In [None]:
%run preprocess_functions.ipynb
%run data_augmentation.ipynb

In [2]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import cv2
import torch
import matplotlib.pyplot as plt
from torch import optim
from torch import nn
import pickle
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import f1_score, classification_report, accuracy_score, precision_recall_curve, auc, roc_auc_score, roc_curve, multilabel_confusion_matrix

In [3]:
torch.cuda.is_available()

True

In [4]:
def crop_image_to_circle(image):

    img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 
    lower = np.array([3,3,3])
    higher = np.array([254,254,254])
    mask=cv2.inRange(img, lower, higher)
    
    
    contours, _ = cv2.findContours(image= mask,
                                   mode=cv2.RETR_EXTERNAL,
                                   method=cv2.CHAIN_APPROX_NONE)

    all_contours = np.vstack(contours)
    x, y, w, h = cv2.boundingRect(all_contours)  
    cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 5)
    cropped_img = image[y:y + h + 1, x:x + w + 1]
    return cropped_img

def preprocess(image, size_x=224, size_y=224):
    # cropping
    final_image = crop_image_to_circle(image)
    
    # blurring for reducing noise
    final_image = cv2.bilateralFilter(final_image, d=5, sigmaColor=75, sigmaSpace=75)
    # CLACHE in green channel
    green_channel = final_image[:, :, 1]
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    enhanced_green_channel = clahe.apply(green_channel)
    final_image[:, :, 1] = enhanced_green_channel
    
    final_image = cv2.resize(final_image, (size_x, size_y), interpolation=cv2.INTER_LINEAR)
    
    return final_image

In [8]:
DIRECTORY_PATH = '/kaggle/input/fundus-images-odir/one_eye_images_ODIR_only'
# AUGMENTED_DIRECTORY_PATH ="original_and_augmented_images"

In [6]:
def save_model(train_dataloader, validation_dataloader, test_dataloader, optimizer, model, train_losses, val_losses, epoch):
    with open(f'train_losses_{epoch}.pkl', 'wb') as file:
        pickle.dump(train_losses, file)
    with open(f'val_losses_{epoch}.pkl', 'wb') as file:
        pickle.dump(val_losses, file)

    torch.save(model.state_dict(), f'vgg_model_{epoch}.pth')
    torch.save(train_dataloader, f'train_dataloader_{epoch}.pth')
    torch.save(validation_dataloader, f'validation_dataloader_{epoch}.pth')
    torch.save(test_dataloader, f'test_dataloader_{epoch}.pth')
    torch.save(optimizer.state_dict(), f'optimizer_{epoch}.pth')

# model for one eye input

Dataset class

In [9]:
class CustomDataset(Dataset):
    def __init__(self, dataset_X, dataset_y, image_directory_path, transfrom=None, include_preprocess_function=False, mode='train'):
        self.X = dataset_X
        self.y = dataset_y
        self.image_directory_path = image_directory_path
        self.preprosecc_status = include_preprocess_function
        
        
        if mode == 'train':
            if self.preprosecc_status:
                self.transform = transfrom or transforms.Compose([
                    # transforms.RandomHorizontalFlip(),
                    # transforms.RandomRotation(degrees=90),
                    # transforms.Lambda(lambda img: Image.fromarray(preprocess(np.array(img)))),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomVerticalFlip(),
                    transforms.ColorJitter(brightness=0.5, contrast=0.2, saturation=0.2, hue=0),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) #values from image net, better for pretrained models, can be changed to dataset values
                ])
            else:
                self.transform = transfrom or transforms.Compose([
                    # transforms.Lambda(lambda img: Image.fromarray(crop_image_to_circle(np.array(img)))), #1111111
                    transforms.Resize((224, 224)),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomVerticalFlip(),
                    transforms.ColorJitter(brightness=0.5, contrast=0.2, saturation=0.2, hue=0),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) #values from image net, better for pretrained models, can be changed to dataset values
                ])
        else:
            if self.preprosecc_status:
                self.transform = transfrom or transforms.Compose([
                    # transforms.Lambda(lambda img: Image.fromarray(preprocess(np.array(img)))),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ])
            else:
                self.transform = transfrom or transforms.Compose([
                    # transforms.Lambda(lambda img: Image.fromarray(crop_image_to_circle(np.array(img)))),
                    transforms.Resize((224, 224)),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) #values from image net, better for pretrained models, can be changed to dataset values
                    ])
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, index):
        eye_image = cv2.imread(os.path.join(self.image_directory_path, self.X.iloc[index]['image_id']))
        if self.preprosecc_status:
            preprocessed_image = Image.fromarray(preprocess((eye_image)))
        else:
            preprocessed_image = Image.fromarray(crop_image_to_circle(eye_image))
            
        diagnosis = self.y.iloc[index][['diabetic_retinopathy', 'amd', 'hypertensive_retinopathy', 
                                                'normal_eye', 'glaucoma', 'cataract']].to_numpy(dtype=np.float32)
        age = self.X.iloc[index]['patient_age']
        sex = self.X.iloc[index]['patient_sex']
        if self.transform:
           preprocessed_image = self.transform(preprocessed_image)
        
        data = {"eye_image": preprocessed_image,
                "diagnosis": torch.tensor(diagnosis, dtype=torch.float32),
                "metadata": torch.tensor(np.array([age, sex]), dtype=torch.float32)}
        
        return data['eye_image'], data['metadata'], data['diagnosis']

