In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix, classification_report
import pandas as pd
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# ----------------------------
# Config
# ----------------------------
IMG_SIZE = 380
BATCH_SIZE = 32
EPOCHS = 25
LR = 1e-4
SEED = 42

# ✅ Updated paths to APTOS dataset
DATA_CSV = "/kaggle/input/aptos2019-blindness-detection/train.csv"
IMG_DIR = "/kaggle/input/aptos2019-blindness-detection/train_images"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

torch.manual_seed(SEED)

# ----------------------------
# Dataset (APTOS)
# ----------------------------
class APTOSDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_id = row['id_code'] + ".png"   # ✅ APTOS filenames end in .png
        label = int(row['diagnosis'])
        img_path = os.path.join(self.img_dir, img_id)
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# ----------------------------
# Data Prep
# ----------------------------
df = pd.read_csv(DATA_CSV)

train_df, val_df = train_test_split(
    df,
    test_size=0.2,
    stratify=df["diagnosis"],
    random_state=SEED
)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    transforms.RandomRotation(20),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

train_dataset = APTOSDataset(train_df, IMG_DIR, transform=train_transform)
val_dataset = APTOSDataset(val_df, IMG_DIR, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=4, pin_memory=True)

# ----------------------------
# Model
# ----------------------------
model = models.resnet101(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 5)

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    model = nn.DataParallel(model)
model = model.to(DEVICE)

# ----------------------------
# Loss and Optimizer
# ----------------------------
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=None, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
        self.ce = nn.CrossEntropyLoss(weight=alpha, reduction='none')

    def forward(self, input, target):
        logp = self.ce(input, target)
        p = torch.exp(-logp)
        loss = (1 - p) ** self.gamma * logp
        if self.reduction == 'mean':
            return loss.mean()
        else:
            return loss.sum()

criterion = FocalLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=1)
scaler = torch.cuda.amp.GradScaler()

# ----------------------------
# Training Loop
# ----------------------------
best_acc = 0.0
for epoch in range(EPOCHS):
    model.train()
    train_loss, train_preds, train_targets = 0, [], []
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]", leave=False)

    for imgs, labels in train_bar:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(imgs)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()
        preds = outputs.argmax(1).detach().cpu().numpy()
        train_preds.extend(preds)
        train_targets.extend(labels.cpu().numpy())
        train_bar.set_postfix(loss=loss.item())

    train_acc = accuracy_score(train_targets, train_preds)

    model.eval()
    val_loss, val_preds, val_targets = 0, [], []
    val_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]", leave=False)
    with torch.no_grad():
        for imgs, labels in val_bar:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            with torch.cuda.amp.autocast():
                outputs = model(imgs)
                loss = criterion(outputs, labels)
            val_loss += loss.item()
            preds = outputs.argmax(1).detach().cpu().numpy()
            val_preds.extend(preds)
            val_targets.extend(labels.cpu().numpy())
            val_bar.set_postfix(loss=loss.item())

    val_acc = accuracy_score(val_targets, val_preds)
    print(f"Epoch {epoch+1}/{EPOCHS} | "
          f"Train Loss: {train_loss/len(train_loader):.4f}, "
          f"Train Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss/len(val_loader):.4f}, "
          f"Val Acc: {val_acc:.4f}")

    scheduler.step()

    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "/kaggle/working/best_resnet101.pth")

# ----------------------------
# Evaluation
# ----------------------------
model.load_state_dict(torch.load("/kaggle/working/best_resnet101.pth", map_location=DEVICE))
model.eval()
all_preds, all_labels = [], []

with torch.no_grad():
    for imgs, labels in val_loader:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        with torch.cuda.amp.autocast():
            outputs = model(imgs)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

acc = accuracy_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds, average="weighted")
precision = precision_score(all_labels, all_preds, average="weighted")
recall = recall_score(all_labels, all_preds, average="weighted")
cm = confusion_matrix(all_labels, all_preds)

print(f"\n📊 Accuracy: {acc:.4f}")
print(f"📊 F1 Score: {f1:.4f}")
print(f"📊 Precision: {precision:.4f}")
print(f"📊 Recall: {recall:.4f}")
print("\nClassification Report:\n", classification_report(all_labels, all_preds))

plt.figure(figsize=(8,6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=[0,1,2,3,4], yticklabels=[0,1,2,3,4])
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("ResNet101 Confusion Matrix")
plt.tight_layout()
plt.show()


Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth
100%|██████████| 171M/171M [00:00<00:00, 205MB/s] 
  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
                                                                             

Epoch 1/25 | Train Loss: 0.4020, Train Acc: 0.7289 | Val Loss: 0.2959, Val Acc: 0.7817


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
                                                                             

Epoch 2/25 | Train Loss: 0.2660, Train Acc: 0.7986 | Val Loss: 0.2835, Val Acc: 0.7981


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
                                                                             

Epoch 3/25 | Train Loss: 0.2274, Train Acc: 0.8238 | Val Loss: 0.2481, Val Acc: 0.7954


  with torch.cuda.amp.autocast():
                                                                                

KeyboardInterrupt: 