In [1]:
from resnet1d import resnet1d_wang
import torch

In [2]:
import torch

# Generate a random tensor
random_input = torch.randn(2, 12, 5000)  # 1 sample, 12 channels, 5000 timesteps

# Instantiate the model with the correct input channels
model = resnet1d_wang(input_channels=12, num_classes=5)

# Get the model's output
output = model(random_input)
print(f"Output shape: {output.shape}")


Output shape: torch.Size([2, 5])


In [3]:
output

tensor([[ 0.3682,  0.2865, -0.5881, -0.2014,  0.8039],
        [ 0.0062,  0.0734,  0.2114,  0.4198, -0.3297]],
       grad_fn=<AddmmBackward0>)

In [14]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
import wandb
import copy
import torch.nn as nn
import torch.nn.functional as F
from dataset_class_for_superclasses import ECGImageDataset
import numpy as np
from tqdm import tqdm

In [19]:
def train_model(model, criterion, optimizer, dataloaders, dataset_sizes, num_epochs=25):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            all_labels = []
            all_preds = []
            all_scores = []

            for inputs, labels in tqdm(dataloaders[phase], desc=f'{phase} phase', leave=False):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())
                all_scores.extend(F.softmax(outputs, dim=1).cpu().detach().numpy())

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

            auroc = roc_auc_score(all_labels, all_scores, multi_class='ovr', average='macro')
            f1 = f1_score(all_labels, all_preds, average='macro')
            accuracy = accuracy_score(all_labels, all_preds)

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} AUROC: {auroc:.4f} F1: {f1:.4f}')

    model.load_state_dict(best_model_wts)
    return model

In [20]:
# Define the custom flatten transform
class FlattenTransform:
    def __call__(self, sample):
        # Assuming sample is a Tensor with shape [C, H, W]
        return sample.view(1, -1)

# Define the data transformations including the custom flatten transform
data_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize image to 224x224
    transforms.ToTensor(),           # Convert image to Tensor
    FlattenTransform()              # Flatten the tensor
])

train_dataset = ECGImageDataset(info_df_path='train-100HZ-files-and-labels.csv', transform=data_transform)
val_dataset = ECGImageDataset(info_df_path='test-100HZ-files-and-labels.csv', transform=data_transform)

# Function to create a subset of the dataset
def create_subset(dataset, num_examples):
    # Ensure num_examples doesn't exceed the dataset length
    num_examples = min(len(dataset), num_examples)
    indices = np.random.choice(len(dataset), num_examples, replace=False)
    subset = torch.utils.data.Subset(dataset, indices)
    return subset
    
num_train_examples = 100
num_test_examples = 100
if num_train_examples is not None:
    train_subset = create_subset(train_dataset, num_train_examples)
else: 
    train_subset = train_dataset

if num_test_examples is not None:
    val_subset = create_subset(val_dataset, num_test_examples)
else:
    val_subset = val_dataset
    
train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=32, shuffle=False)

dataloaders = {
    'train': train_loader,
    'val': val_loader
}

# Initialize model, criterion, optimizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = resnet1d_wang(num_classes=5, input_channels=1).to(device)  # Adjust num_classes as per your dataset

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Define dataset sizes
dataset_sizes = {
    'train': len(train_dataset),
    'val': len(val_dataset)
}

# Train the model
model = train_model(model, criterion, optimizer, dataloaders, dataset_sizes, num_epochs=25)

Epoch 0/24
----------


                                                                                                                        

train Loss: 0.0144 Acc: 0.0014 AUROC: 0.4939 F1: 0.1321


                                                                                                                        

val Loss: 0.3695 Acc: 0.0167 AUROC: 0.5004 F1: 0.0774
Epoch 1/24
----------


                                                                                                                        

train Loss: 0.0125 Acc: 0.0028 AUROC: 0.6679 F1: 0.2563


                                                                                                                        

val Loss: 0.8895 Acc: 0.0167 AUROC: 0.5148 F1: 0.0774
Epoch 2/24
----------


                                                                                                                        

train Loss: 0.0126 Acc: 0.0031 AUROC: 0.6390 F1: 0.2874


                                                                                                                        

val Loss: 1.4576 Acc: 0.0167 AUROC: 0.4582 F1: 0.0774
Epoch 3/24
----------


                                                                                                                        

train Loss: 0.0123 Acc: 0.0039 AUROC: 0.6717 F1: 0.3246


                                                                                                                        

val Loss: 1.8228 Acc: 0.0167 AUROC: 0.5503 F1: 0.0774
Epoch 4/24
----------


                                                                                                                        

train Loss: 0.0124 Acc: 0.0040 AUROC: 0.6439 F1: 0.2902


                                                                                                                        

val Loss: 2.0987 Acc: 0.0167 AUROC: 0.4513 F1: 0.0774
Epoch 5/24
----------


                                                                                                                        

train Loss: 0.0115 Acc: 0.0044 AUROC: 0.7198 F1: 0.3987


                                                                                                                        

val Loss: 2.7987 Acc: 0.0167 AUROC: 0.4313 F1: 0.0774
Epoch 6/24
----------


                                                                                                                        

train Loss: 0.0113 Acc: 0.0044 AUROC: 0.7429 F1: 0.4199


                                                                                                                        

val Loss: 3.9258 Acc: 0.0167 AUROC: 0.4652 F1: 0.0774
Epoch 7/24
----------


                                                                                                                        

