In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install segmentation-models-pytorch

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting segmentation-models-pytorch
  Downloading segmentation_models_pytorch-0.3.2-py3-none-any.whl (106 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.7/106.7 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
Collecting timm==0.6.12
  Downloading timm-0.6.12-py3-none-any.whl (549 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m549.1/549.1 kB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting efficientnet-pytorch==0.7.1
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pretrainedmodels==0.7.4
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting munch
  Downloading munch-2.5.0-py2.py3

In [None]:
import torch
import math
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np
import pandas as pd
from PIL import Image, ImageEnhance
import argparse
import os
import copy
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from pprint import pprint
import segmentation_models_pytorch as smp
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import f1_score
import cv2
from skimage.feature import hog
import timm
import random
from torch.utils.data import WeightedRandomSampler

# import torch.multiprocessing as mp
# mp.set_start_method('spawn', force=True)

LABELS_Severity = {35: 0,
                   43: 0,
                   47: 1,
                   53: 1,
                   61: 2,
                   65: 2,
                   71: 2,
                   85: 2}


mean = (.1706)
std = (.2112)
normalize = transforms.Normalize(mean=mean, std=std)
train_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize(size=(64,64)),
    #transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    #transforms.RandomCrop((224,224), padding=4),
    transforms.ToTensor(),
    normalize,
])
test_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize(size=(64,64)),
    transforms.ToTensor(),
    normalize,
])

In [None]:
class OCTDataset(Dataset):
    def __init__(self, annot=None, unique_pairs=None, subset='train', transform=None, device='cpu'):
        if subset == 'train':
            self.annot = pd.read_csv("/content/drive/MyDrive/FML_Project/df_prime_train.csv")                      
        elif subset == 'val':
            self.annot = pd.read_csv("/content/drive/MyDrive/FML_Project/df_prime_train.csv")
            self.unique_pairs = unique_pairs
        elif subset == 'test':
            self.annot = pd.read_csv("/content/drive/MyDrive/FML_Project/df_prime_test.csv")

        # Extract "Patient_ID" and "Week_Num" columns
        # print("Before Pairing ", len(self.annot))
        self.patient_ids = self.annot["Patient_ID"]
        self.week_nums = self.annot["Week_Num"]
        self.patient_ids = self.annot["Patient_ID"]
        self.annot['Severity_Label'] = [LABELS_Severity[drss] for drss in copy.deepcopy(self.annot['DRSS'].values)]
        self.drss_class = self.annot['Severity_Label']

        if subset == 'train':
          # Create unique pairs of values
          self.unique_pairs = set(zip(self.patient_ids, self.week_nums, self.drss_class))

          # Create a list from the set of unique_pairs
          unique_pairs_list = list(self.unique_pairs)

          # Shuffle the unique_pairs_list
          random.shuffle(unique_pairs_list)

          # Calculate the index at which to split the list
          split_index = int(0.8 * len(unique_pairs_list))

          # Split the list into training and validation pairs
          self.unique_pairs = unique_pairs_list[:split_index]
          self.unique_validation_pairs = unique_pairs_list[split_index:]

        elif subset == 'test':
          # Create unique pairs of values
          self.unique_pairs = set(zip(self.patient_ids, self.week_nums, self.drss_class))

        self.root = os.path.expanduser("/content/drive/MyDrive/FML_Project/")
        self.transform = transform
        self.nb_classes=len(np.unique(list(LABELS_Severity.values())))
        self.path_list = self.annot['File_Path'].values

        self._labels = [pair[2] for pair in self.unique_pairs]
        # self._labels = self.annot['Severity_Label'].values
        assert len(self.unique_pairs) == len(self._labels)
        
        max_samples = int(len(self._labels)) #32 #int(len(self._labels)/2)
        self.max_samples = max_samples
        self.device = device
        
    def __getitem__(self, index):
        # Get the Patient_ID and Week_Num from the indexed element in unique_pairs
        patient_id, week_num, target = list(self.unique_pairs)[index]
        # Filter the annot DataFrame to select rows that match the Patient_ID and Week_Num
        filtered_df = self.annot[(self.annot['Patient_ID'] == patient_id) & (self.annot['Week_Num'] == week_num)]
        # Extract the file paths from the filtered DataFrame and return them as a list
        file_paths = [self.root + file_path for file_path in filtered_df['File_Path'].values.tolist()]

        # Fix for directories containing lesser than 49 images:
        # Load the images.
        images = [Image.open(fp).convert('RGB') for fp in file_paths]
        # Check if the number of images is less than 49.
        if len(images) < 49:
            # Calculate the number of missing images.
            missing_images = 49 - len(images)
            # Duplicate a random image from the existing images to fill the missing spots.
            for _ in range(missing_images):
                random_index = random.randint(0, len(images) - 1)
                images.append(images[random_index])

        # Load all 49 images and stack them into a single tensor
        image_sequence = []
        for image in images:
            img_gray = image.convert("L")
            img_gray = np.array(img_gray)  # Convert PIL Image to a numpy array
            img_gray = cv2.Canny(img_gray, 200, 500) 
            if self.transform is not None:
                img_transformed = self.transform(Image.fromarray(img_gray))
            image_sequence.append(img_transformed)

        image_sequence = torch.stack(image_sequence, dim=0)
        return (image_sequence, target)

    def __len__(self):
        if self.max_samples is not None:
            return min(len(self._labels), self.max_samples)
        else:
            return len(self._labels)    

