In [23]:
import os
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import numpy as np
from sklearn.model_selection import train_test_split


In [46]:
# Paths
IMG_DIR = Path("data/processed/classification/images")
LABEL_DIR = Path("data/processed/classification/labels")
MODEL_PATH = Path("models/classification_model.pth")
IMG_SIZE = 224
BATCH_SIZE = 16
EPOCHS = 30
LR = 1e-4

In [25]:
# 1. Custom Dataset
class FundusDataset(Dataset):
    def __init__(self, img_files, transform=None):
        self.img_files = img_files
        self.transform = transform

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        img_path = self.img_files[idx]
        label_path = LABEL_DIR / (img_path.stem + ".txt")

        image = Image.open(img_path).convert("RGB")
        label = int(open(label_path).read().strip())

        if self.transform:
            image = self.transform(image)

        return image, label

In [44]:
# 2. Data Loaders (updated with augmentation)
def get_dataloaders():
    all_images = list(IMG_DIR.glob("*.jpg"))
    valid_images = []
    for img_path in all_images:
        label_path = LABEL_DIR / f"{img_path.stem}.txt"
        if label_path.exists():
            valid_images.append(img_path)
        else:
            print(f"Missing label for: {img_path.name}")

    if len(valid_images) < 2:
        raise ValueError("Not enough valid samples for splitting!")

    # Split into train (60%), val (20%), test (20%)
    train_files, test_val_files = train_test_split(valid_images, test_size=0.4, random_state=42)
    val_files, test_files = train_test_split(test_val_files, test_size=0.5, random_state=42)

    # Augmentation for training data
    train_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])

    # Simple transform for validation/test
    val_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])

    train_ds = FundusDataset(train_files, train_transform)
    val_ds = FundusDataset(val_files, val_transform)
    test_ds = FundusDataset(test_files, val_transform)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

    return train_loader, val_loader, test_loader



In [35]:
# MixUp implementation


def mixup_data(x, y, alpha=0.4):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

In [43]:
# 3. Training Loop (with regularization)
def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, val_loader, _ = get_dataloaders()

    # Model with dropout
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Sequential(
        nn.Dropout(0.5),  # 50% dropout
        nn.Linear(model.fc.in_features, 5)
    )
    model.to(device)

    # Loss function with class weighting (adjust weights according to your data)
    class_weights = torch.tensor([1.0, 2.0, 2.0, 3.0, 3.0])  # Example weights
    criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
    optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.1)

    # Early stopping
    best_val_loss = float('inf')
    patience = 3
    trigger_times = 0

    for epoch in range(EPOCHS):
        model.train()
        total_loss, correct = 0.0, 0

        # Training phase
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * images.size(0)
            correct += (outputs.argmax(1) == labels).sum().item()

        train_loss = total_loss / len(train_loader.dataset)
        train_acc = correct / len(train_loader.dataset)

        # Validation phase
        model.eval()
        val_loss, val_correct = 0.0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                val_correct += (outputs.argmax(1) == labels).sum().item()

        val_loss = val_loss / len(val_loader.dataset)
        val_acc = val_correct / len(val_loader.dataset)
        
        # Update scheduler
        scheduler.step(val_loss)

        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            trigger_times = 0
            # Save best model
            MODEL_PATH.parent.mkdir(parents=True, exist_ok=True)
            torch.save(model.state_dict(), MODEL_PATH)
        else:
            trigger_times += 1
            if trigger_times >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

        print(f"Epoch {epoch+1}/{EPOCHS}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    print("Training complete. Best model saved to:", MODEL_PATH)

In [49]:
train()

Epoch 1/30, Train Loss: 1.5944, Train Acc: 0.3320, Val Loss: 1.2115, Val Acc: 0.6265
Epoch 2/30, Train Loss: 1.2281, Train Acc: 0.4899, Val Loss: 1.0794, Val Acc: 0.5542
Epoch 3/30, Train Loss: 1.0619, Train Acc: 0.5789, Val Loss: 1.1183, Val Acc: 0.5422
Epoch 4/30, Train Loss: 0.9319, Train Acc: 0.6275, Val Loss: 0.9880, Val Acc: 0.6506
Epoch 5/30, Train Loss: 0.7433, Train Acc: 0.7206, Val Loss: 1.1108, Val Acc: 0.6386
Epoch 6/30, Train Loss: 0.6591, Train Acc: 0.7611, Val Loss: 1.0784, Val Acc: 0.6024
Early stopping at epoch 7
Training complete. Best model saved to: models\classification_model.pth
