In [None]:
!pip install torchsummary

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import utils
from torchvision.io import read_image
import torchvision.models as models
import torchvision.transforms as T
from torchsummary import summary

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using {device} device")

# Data Processing

In [None]:
root_dir = '/kaggle/input/retinal-disease-classification/' 
train_dir = os.path.join(root_dir, 'Training_Set', 'Training_Set')
val_dir = os.path.join(root_dir, 'Evaluation_Set', 'Evaluation_Set')
test_dir = os.path.join(root_dir, 'Test_Set', 'Test_Set')

In [None]:
class RetinaDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, str(self.img_labels.iloc[idx, 0]) + '.png')
        image = read_image(img_path)
        label = torch.tensor(self.img_labels.iloc[idx, 2:].values, dtype=torch.float32)
        if self.transform is not None:
            image = self.transform(image)
        return image, label

In [None]:
imagenet_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

transform= T.Compose([T.Resize((256, 256)),
                      T.RandomAdjustSharpness(2, 0.8),
                      T.RandomHorizontalFlip(0.5),
                      T.RandomVerticalFlip(0.5),    
                      T.ConvertImageDtype(torch.float32)])

test_transform= T.Compose([T.Resize((256, 256)), T.ConvertImageDtype(torch.float32)])

In [None]:
train_data = RetinaDataset(train_dir+'/RFMiD_Training_Labels.csv',
                           train_dir+'/Training', transform=transform)

val_data = RetinaDataset(val_dir+'/RFMiD_Validation_Labels.csv',
                              val_dir+'/Validation', transform=transform)

test_data = RetinaDataset(test_dir+'/RFMiD_Testing_Labels.csv',
                              test_dir+'/Test', transform=test_transform)

In [None]:
# for i, (image, label) in enumerate(train_data):
#     print(i, image.size(), len(label))

In [None]:
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
valid_dataloader = DataLoader(val_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

In [None]:
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.shape}, Type: {train_features.dtype}")
print(f"Labels batch shape: {train_labels.shape},  Type: {train_labels.dtype}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img.permute(1, 2, 0))
plt.show()
print(f"Label: {label}")

# Model

In [None]:
model = models.resnext50_32x4d(pretrained=True)
model.fc = nn.Linear(in_features=2048, out_features=45, bias=True)
model = model.to(device)
summary(model, (3, 256, 256))

# Training

In [None]:
epochs = 10
lr = 1e-3
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
from tqdm import tqdm
model.train()
for epoch in range(1, epochs+1):
    losses, accs = [], []
    with tqdm(train_dataloader, unit="batch") as tepoch:
        for images, labels in tepoch:
            tepoch.set_description(f"Epoch {epoch}")
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
        
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            outputs = outputs > 0.5
            acc = (outputs == labels).float().mean(dim=1).mean().item()
            
            losses.append(loss.item())
            accs.append(acc)
            tepoch.set_postfix(loss=loss.item(), accuracy=acc)
     
    print(f'Epoch {epoch + 1} Training - Loss: {torch.tensor(losses).mean().item()}, Acc: {torch.tensor(accs).mean().item()}')
        
    losses, accs = [], []
    with torch.no_grad():
        with tqdm(valid_dataloader, unit="batch") as tepoch:
            for images, labels in tepoch:
                images, labels = images.to(device), labels.to(device)
                
                outputs = model(images)
                loss = criterion(outputs, labels)

                outputs = outputs > 0.5
                acc = (outputs == labels).float().mean(dim=1).mean()

                losses.append(loss.item())
                accs.append(acc)
                tepoch.set_postfix(loss=loss.item(), accuracy=acc)
            print(f'Epoch {epoch + 1} Validation - Loss: {torch.tensor(losses).mean().item()}, Acc: {torch.tensor(accs).mean().item()}')


In [None]:
BEST_MODEL = f'bestmodel.pth'
torch.save(model.state_dict(), BEST_MODEL)

In [None]:
model.load_state_dict(torch.load(BEST_MODEL))

In [None]:
def true_positive(y_pred, y_true, th=0.5):
    assert y_pred.shape == y_true.shape
    y_pred = (y_pred > th).float()
    true_positive = torch.sum((y_pred == 1) & (y_true == 1), dim=0)
    return true_positive

def false_positive(y_pred, y_true, th=0.5):
    assert y_pred.shape == y_true.shape
    y_pred = (y_pred > th).float()
    false_positive = torch.sum((y_pred == 1) & (y_true == 0), dim=0)
    return false_positive

def false_negative(y_pred, y_true, th=0.5):
    assert y_pred.shape == y_true.shape
    y_pred = (y_pred > th).float()
    false_negative = torch.sum((y_pred == 0) & (y_true == 1), dim=0)
    return false_negative

In [None]:
OUT_FINAL = 45
true_pos, false_pos, false_neg = torch.zeros(OUT_FINAL), torch.zeros(OUT_FINAL), torch.zeros(OUT_FINAL)
losses, accs = [], []

model.eval()
with torch.no_grad():
    for x, y in tqdm(test_dataloader, desc='Evaluation'):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)        
        loss = criterion(outputs, labels)
        
        outputs = outputs > 0.5
        acc = (outputs == labels).float().mean(dim=1).mean()

        losses.append(loss.item())
        accs.append(acc)
        tepoch.set_postfix(loss=loss.item(), accuracy=acc)
        
        outputs, labels = outputs.detach().cpu(), labels.detach().cpu()
        tp = true_positive(outputs, labels)
        fp = false_positive(outputs, labels)
        fn = false_negative(outputs, labels)
        true_pos += tp; false_pos += fp; false_neg += fn

    print(f'Test- Loss: {torch.tensor(losses).mean().item()}, Acc: {torch.tensor(accs).mean().item()}')


In [None]:
precision = true_pos / (true_pos + false_pos + 1e-10)
precision

In [None]:
recall = true_pos / (true_pos + false_neg + 1e-10)
recall

In [None]:
f1 = 2 * (precision * recall) / (precision + recall + 1e-10)
f1