Model class

1. vgg 

In [11]:
class NetworkVGG(nn.Module):
    def __init__(self):
        super(NetworkVGG, self).__init__()
        
        self.eye_input = models.vgg16(weights='VGG16_Weights.DEFAULT').requires_grad_(False)
        self.eye_input = nn.Sequential(*list(self.eye_input.children())[:-1]) #removing the classifier layer 
 
        self.metadata_input = nn.Sequential(
            nn.Linear(2, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 64),
            nn.ReLU(inplace=True)
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(25088 + 64, 4096), #4096
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(512, 6),
            # nn.Softmax(dim=1)
        )
        
    def forward(self, eye_image, metadata): 
        vgg = self.eye_input(eye_image)
        vgg = torch.flatten(vgg, 1)
        meta = self.metadata_input(metadata)
        concatenated = torch.cat((vgg, meta), dim=1)
        result = self.classifier(concatenated)
        return result

train function

In [32]:
def train(dataloader, model, loss_function, optimizer, device="cpu"):
    size = len(dataloader.dataset)
    running_loss = 0.
    all_predictions = []
    all_labels = []
    model.train()
    for batch, (image, metadata, diagnosis) in enumerate(dataloader): #(image, metadata, diagnosis)
        image = image.to(device)
        metadata = metadata.to(device)
        diagnosis = diagnosis.to(device)
        prediction = model(image, metadata)
        prediction = prediction.to(device)
        loss = loss_function(prediction, diagnosis)
        running_loss += loss.item()

        softmax = nn.Softmax(dim=1)
        probabilities = softmax(prediction)
        preds = torch.argmax(probabilities, dim=1)
        all_predictions.extend(preds.cpu().numpy())
        all_labels.extend(torch.argmax(diagnosis, dim=1).cpu().numpy())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 2 == 0:
            loss, current = loss.item(), (batch + 1) * len(diagnosis)
            print(f"loss: {loss:>7f}, [{current:>5d}/{size:>5d}]")
    f1 = f1_score(all_labels, all_predictions, average='macro')
    training_loss = running_loss / len(dataloader)
    print(f"training_loss: {training_loss:>7f}, f1_score: {f1:>7f}")
    return training_loss, f1

In [33]:
def validation(dataloader, model, loss_function, device="cpu"):
    val_loss = 0.
    model.eval()
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for (image, metadata, diagnosis) in dataloader: #(image, metadata, diagnosis)
            image = image.to(device)
            metadata = metadata.to(device)
            diagnosis = diagnosis.to(device)
            prediction = model(image, metadata) #, metadata
            prediction = prediction.to(device)
            batch_val_loss = loss_function(prediction, diagnosis)
            val_loss += batch_val_loss
            softmax = nn.Softmax(dim=1)
            probabilities = softmax(prediction)
            preds = torch.argmax(probabilities, dim=1)
            all_predictions.extend(preds.cpu().numpy())
            all_labels.extend(torch.argmax(diagnosis, dim=1).cpu().numpy())
    validation_loss = val_loss / len(dataloader)
    f1 = f1_score(all_labels, all_predictions, average='macro')
    classification_overall = classification_report(all_labels, all_predictions)
    print(f"validation_loss: {validation_loss:>7f}, f1_score: {f1:>7f}")
    return validation_loss, f1, classification_overall

In [23]:
if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")

In [28]:
data = pd.read_csv("/kaggle/input/fundus-images-odir/odir_one_eye_data_only.csv")
classes = ['diabetic_retinopathy', 'amd', 'hypertensive_retinopathy', 'normal_eye', 'glaucoma', 'cataract']
X = data[['image_id', 'patient_age', 'patient_sex']]
y = data[classes]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_test, X_val, y_test, y_val = train_test_split(X_train, y_train, test_size=0.25, random_state=42)


In [16]:
data[classes].value_counts()

diabetic_retinopathy  amd  hypertensive_retinopathy  normal_eye  glaucoma  cataract
0                     0    0                         1           0         0           3094
1                     0    0                         0           0         0           1667
0                     0    0                         0           0         1            292
                                                                 1         0            268
                      1    0                         0           0         0            242
                      0    1                         0           0         0            110
Name: count, dtype: int64

In [17]:
y_train[classes].value_counts()

diabetic_retinopathy  amd  hypertensive_retinopathy  normal_eye  glaucoma  cataract
0                     0    0                         1           0         0           2490
1                     0    0                         0           0         0           1314
0                     0    0                         0           0         1            244
                                                                 1         0            212
                      1    0                         0           0         0            185
                      0    1                         0           0         0             93
Name: count, dtype: int64

datasets and dataloaders

In [29]:
age_scaler = StandardScaler()
with open(f'age_scaler.pkl', 'wb') as file:
    pickle.dump(age_scaler, file)

