In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from PIL import Image
import timm
from tqdm import tqdm   # ✅ progress bar
import random

# ----------------------------
# Config
# ----------------------------
IMG_SIZE = 224
BATCH_SIZE = 16
EPOCHS = 20
LR = 1e-4
SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

torch.manual_seed(SEED)

# ----------------------------
# Dataset
# ----------------------------
class APTOSDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['id_code'] + ".png")
        image = Image.open(img_path).convert("RGB")
        label = row['diagnosis']
        if self.transform:
            image = self.transform(image)
        return image, label

# ----------------------------
# Data Preparation
# ----------------------------
df = pd.read_csv("/kaggle/input/aptos2019-blindness-detection/train.csv")

train_df, val_df = train_test_split(
    df, test_size=0.2, stratify=df["diagnosis"], random_state=SEED
)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

train_dataset = APTOSDataset(train_df, "/kaggle/input/aptos2019-blindness-detection/train_images", transform=train_transform)
val_dataset = APTOSDataset(val_df, "/kaggle/input/aptos2019-blindness-detection/train_images", transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# ----------------------------
# Model: Swin Transformer
# ----------------------------
model = timm.create_model("swin_base_patch4_window7_224", pretrained=True, num_classes=5)
model = model.to(DEVICE)

# Loss with label smoothing
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# Optimizer + Scheduler
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

# ----------------------------
# Training Loop with tqdm
# ----------------------------
best_acc = 0.0

for epoch in range(EPOCHS):
    # ---- Train ----
    model.train()
    train_loss, train_preds, train_targets = 0, [], []
    
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]", leave=False)
    for imgs, labels in train_bar:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        preds = outputs.argmax(1).detach().cpu().numpy()
        train_preds.extend(preds)
        train_targets.extend(labels.cpu().numpy())

        train_bar.set_postfix(loss=loss.item())

    train_acc = accuracy_score(train_targets, train_preds)

    # ---- Validation ----
    model.eval()
    val_loss, val_preds, val_targets = 0, [], []
    val_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]", leave=False)
    with torch.no_grad():
        for imgs, labels in val_bar:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            outputs = model(imgs)
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            preds = outputs.argmax(1).detach().cpu().numpy()
            val_preds.extend(preds)
            val_targets.extend(labels.cpu().numpy())

            val_bar.set_postfix(loss=loss.item())

    val_acc = accuracy_score(val_targets, val_preds)

    print(f"Epoch {epoch+1}/{EPOCHS} | "
          f"Train Loss: {train_loss/len(train_loader):.4f}, "
          f"Train Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss/len(val_loader):.4f}, "
          f"Val Acc: {val_acc:.4f}")
    
    scheduler.step()

    # Save best model
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_swin.pth")
        best_preds, best_targets = val_preds.copy(), val_targets.copy()

print("Training complete. Best Val Acc:", best_acc)

# ----------------------------
# Evaluation Report
# ----------------------------
print("\nClassification Report:")
print(classification_report(best_targets, best_preds, target_names=[f"Class {i}" for i in range(5)]))

# Confusion Matrix
cm = confusion_matrix(best_targets, best_preds)
plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=[f"Class {i}" for i in range(5)],
            yticklabels=[f"Class {i}" for i in range(5)])
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.show()

# ----------------------------
# Single Image Prediction
# ----------------------------
model.load_state_dict(torch.load("best_swin.pth", map_location=DEVICE))
model.eval()

# pick random sample from validation set
idx = random.randint(0, len(val_df)-1)
sample = val_df.iloc[idx]
img_path = os.path.join("/kaggle/input/aptos2019-blindness-detection/train_images", sample['id_code']+".png")

image = Image.open(img_path).convert("RGB")
input_tensor = val_transform(image).unsqueeze(0).to(DEVICE)

with torch.no_grad():
    output = model(input_tensor)
    pred_class = output.argmax(1).item()

plt.imshow(image)
plt.axis("off")
plt.title(f"True: {sample['diagnosis']} | Pred: {pred_class}")
plt.show()