#ConvLSTM
class ConvLSTMCell(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size):
        super(ConvLSTMCell, self).__init__()

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size

        padding = kernel_size // 2
        self.conv = nn.Conv2d(input_channels + hidden_channels, 4 * hidden_channels, kernel_size, padding=padding)

    def forward(self, input, hidden):
        hx, cx = hidden
        combined = torch.cat((input, hx), 1)
        gates = self.conv(combined)

        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
        ingate = torch.sigmoid(ingate)
        forgetgate = torch.sigmoid(forgetgate)
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)

        cy = (forgetgate * cx) + (ingate * cellgate)
        hy = outgate * torch.tanh(cy)

        return hy, cy

#define the neural network architecture
class OCTClassifier(torch.nn.Module):
    def __init__(self):
        super(OCTClassifier, self).__init__()        
        self.conv_lstm = ConvLSTMCell(3, 64, 3)
        self.unet = smp.Unet(
            encoder_name="resnet34",
            encoder_weights="imagenet",
            in_channels=64,
            classes=3,
            activation=None
        )
        self.avg_pool = torch.nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        # x shape: (batch_size, sequence_length, channels, height, width)
        batch_size, sequence_length, channels, height, width = x.size()

        hidden_state = (torch.zeros(batch_size, 64, height, width).to(x.device),
                        torch.zeros(batch_size, 64, height, width).to(x.device))

        for t in range(sequence_length):
            hidden_state = self.conv_lstm(x[:, t], hidden_state)

        x = self.unet(hidden_state[0])
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        # Comment the following lines if you want to return logits instead of class labels
        # Apply softmax activation
        # x = torch.softmax(x, dim=1)

        # # Get the class labels
        # x = torch.argmax(x, dim=1)
        return x

def create_oversampler(targets):
    class_sample_counts = np.bincount(targets)
    weights = 1.0 / torch.tensor(class_sample_counts, dtype=torch.float)
    sample_weights = weights[targets]
    sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(targets), replacement=True)
    return sampler

In [None]:
annot_train_prime = "/content/drive/MyDrive/FML_Project/df_prime_train.csv"
annot_test_prime = "/content/drive/MyDrive/FML_Project/df_prime_test.csv"
data_root = "/content/drive/MyDrive/FML_Project/"

#train_class_distribution = count_class_distribution(trainset)
#test_class_distribution = count_class_distribution(testset)
#print("Trainset class distribution:", train_class_distribution)
#print("Testset class distribution:", test_class_distribution)    

#set up the device (GPU or CPU)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Found device:', device)