train Loss: 0.0109 Acc: 0.0048 AUROC: 0.7556 F1: 0.4465


                                                                                                                        

val Loss: 5.3427 Acc: 0.0167 AUROC: 0.4154 F1: 0.0774
Epoch 8/24
----------


                                                                                                                        

train Loss: 0.0101 Acc: 0.0052 AUROC: 0.8177 F1: 0.4750


                                                                                                                        

val Loss: 6.7077 Acc: 0.0167 AUROC: 0.4378 F1: 0.0774
Epoch 9/24
----------


                                                                                                                        

train Loss: 0.0109 Acc: 0.0040 AUROC: 0.7491 F1: 0.2813


                                                                                                                        

val Loss: 8.0160 Acc: 0.0167 AUROC: 0.4796 F1: 0.0774
Epoch 10/24
----------


                                                                                                                        

train Loss: 0.0112 Acc: 0.0051 AUROC: 0.7040 F1: 0.4544


                                                                                                                        

val Loss: 9.4585 Acc: 0.0167 AUROC: 0.5510 F1: 0.0774
Epoch 11/24
----------


                                                                                                                        

train Loss: 0.0105 Acc: 0.0048 AUROC: 0.7565 F1: 0.4313


                                                                                                                        

val Loss: 10.2965 Acc: 0.0167 AUROC: 0.5184 F1: 0.0774
Epoch 12/24
----------


                                                                                                                        

train Loss: 0.0105 Acc: 0.0051 AUROC: 0.7836 F1: 0.4294


                                                                                                                        

val Loss: 10.0957 Acc: 0.0167 AUROC: 0.5071 F1: 0.0774
Epoch 13/24
----------


                                                                                                                        

train Loss: 0.0107 Acc: 0.0047 AUROC: 0.7535 F1: 0.3700


                                                                                                                        

val Loss: 8.4480 Acc: 0.0167 AUROC: 0.5020 F1: 0.0774
Epoch 14/24
----------


                                                                                                                        

train Loss: 0.0104 Acc: 0.0047 AUROC: 0.7687 F1: 0.3635


                                                                                                                        

val Loss: 8.3531 Acc: 0.0167 AUROC: 0.4602 F1: 0.0774
Epoch 15/24
----------


                                                                                                                        

train Loss: 0.0101 Acc: 0.0049 AUROC: 0.7875 F1: 0.4315


                                                                                                                        

val Loss: 9.4591 Acc: 0.0167 AUROC: 0.4903 F1: 0.0774
Epoch 16/24
----------


                                                                                                                        

train Loss: 0.0100 Acc: 0.0052 AUROC: 0.7877 F1: 0.4192


                                                                                                                        

val Loss: 11.0607 Acc: 0.0167 AUROC: 0.5000 F1: 0.0774
Epoch 17/24
----------


                                                                                                                        

train Loss: 0.0098 Acc: 0.0051 AUROC: 0.7812 F1: 0.4429


                                                                                                                        

val Loss: 13.9700 Acc: 0.0167 AUROC: 0.5000 F1: 0.0774
Epoch 18/24
----------


                                                                                                                        

train Loss: 0.0102 Acc: 0.0052 AUROC: 0.6884 F1: 0.4470


                                                                                                                        

val Loss: 16.1189 Acc: 0.0167 AUROC: 0.5000 F1: 0.0774
Epoch 19/24
----------


                                                                                                                        

train Loss: 0.0099 Acc: 0.0053 AUROC: 0.7283 F1: 0.4118


                                                                                                                        

val Loss: 14.0908 Acc: 0.0167 AUROC: 0.5219 F1: 0.0774
Epoch 20/24
----------


                                                                                                                        

train Loss: 0.0094 Acc: 0.0055 AUROC: 0.7231 F1: 0.4671


                                                                                                                        

val Loss: 14.0315 Acc: 0.0167 AUROC: 0.4969 F1: 0.0774
Epoch 21/24
----------


                                                                                                                        

train Loss: 0.0102 Acc: 0.0054 AUROC: 0.6860 F1: 0.3981


                                                                                                                        

val Loss: 17.1565 Acc: 0.0167 AUROC: 0.4827 F1: 0.0774
Epoch 22/24
----------


                                                                                                                        

train Loss: 0.0096 Acc: 0.0052 AUROC: 0.7872 F1: 0.4398


                                                                                                                        

val Loss: 25.4911 Acc: 0.0167 AUROC: 0.5000 F1: 0.0774
Epoch 23/24
----------


                                                                                                                        

train Loss: 0.0091 Acc: 0.0054 AUROC: 0.7500 F1: 0.4184


                                                                                                                        

val Loss: 23.4217 Acc: 0.0167 AUROC: 0.5000 F1: 0.0774
Epoch 24/24
----------


                                                                                                                        

train Loss: 0.0095 Acc: 0.0048 AUROC: 0.7263 F1: 0.2966


                                                                                                                        

val Loss: 18.9434 Acc: 0.0167 AUROC: 0.5000 F1: 0.0774




In [17]:
for ex, lab in train_loader:
    print(ex.shape, lab.shape)
    output = model(ex)
    print(output.shape)
    softmax = F.softmax(output, dim=1)
    auroc = roc_auc_score(lab, softmax.detach().numpy(), multi_class='ovr', average='macro')
    print(auroc)
    break

torch.Size([32, 1, 150528]) torch.Size([32])
torch.Size([32, 5])
0.4262988545294048
