### Model

In [1]:
from vgg_features import vgg11_features, vgg13_features, vgg16_features, vgg19_features
import torch.nn as nn
import torch

In [2]:
class VGG_Classifier(nn.Module):
    def __init__(self, model_name, num_classes=5):
        super(VGG_Classifier, self).__init__()
        self.backend_model_bandwidth = 25088  # for vgg family of models
        self.model_loader = self.select_model_loader(model_name)
        self.model = self.model_loader(True)
        self.classifier_head = nn.Linear(self.backend_model_bandwidth, num_classes)

    def forward(self, x):
        batch_size = x.size(0)
        x = self.model(x)
        x = x.reshape(batch_size, -1)
        x = self.classifier_head(x)
        return x      

    def select_model_loader(self, model_name):
        if model_name == 'vgg11':
            model_loader = vgg11_features
        elif model_name == 'vgg13':
            model_loader = vgg13_features
        elif model_name == 'vgg16':
            model_loader = vgg16_features
        elif model_name == 'vgg19':
            model_loader = vgg19_features

        return model_loader

In [3]:
v11 = vgg11_features('vgg11')
v13 = vgg13_features('vgg13')
v16 = vgg16_features('vgg16')
v19 = vgg19_features('vgg19')

In [4]:
input = torch.randn(3, 3, 224, 224)

In [5]:
v11(input).view(3, -1).shape

torch.Size([3, 25088])

In [6]:
v13(input).view(3, -1).shape

torch.Size([3, 25088])

In [7]:
v16(input).view(3, -1).shape

torch.Size([3, 25088])

In [8]:
v19(input).view(3, -1).shape

torch.Size([3, 25088])

In [9]:
VGG_Classifier(model_name='vgg11', num_classes=5)(input).view(3, -1).shape

torch.Size([3, 5])

In [10]:
model = VGG_Classifier(model_name='vgg11', num_classes=5)

### Dataset Class

In [11]:
import os
import pandas as pd
import numpy as np
import ast
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.preprocessing import LabelEncoder
import torchvision.transforms as transforms

class ECGImageDataset(Dataset):
    def __init__(self, info_df_path, transform=None):
        self.info_df = pd.read_csv(info_df_path)
        self.transform = transform

    def __len__(self):
        return len(self.info_df)

    def __getitem__(self, idx):
        img_path = self.info_df.iloc[idx]['Image Path']
        image = Image.open(img_path).convert('RGB')
        label = self.info_df.iloc[idx]['Label']

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label).long()

In [12]:
# Define transformations
img_size = 224  # or whatever size you want
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    normalize,
])

In [13]:
num_train_examples = 100
num_test_examples = 100
train_df = 'train-100HZ-files-and-labels.csv'
val_df = 'test-100HZ-files-and-labels.csv'

# Function to create a subset of the dataset
def create_subset(dataset, num_examples):
    # Ensure num_examples doesn't exceed the dataset length
    num_examples = min(len(dataset), num_examples)
    indices = np.random.choice(len(dataset), num_examples, replace=False)
    subset = torch.utils.data.Subset(dataset, indices)
    return subset

# Create train and test datasets
train_dataset = ECGImageDataset(train_df, transform=transform)
val_dataset = ECGImageDataset(val_df, transform=transform)

if num_train_examples is not None:
    train_subset = create_subset(train_dataset, num_train_examples)
else: 
    train_subset = train_dataset

if num_test_examples is not None:
    val_subset = create_subset(val_dataset, num_test_examples)
else:
    val_subset = val_dataset

# Create data loaders for the subsets
train_loader = torch.utils.data.DataLoader(train_subset, batch_size=32, shuffle=True, num_workers=4, pin_memory=False)
val_loader = torch.utils.data.DataLoader(val_subset, batch_size=32, shuffle=True, num_workers=4, pin_memory=False)

In [14]:
for ex, lab in train_loader:
    print(ex.shape, lab.shape)
    output = model(ex)
    print(output.shape)
    # softmax = F.softmax(output, dim=1)
    # auroc = roc_auc_score(lab, softmax.detach().numpy(), multi_class='ovr', average='macro')
    # print(auroc)
    break

torch.Size([32, 3, 224, 224]) torch.Size([32])
torch.Size([32, 5])


In [16]:
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score
from tqdm import tqdm

def train_model(model, criterion, optimizer, dataloaders, dataset_sizes, num_epochs=25):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            all_labels = []
            all_preds = []
            all_scores = []

            for inputs, labels in tqdm(dataloaders[phase], desc=f'{phase} phase', leave=False):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())
                all_scores.extend(F.softmax(outputs, dim=1).cpu().detach().numpy())

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

            auroc = roc_auc_score(all_labels, all_scores, multi_class='ovr', average='macro')
            f1 = f1_score(all_labels, all_preds, average='macro')
            accuracy = accuracy_score(all_labels, all_preds)

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} AUROC: {auroc:.4f} F1: {f1:.4f}')

    model.load_state_dict(best_model_wts)
    return model

dataloaders = {
    'train': train_loader,
    'val': val_loader
}

# Initialize model, criterion, optimizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = VGG_Classifier('vgg11', num_classes=5).to(device)  # Adjust num_classes as per your dataset

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Define dataset sizes
dataset_sizes = {
    'train': len(train_dataset),
    'val': len(val_dataset)
}

# Train the model
model = train_model(model, criterion, optimizer, dataloaders, dataset_sizes, num_epochs=25)

                                                                                

train Loss: 0.0247 Acc: 0.0039 AUROC: 0.6294 F1: 0.1866


                                                                                

val Loss: 0.1030 Acc: 0.0313 AUROC: 0.3851 F1: 0.1250


                                                                                