#def __init__(self, annot=None, subset='train', transform=None, device='cpu'):
trainset = OCTDataset(subset='train', transform=train_transform, device=device)
valset = OCTDataset(subset='val', unique_pairs=trainset.unique_validation_pairs, transform=train_transform, device=device)

#define the hyperparameters
import os
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:100'
batch_size = 8
learning_rate = 1e-4
num_epochs = 5 #50 #100 #95 #100 #25 #10 #5

oversampler = create_oversampler(trainset._labels)
trainloader = DataLoader(trainset, batch_size=batch_size, sampler=oversampler, num_workers=4)
val_loader = DataLoader(valset, batch_size=batch_size, shuffle=True)
print('Train and Test loader complete')

# for i in range(10):
#   print(trainset[i][0].shape)
print(len(trainset), len(valset))

Found device: cuda:0
Train and Test loader complete
396 99




In [None]:
#initialize the model and optimizer
model = OCTClassifier().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
print('Model definition complete')

#define the loss function
criterion = nn.CrossEntropyLoss()
# to handle imbalanced dataset:
class_counts = np.bincount(trainset._labels)
print(class_counts)
total_samples = len(trainset)
class_weights = torch.FloatTensor(total_samples / (len(class_counts) * class_counts)).to(device)
# class_weights[0] = class_weights[0] * 2
# class_weights[2] = class_weights[2] * 4

# class_weights = torch.tensor([1.0 / c for c in class_counts], dtype=torch.float).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
print(class_weights)
#exit()

#train the model
print('len of trainloader:'+str(len(trainloader)))
for epoch in range(num_epochs):
    running_loss = 0.0
    model.train()  # Set the model to training mode
    for i, (inputs, labels) in enumerate(trainloader):
        print('Training start for train batch: '+str(i))
        inputs = inputs.to(device)
        labels = labels.to(device).long()
        optimizer.zero_grad()
        #print('Before model call for batch: '+str(i))
        outputs = model(inputs)
        #print('After model call for batch: '+str(i))
        # print('outputs:', outputs)
        # print('labels:', labels)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    train_loss = running_loss / len(trainloader)
    print(f'Epoch {epoch + 1} | Train Loss: {train_loss:.3f}')

Model definition complete
[125 192  79]
tensor([1.0560, 0.6875, 1.6709], device='cuda:0')
len of trainloader:50
Training start for train batch: 0


In [None]:
testset = OCTDataset(subset='test', transform=test_transform, device=device)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
print(len(testloader))
#evaluate the model on the test set
model.eval()

# turn off gradients for evaluation
true_labels = []
pred_labels = []
with torch.no_grad():
    #for inputs, labels in testloader:
    for i, (inputs, labels) in enumerate(testloader):
        print('Testing start for batch: '+str(i))
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        # predicted = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
        _, predicted = torch.max(outputs.data, 1)
        pred_labels.extend(predicted.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

        # true_labels += labels.cpu().numpy().tolist()
        # pred_labels += predicted.cpu().numpy().tolist()

# compute the balanced accuracy
balanced_accuracy = balanced_accuracy_score(true_labels, pred_labels)
f1 = f1_score(true_labels, pred_labels, average='weighted')

# print the balanced accuracy
print('Balanced accuracy:', balanced_accuracy)
print('F1 score:', f1)

21
Testing start for batch: 0
Testing start for batch: 1
Testing start for batch: 2
Testing start for batch: 3
Testing start for batch: 4
Testing start for batch: 5
Testing start for batch: 6
Testing start for batch: 7
Testing start for batch: 8
Testing start for batch: 9
Testing start for batch: 10
Testing start for batch: 11
Testing start for batch: 12
Testing start for batch: 13
Testing start for batch: 14
Testing start for batch: 15
Testing start for batch: 16
Testing start for batch: 17
Testing start for batch: 18
Testing start for batch: 19
Testing start for batch: 20
Balanced accuracy: 0.3461538461538462
F1 score: 0.08493647488925812
