In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix, classification_report
import pandas as pd
from PIL import Image
import timm
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import random

# ----------------------------
# Config
# ----------------------------
IMG_SIZE = 224   # Swin is trained with 224x224
BATCH_SIZE = 32
EPOCHS = 20
LR = 1e-4
SEED = 42

DATA_CSV = "/kaggle/input/messidor2preprocess/messidor_data.csv"
IMG_DIR = "/kaggle/input/messidor2preprocess/messidor-2/messidor-2/preprocess"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# reproducibility
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if DEVICE == "cuda":
    torch.cuda.manual_seed_all(SEED)

# ----------------------------
# Dataset
# ----------------------------
class MessidorDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_id = row['id_code']       # already has extension like "xyz.jpg"
        label = int(row['diagnosis'])
        img_path = os.path.join(self.img_dir, img_id)
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label


if __name__ == "__main__":
    # ----------------------------
    # Data Prep
    # ----------------------------
    df = pd.read_csv(DATA_CSV)

    # Ensure labels are ints 0–4
    if df["diagnosis"].dtype not in ["int64", "int32"]:
        label_map = {cls: i for i, cls in enumerate(sorted(df["diagnosis"].unique()))}
        df["diagnosis"] = df["diagnosis"].map(label_map)

    print("Unique labels:", np.unique(df["diagnosis"]))
    print("Label distribution:\n", df["diagnosis"].value_counts())

    train_df, val_df = train_test_split(
        df, test_size=0.2, stratify=df["diagnosis"], random_state=SEED
    )

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225]),
    ])

    val_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225]),
    ])

    train_dataset = MessidorDataset(train_df, IMG_DIR, transform=train_transform)
    val_dataset = MessidorDataset(val_df, IMG_DIR, transform=val_transform)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                            num_workers=4, pin_memory=True)

    # ----------------------------
    # Model (Swin Transformer)
    # ----------------------------
    model_name = "swin_base_patch4_window7_224"
    print(f"Creating model: {model_name}")
    model = timm.create_model(model_name, pretrained=True, num_classes=5)

    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs via DataParallel.")
        model = nn.DataParallel(model)

    model = model.to(DEVICE)

    # ----------------------------
    # Loss, Optimizer, Scheduler, AMP
    # ----------------------------
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
    scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE == "cuda"))

    # ----------------------------
    # Training loop
    # ----------------------------
    best_acc = 0.0
    best_ckpt = "/kaggle/working/best_swin.pth"
    print(f"Starting training for {EPOCHS} epochs. Steps/epoch={len(train_loader)}, LR={LR}")

    for epoch in range(EPOCHS):
        model.train()
        train_loss = 0.0
        train_preds, train_targets = [], []
        train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]", leave=False)

        for imgs, labels in train_bar:
            imgs = imgs.to(DEVICE, non_blocking=True)
            labels = labels.to(DEVICE, non_blocking=True)

            optimizer.zero_grad()
            with torch.amp.autocast(device_type="cuda", enabled=(DEVICE == "cuda")):
                outputs = model(imgs)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item() * imgs.size(0)
            preds = outputs.argmax(dim=1).detach().cpu().numpy()
            train_preds.extend(preds)
            train_targets.extend(labels.cpu().numpy())
            train_bar.set_postfix(loss=loss.item())

        train_loss /= len(train_loader.dataset)
        train_acc = accuracy_score(train_targets, train_preds)

        # ----------------------------
        # Validation
        # ----------------------------
        model.eval()
        val_loss = 0.0
        val_preds, val_targets = [], []
        val_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]", leave=False)
        with torch.no_grad():
            for imgs, labels in val_bar:
                imgs = imgs.to(DEVICE, non_blocking=True)
                labels = labels.to(DEVICE, non_blocking=True)
                with torch.amp.autocast(device_type="cuda", enabled=(DEVICE == "cuda")):
                    outputs = model(imgs)
                    loss = criterion(outputs, labels)

                val_loss += loss.item() * imgs.size(0)
                preds = outputs.argmax(dim=1).detach().cpu().numpy()
                val_preds.extend(preds)
                val_targets.extend(labels.cpu().numpy())
                val_bar.set_postfix(loss=loss.item())

        val_loss /= len(val_loader.dataset)
        val_acc = accuracy_score(val_targets, val_preds)
        val_f1 = f1_score(val_targets, val_preds, average="weighted")

        print(f"Epoch {epoch+1}/{EPOCHS} | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}")

        scheduler.step()

        if val_acc > best_acc:
            best_acc = val_acc
            state_dict_to_save = model.module.state_dict() if hasattr(model, "module") else model.state_dict()
            torch.save(state_dict_to_save, best_ckpt)
            print(f"✅ New best model saved. Best Val Acc: {best_acc:.4f}")

    print(f"\n✅ Training complete. Best Val Acc: {best_acc:.4f}")

    # ----------------------------
    # Evaluation (load best)
    # ----------------------------
    state_dict = torch.load(best_ckpt, map_location=DEVICE)
    if hasattr(model, "module"):
        model.module.load_state_dict(state_dict)
    else:
        model.load_state_dict(state_dict)
    model.eval()

    all_preds, all_labels = [], []
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs = imgs.to(DEVICE)
            labels = labels.to(DEVICE)
            with torch.amp.autocast(device_type="cuda", enabled=(DEVICE == "cuda")):
                outputs = model(imgs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average="weighted")
    precision = precision_score(all_labels, all_preds, average="weighted", zero_division=0)
    recall = recall_score(all_labels, all_preds, average="weighted", zero_division=0)
    cm = confusion_matrix(all_labels, all_preds)

    print(f"\n📊 Accuracy: {acc:.4f}")
    print(f"📊 F1 Score: {f1:.4f}")
    print(f"📊 Precision: {precision:.4f}")
    print(f"📊 Recall: {recall:.4f}")
    print("\nClassification Report:\n", classification_report(all_labels, all_preds, zero_division=0))

    plt.figure(figsize=(8,6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=[0,1,2,3,4], yticklabels=[0,1,2,3,4])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Swin Transformer Confusion Matrix")
    plt.tight_layout()
    plt.show()