train Loss: 0.0149 Acc: 0.0035 AUROC: 0.3954 F1: 0.1622


                                                                                

val Loss: 0.0952 Acc: 0.0334 AUROC: 0.4242 F1: 0.1297


                                                                                

train Loss: 0.0130 Acc: 0.0048 AUROC: 0.3857 F1: 0.1468


                                                                                

val Loss: 0.0958 Acc: 0.0334 AUROC: 0.3986 F1: 0.1297


                                                                                

train Loss: 0.0100 Acc: 0.0050 AUROC: 0.5070 F1: 0.1516


                                                                                

val Loss: 0.1808 Acc: 0.0334 AUROC: 0.4541 F1: 0.1297


                                                                                

train Loss: 0.0115 Acc: 0.0050 AUROC: 0.5698 F1: 0.1516


                                                                                

val Loss: 0.0947 Acc: 0.0334 AUROC: 0.4496 F1: 0.1297


                                                                                

train Loss: 0.0098 Acc: 0.0050 AUROC: 0.5364 F1: 0.1516


                                                                                

val Loss: 0.1179 Acc: 0.0334 AUROC: 0.4464 F1: 0.1297


                                                                                

train Loss: 0.0104 Acc: 0.0050 AUROC: 0.5061 F1: 0.1516


                                                                                

val Loss: 0.0951 Acc: 0.0334 AUROC: 0.4761 F1: 0.1297


                                                                                

train Loss: 0.0103 Acc: 0.0050 AUROC: 0.4111 F1: 0.1516


                                                                                

val Loss: 0.0969 Acc: 0.0334 AUROC: 0.4707 F1: 0.1297


                                                                                

train Loss: 0.0102 Acc: 0.0050 AUROC: 0.4078 F1: 0.1516


                                                                                

val Loss: 0.0974 Acc: 0.0334 AUROC: 0.4742 F1: 0.1297


                                                                                

train Loss: 0.0094 Acc: 0.0050 AUROC: 0.5359 F1: 0.1516


                                                                                

val Loss: 0.0967 Acc: 0.0334 AUROC: 0.4856 F1: 0.1297


                                                                                

train Loss: 0.0096 Acc: 0.0050 AUROC: 0.5135 F1: 0.1516


                                                                                

val Loss: 0.0969 Acc: 0.0334 AUROC: 0.4846 F1: 0.1297


                                                                                

train Loss: 0.0095 Acc: 0.0050 AUROC: 0.5757 F1: 0.1516


                                                                                

val Loss: 0.0930 Acc: 0.0334 AUROC: 0.4955 F1: 0.1297


                                                                                

train Loss: 0.0100 Acc: 0.0050 AUROC: 0.6089 F1: 0.1516


                                                                                

val Loss: 0.0954 Acc: 0.0334 AUROC: 0.5107 F1: 0.1297


                                                                                

train Loss: 0.0099 Acc: 0.0050 AUROC: 0.6110 F1: 0.1516


                                                                                

val Loss: 0.1093 Acc: 0.0334 AUROC: 0.4807 F1: 0.1297


                                                                                

train Loss: 0.0102 Acc: 0.0050 AUROC: 0.6179 F1: 0.1516


                                                                                

val Loss: 0.1004 Acc: 0.0334 AUROC: 0.4716 F1: 0.1297


                                                                                

train Loss: 0.0094 Acc: 0.0050 AUROC: 0.5749 F1: 0.1516


                                                                                

val Loss: 0.0943 Acc: 0.0334 AUROC: 0.4797 F1: 0.1297


                                                                                

train Loss: 0.0095 Acc: 0.0050 AUROC: 0.6082 F1: 0.1516


                                                                                

val Loss: 0.1014 Acc: 0.0334 AUROC: 0.4731 F1: 0.1297


                                                                                

train Loss: 0.0095 Acc: 0.0050 AUROC: 0.5969 F1: 0.1516


                                                                                

val Loss: 0.0942 Acc: 0.0334 AUROC: 0.4954 F1: 0.1297


                                                                                

train Loss: 0.0092 Acc: 0.0050 AUROC: 0.6751 F1: 0.1516


                                                                                

val Loss: 0.0939 Acc: 0.0334 AUROC: 0.4812 F1: 0.1297


                                                                                

train Loss: 0.0094 Acc: 0.0050 AUROC: 0.6619 F1: 0.1516


                                                                                

val Loss: 0.0984 Acc: 0.0334 AUROC: 0.4723 F1: 0.1297


                                                                                

train Loss: 0.0089 Acc: 0.0050 AUROC: 0.6686 F1: 0.1516


                                                                                

val Loss: 0.1081 Acc: 0.0334 AUROC: 0.4862 F1: 0.1297


                                                                                

train Loss: 0.0088 Acc: 0.0049 AUROC: 0.6649 F1: 0.1500


                                                                                

val Loss: 0.1054 Acc: 0.0334 AUROC: 0.5040 F1: 0.1297


                                                                                

train Loss: 0.0088 Acc: 0.0050 AUROC: 0.6852 F1: 0.1516


                                                                                

val Loss: 0.0946 Acc: 0.0334 AUROC: 0.5038 F1: 0.1297


                                                                                

train Loss: 0.0088 Acc: 0.0050 AUROC: 0.7104 F1: 0.1516


                                                                                

val Loss: 0.1040 Acc: 0.0334 AUROC: 0.5042 F1: 0.1297


                                                                                

train Loss: 0.0084 Acc: 0.0052 AUROC: 0.7070 F1: 0.2165


                                                                                

val Loss: 0.1370 Acc: 0.0347 AUROC: 0.4952 F1: 0.1848


