In [1]:
%run preprocess_functions.ipynb
%run data_augmentation.ipynb

In [2]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import torch
import matplotlib.pyplot as plt
from torch import optim
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import f1_score, classification_report, accuracy_score, precision_recall_curve, auc, roc_auc_score, roc_curve, multilabel_confusion_matrix

In [3]:
DIRECTORY_PATH = 'one_eye_images'
AUGMENTED_DIRECTORY_PATH ="original_and_augmented_images"

# model for one eye input

Dataset class

In [4]:
class CustomDataset(Dataset):
    def __init__(self, dataset_X, dataset_y, image_directory_path, transfrom=None, include_preprocess_function=False, mode='train'):
        self.X = dataset_X
        self.y = dataset_y
        self.image_directory_path = image_directory_path
        # self.grouped_by_id_data = self.dataset.groupby('patien_id') #groups in format of { patient_id: [indx1, indx2] }
        if mode == 'train':
            if include_preprocess_function:
                self.transform = transfrom or transforms.Compose([
                    # transforms.RandomHorizontalFlip(),
                    # transforms.RandomRotation(degrees=90),
                    transforms.Lambda(lambda img: Image.fromarray(preprocess(np.array(img)))),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomRotation(degrees=50),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) #values from image net, better for pretrained models, can be changed to dataset values
                ])
            else:
                self.transform = transfrom or transforms.Compose([
                    transforms.Resize((224, 224)),
                    # transforms.RandomHorizontalFlip(),
                    # transforms.RandomRotation(degrees=90),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) #values from image net, better for pretrained models, can be changed to dataset values
                ])
        else:
            if include_preprocess_function:
                self.transform = transfrom or transforms.Compose([
                    transforms.Lambda(lambda img: Image.fromarray(preprocess(np.array(img)))),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ])
            else:
                self.transform = transfrom or transforms.Compose([
                        transforms.Resize((224, 224)),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) #values from image net, better for pretrained models, can be changed to dataset values
                    ])
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, index):
        # eye_image = cv2.imread(os.path.join(self.image_directory_path, self.X.iloc[index]['image_id']))
        eye_image = Image.open(os.path.join(self.image_directory_path, self.X.iloc[index]['image_id']), 'r')
        diagnosis = self.y.iloc[index][['diabetic_retinopathy', 'amd', 'hypertensive_retinopathy', 
                                                'normal_eye', 'glaucoma', 'cataract']].to_numpy(dtype=np.float32)
        age = self.X.iloc[index]['patient_age']
        sex = self.X.iloc[index]['patient_sex']
        if self.transform:
           eye_image = self.transform(eye_image)
        
        data = {"eye_image": eye_image,
                "diagnosis": torch.tensor(diagnosis, dtype=torch.float32),
                "metadata": torch.tensor(np.array([age, sex]), dtype=torch.float32)}
        
        return data['eye_image'], data['metadata'], data['diagnosis']

Model class

1. vgg 

In [5]:
class NetworkVGG(nn.Module):
    def __init__(self):
        super(NetworkVGG, self).__init__()
        
        self.eye_input = models.vgg16(weights='VGG16_Weights.DEFAULT').requires_grad_(False)
        self.eye_input = nn.Sequential(*list(self.eye_input.children())[:-1]) #removing the classifier layer 
        # for param in self.eye_input.parameters():
        #     param.requires_grad = False

        # self.eye_input = nn.Sequential(
            
        # )
            
        self.metadata_input = nn.Sequential(
            nn.Linear(2, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 64),
            nn.ReLU(inplace=True)
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(25088 + 64, 4096), #4096
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(512, 6),
            # nn.Softmax(dim=1)
        )
        
    def forward(self, eye_image, metadata): 
        vgg = self.eye_input(eye_image)
        vgg = torch.flatten(vgg, 1)
        meta = self.metadata_input(metadata)

        # display(vgg)
        # display(meta)
        concatenated = torch.cat((vgg, meta), dim=1)
        result = self.classifier(concatenated)
        return result

train function

In [6]:
def train(dataloader, model, loss_function, optimizer):
    size = len(dataloader.dataset)
    running_loss = 0.
    model.train()
    for batch, (image, metadata, diagnosis) in enumerate(dataloader): #(image, metadata, diagnosis)

        prediction = model(image, metadata) #, metadata
        # print("prediction", prediction[0])
        # print("diagnosis", diagnosis[0])
        loss = loss_function(prediction, diagnosis)
        running_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch % 2 == 0:
            loss, current = loss.item(), batch * len(diagnosis)
            print(f"loss: {loss:>7f}, [{current:>5d}/{size:>5d}]")
    return running_loss / len(dataloader)

In [7]:
def test_and_validation(dataloader, model, loss_function):
    val_loss = 0.
    model.eval()
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for (image, metadata, diagnosis) in dataloader: #(image, metadata, diagnosis)
            prediction = model(image, metadata) #, metadata
            batch_val_loss = loss_function(prediction, diagnosis)
            val_loss += batch_val_loss
            preds = torch.argmax(prediction, dim=1)
            all_predictions.extend(preds.cpu().numpy())
            all_labels.extend(torch.argmax(diagnosis, dim=1).cpu().numpy())
    avg_loss = val_loss / len(dataloader)
    f1 = f1_score(all_labels, all_predictions, average='macro')
    print(f"validation_loss: {avg_loss:>7f}, f1_score: {f1:>7f}")
    return avg_loss, f1

In [44]:
# def test(dataloader, model, loss_function):
#     num_batches = len(dataloader)
#     test_loss = 0.
#     model.eval()
#     with torch.no_grad():
#         for (image, metadata, diagnosis) in dataloader:
#             prediction = model(image, metadata)
#             batch_test_loss = loss_function(prediction, diagnosis)
#             test_loss += batch_test_loss.item()
#             print(test_loss)
#     return test_loss / num_batches

