# Diabetic Retinopathy Classification

This notebook fine-tunes a model for 'Referral' / 'No Referral' diabetic retinopathy classification using the APTOS 2019 and MESSIDOR-2 datasets.

## 1. Environment Setup

In [None]:
import os

# --- Install Dependencies ---
print("\n⏳ Installing dependencies...")
!pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121 -q
!pip install timm==0.9.16 pandas==2.2.2 scikit-learn -q
!pip install gdown -q
print("✅ Dependencies installed.")

# --- Set up device ---
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Download and Preprocess Datasets

In [None]:
import gdown
import os

# --- Download APTOS 2019 ---
aptos_url = 'https://drive.google.com/uc?id=162YPf4OhMVxj9TrQH0GnJv0n7z7gJWpj'
aptos_output = 'APTOS2019.zip'
if not os.path.exists(aptos_output):
    print('Downloading APTOS 2019 dataset...')
    gdown.download(aptos_url, aptos_output, quiet=False)
    !unzip -q {aptos_output}
    print('APTOS 2019 dataset downloaded and unzipped.')
else:
    print('APTOS 2019 dataset already downloaded.')

# --- Download MESSIDOR-2 ---
messidor_url = 'https://drive.google.com/uc?id=1vOLBUK9xdzNV8eVkRjVdNrRwhPfaOmda'
messidor_output = 'MESSIDOR2.zip'
if not os.path.exists(messidor_output):
    print('Downloading MESSIDOR-2 dataset...')
    gdown.download(messidor_url, messidor_output, quiet=False)
    !unzip -q {messidor_output}
    print('MESSIDOR-2 dataset downloaded and unzipped.')
else:
    print('MESSIDOR-2 dataset already downloaded.')

In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split

# --- Preprocess APTOS 2019 ---
aptos_df = pd.read_csv('train.csv')
aptos_df['image_path'] = aptos_df['id_code'].apply(lambda x: os.path.join('train_images', x + '.png'))

# --- Preprocess MESSIDOR-2 ---
messidor_df = pd.read_csv('messidor_data.csv')
messidor_df['image_path'] = messidor_df['image_id'].apply(lambda x: os.path.join('messidor-2', 'images', x + '.jpg'))

# --- Combine datasets ---
combined_df = pd.concat([
    aptos_df[['image_path', 'diagnosis']],
    messidor_df[['image_path', 'adjudicated_dr_grade']]
], ignore_index=True)
combined_df.rename(columns={'diagnosis': 'grade', 'adjudicated_dr_grade': 'grade'}, inplace=True)

# --- Create binary labels ---
# Referral: grade >= 2
# No Referral: grade < 2
combined_df['label'] = combined_df['grade'].apply(lambda x: 1 if x >= 2 else 0)

# --- Split data ---
train_df, val_df = train_test_split(combined_df, test_size=0.2, stratify=combined_df['label'], random_state=42)

print(f'Training samples: {len(train_df)}')
print(f'Validation samples: {len(val_df)}')
print(train_df['label'].value_counts())
print(val_df['label'].value_counts())

## 3. Implement Data Loading

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

class DRDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image_path = self.df.iloc[idx]['image_path']
        image = Image.open(image_path).convert('RGB')
        label = self.df.iloc[idx]['label']

        if self.transform:
            image = self.transform(image)

        return image, label

data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

train_dataset = DRDataset(train_df, transform=data_transforms['train'])
val_dataset = DRDataset(val_df, transform=data_transforms['val'])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

print(f'Train loader: {len(train_loader)} batches')
print(f'Validation loader: {len(val_loader)} batches')

## 4. Enhance Fine-Tuning Process

In [None]:
import torch.nn as nn
import torch.optim as optim
from torchvision import models

# --- Define the model ---
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2) # Binary classification
model = model.to(device)

# --- Define loss function and optimizer ---
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

print("Model, loss function, and optimizer are ready.")

## 5. Train and Evaluate the Model

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np

def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(loader.dataset)
    return epoch_loss

def evaluate_model(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)

            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    epoch_loss = running_loss / len(loader.dataset)
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds)
    recall = recall_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)

    return epoch_loss, accuracy, precision, recall, f1

num_epochs = 10
best_accuracy = 0.0

for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc, val_prec, val_rec, val_f1 = evaluate_model(model, val_loader, criterion, device)

    print(f'Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val Precision: {val_prec:.4f} | Val Recall: {val_rec:.4f} | Val F1: {val_f1:.4f}')

    if val_acc > best_accuracy:
        best_accuracy = val_acc
        torch.save(model.state_dict(), 'best_model.pth')
        print('Best model saved.')