X_train['patient_age'] = age_scaler.fit_transform(X_train[['patient_age']])
X_val['patient_age'] = age_scaler.transform(X_val[['patient_age']])
X_test['patient_age'] = age_scaler.transform(X_test[['patient_age']])

train_dataset = CustomDataset(X_train, y_train, DIRECTORY_PATH, include_preprocess_function=True)
validation_dataset = CustomDataset(X_val, y_val, DIRECTORY_PATH, mode='test', include_preprocess_function=True)
test_dataset = CustomDataset(X_test, y_test, DIRECTORY_PATH, mode='test', include_preprocess_function=True)

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)


    # Compute class weights using class indices
y_train_indices = np.argmax(y_train, axis=1)
class_weights = compute_class_weight('balanced', classes=np.unique(y_train_indices), y=y_train_indices)
class_weights = torch.tensor(class_weights, dtype=torch.float32)

In [40]:
epochs_vgg16 = 200
learning_rate_vgg = 0.0001
model_vgg = NetworkVGG().to(device)
loss_function = nn.CrossEntropyLoss(weight=class_weights.to(device))
optimizer = optim.Adam(model_vgg.parameters(), lr=learning_rate_vgg)
train_losses = []
val_losses = []
f1_scores_train = []
f1_scores_val = []
reports = []

for t in range(epochs_vgg16):
    print(f"Epoch {t+1}\n-------------------------------")
    epoch_train_loss, f1_score_train = train(train_dataloader, model_vgg, loss_function, optimizer, device)
    epoch_val_loss, f1_score_val, report = validation(validation_dataloader, model_vgg, loss_function, device)
    train_losses.append(epoch_train_loss)
    val_losses.append(epoch_val_loss)
    f1_scores_val.append(f1_score_val)
    f1_scores_train.append(f1_score_train)
    if (t + 1) % 50 == 0:
        save_model(train_dataloader, validation_dataloader, test_dataloader, optimizer, model_vgg, train_losses, val_losses, epoch=t+1)
        with open(f"f1_train_scores_epoch{t+1}.pkl", 'wb') as file:
            pickle.dump(f1_scores_train, file)
        with open(f"f1_tval_scores_epoch{t+1}.pkl", 'wb') as file:
            pickle.dump(f1_scores_val, file)
print("Done!")

Epoch 1
-------------------------------
loss: 1.572621, [   64/ 4538]
loss: 1.273519, [  192/ 4538]
loss: 1.943160, [  320/ 4538]
loss: 2.025402, [  448/ 4538]
loss: 1.661308, [  576/ 4538]
loss: 1.753855, [  704/ 4538]
loss: 1.459298, [  832/ 4538]
loss: 1.844298, [  960/ 4538]
loss: 0.890791, [ 1088/ 4538]
loss: 1.457503, [ 1216/ 4538]
loss: 1.744282, [ 1344/ 4538]
loss: 1.136464, [ 1472/ 4538]
loss: 1.435172, [ 1600/ 4538]
loss: 1.338549, [ 1728/ 4538]
loss: 1.896249, [ 1856/ 4538]
loss: 1.216908, [ 1984/ 4538]
loss: 1.376751, [ 2112/ 4538]
loss: 1.852203, [ 2240/ 4538]
loss: 1.856295, [ 2368/ 4538]
loss: 1.754092, [ 2496/ 4538]
loss: 1.669850, [ 2624/ 4538]
loss: 2.236166, [ 2752/ 4538]
loss: 1.087487, [ 2880/ 4538]
loss: 1.304839, [ 3008/ 4538]
loss: 1.669696, [ 3136/ 4538]
loss: 1.424150, [ 3264/ 4538]
loss: 1.592016, [ 3392/ 4538]
loss: 1.566654, [ 3520/ 4538]
loss: 1.422950, [ 3648/ 4538]
loss: 1.571306, [ 3776/ 4538]
loss: 1.858013, [ 3904/ 4538]
loss: 1.455936, [ 4032/ 4538]


KeyboardInterrupt: 

In [115]:
with open(f'f1_train_scores.pkl', 'wb') as file:
    pickle.dump(f1_scores_train, file)
with open(f'f1_tval_scores.pkl', 'wb') as file:
    pickle.dump(f1_scores_val, file)
save_model(train_dataloader, validation_dataloader, test_dataloader, optimizer, model_vgg, train_losses, val_losses, epoch=epochs_vgg16)

In [113]:
print(report)

              precision    recall  f1-score   support

           0       0.82      0.56      0.66       486
           1       0.65      0.91      0.76        97
           2       0.35      0.82      0.49        49
           3       0.75      0.72      0.74       609
           4       0.54      0.92      0.68        66
           5       0.64      0.98      0.77        48

    accuracy                           0.70      1355
   macro avg       0.62      0.82      0.68      1355
weighted avg       0.74      0.70      0.70      1355



In [37]:
torch.cuda.device_count()


2

In [38]:

torch.cuda.current_device()


0

In [39]:
torch.cuda.get_device_name(0)

'Tesla T4'