In [8]:
data = pd.read_csv("final_one_eye_dataset.csv")
classes = ['diabetic_retinopathy', 'amd', 'hypertensive_retinopathy', 'normal_eye', 'glaucoma', 'cataract']
X = data[['image_id', 'patient_age', 'patient_sex']]
y = data[classes]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.25, random_state=42)


In [9]:
data[classes].value_counts()

diabetic_retinopathy  amd  hypertensive_retinopathy  normal_eye  glaucoma  cataract
0                     0    0                         1           0         0           3094
1                     0    0                         0           0         0           2376
0                     1    0                         0           0         0            468
                      0    0                         0           0         1            292
                           1                         0           0         0            277
                           0                         0           1         0            268
Name: count, dtype: int64

datasets and dataloaders

In [10]:
age_scaler = StandardScaler()

X_train['patient_age'] = age_scaler.fit_transform(X_train[['patient_age']])
X_val['patient_age'] = age_scaler.transform(X_val[['patient_age']])
X_test['patient_age'] = age_scaler.transform(X_val[['patient_age']])

train_dataset = CustomDataset(X_train, y_train, DIRECTORY_PATH, include_preprocess_function=True)
validation_dataset = CustomDataset(X_val, y_val, DIRECTORY_PATH, mode='test', include_preprocess_function=True)
test_dataset = CustomDataset(X_test, y_test, DIRECTORY_PATH, mode='test', include_preprocess_function=True)

train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True)


    # Compute class weights using class indices
y_train_indices = np.argmax(y_train, axis=1)
class_weights = compute_class_weight('balanced', classes=np.unique(y_train_indices), y=y_train_indices)
class_weights = torch.tensor(class_weights, dtype=torch.float32)
    #criterion = nn.BCEWithLogitsLoss() # Binary Cross-Entropy Loss

In [11]:
epochs_vgg16 = 5
learning_rate_vgg = 0.001
model_vgg = NetworkVGG()
loss_function = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.Adam(model_vgg.parameters(), lr=learning_rate_vgg)
train_losses = []
val_losses = []
f1_scores = []

for t in range(epochs_vgg16):
    print(f"Epoch {t+1}\n-------------------------------")
    epoch_train_loss = train(train_dataloader, model_vgg, loss_function, optimizer)
    epoch_val_loss, f1_score_val =test_and_validation(validation_dataloader, model_vgg, loss_function)
    train_losses.append(epoch_train_loss)
    val_losses.append(epoch_val_loss)
    f1_scores.append(f1_score_val)
print("Done!")

Epoch 1
-------------------------------
loss: 1.845412, [    0/ 4065]
loss: 6.779254, [  256/ 4065]
loss: 5.933875, [  512/ 4065]
loss: 3.051322, [  768/ 4065]
loss: 2.319752, [ 1024/ 4065]
loss: 1.700814, [ 1280/ 4065]
loss: 1.901054, [ 1536/ 4065]
loss: 1.310896, [ 1792/ 4065]
loss: 1.481751, [ 2048/ 4065]
loss: 1.581696, [ 2304/ 4065]
loss: 1.563545, [ 2560/ 4065]
loss: 1.536388, [ 2816/ 4065]
loss: 1.741104, [ 3072/ 4065]
loss: 1.448448, [ 3328/ 4065]
loss: 1.248655, [ 3584/ 4065]
loss: 1.562592, [ 3840/ 4065]
validation_loss: 1.299685, f1_score: 0.304625
Epoch 2
-------------------------------
loss: 1.394958, [    0/ 4065]
loss: 1.868749, [  256/ 4065]
loss: 1.238630, [  512/ 4065]
loss: 1.134365, [  768/ 4065]
loss: 1.238057, [ 1024/ 4065]
loss: 1.540298, [ 1280/ 4065]
loss: 1.211996, [ 1536/ 4065]
loss: 1.449643, [ 1792/ 4065]
loss: 1.260749, [ 2048/ 4065]
loss: 1.168565, [ 2304/ 4065]
loss: 1.105432, [ 2560/ 4065]
loss: 1.121161, [ 2816/ 4065]
loss: 1.073504, [ 3072/ 4065]
loss

KeyboardInterrupt: 

In [16]:
def sas(a, b, c):
    c.append(1)
    return "wer"

a = 1
b = 2
c = []
sas(a, b, c)
c

[1]

In [36]:
torch.save(model_vgg.state_dict(), 'vgg_model.pth')

In [37]:
train_losses

[1.3695705653576369,
 0.7781428319684575,
 0.5923326380467147,
 0.4782548225327824,
 0.4305681502551175,
 0.3712513761573963,
 0.3401051274176394,
 0.3178074627779843,
 0.29567769186550313,
 0.27077615076906225,
 0.2396045196592138]

In [38]:
val_losses

[tensor(1.1102),
 tensor(1.0592),
 tensor(1.0453),
 tensor(1.0799),
 tensor(1.0433),
 tensor(1.2642),
 tensor(1.0555),
 tensor(1.1894),
 tensor(1.3154),
 tensor(1.2449),
 tensor(1.2217)]

In [50]:
import pickle
with open('train_losses.pkl', 'wb') as file:
    pickle.dump(train_losses, file)
with open('val_losses.pkl', 'wb') as file:
    pickle.dump(val_losses, file)
with open('scaler.pkl', 'wb') as file:
    pickle.dump(age_scaler, file)
    
torch.save(train_dataloader, 'train_dataloader.pth')
torch.save(validation_dataloader, 'validation_dataloader.pth')
torch.save(test_dataloader, 'test_dataloader.pth')
torch.save(optimizer.state_dict(), 'optimizer.pth')