In [7]:
zip_path = "/content/OCTDL.zip"     # <-- update if needed
extract_dir = "/content/OCTDL"      # destination folder


In [8]:
!unzip -q "$zip_path" -d "$extract_dir"


unzip:  cannot find or open /content/OCTDL.zip, /content/OCTDL.zip.zip or /content/OCTDL.zip.ZIP.


In [9]:
import os

root = "/kaggle/input/octdl-dataset"    # update if your extract dir is different
print("Folders inside OCTDL:", os.listdir(root))


Folders inside OCTDL: ['OCTDL']


In [10]:
import os
import shutil

# Define the paths
root_dir = "/kaggle/input"
nested_dir = os.path.join(root_dir, "OCTDL")

if os.path.exists(nested_dir):
    # Move contents from nested_dir to root_dir
    for item in os.listdir(nested_dir):
        s = os.path.join(nested_dir, item)
        d = os.path.join(root_dir, item)
        if os.path.isdir(s):
            shutil.move(s, d)
        else:
            shutil.move(s, d)

    # Remove the now empty nested_dir
    shutil.rmtree(nested_dir)
    print(f"Flattened directory structure: moved contents from '{nested_dir}' to '{root_dir}' and removed '{nested_dir}'")
else:
    print(f"No nested directory '{nested_dir}' found to flatten.")

print("Current folders inside OCTDL:", os.listdir(root_dir))

No nested directory '/kaggle/input/OCTDL' found to flatten.
Current folders inside OCTDL: ['octdl-dataset']


In [11]:
import os

root = "/kaggle/input/octdl-dataset/OCTDL"    # update if your extract dir is different
print("Folders inside OCTDL:", os.listdir(root))

Folders inside OCTDL: ['NO', 'AMD', 'VID', 'ERM', 'RVO', 'DME', 'RAO']


In [12]:
# === OCTDL: Optimized Dataset + Dataloader Setup (single cell) ===
# - Designed for Colab T4, persistent_workers=True for speed.

import os, json, time, random
from pathlib import Path
from collections import Counter
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image

# --------------------- CONFIG ---------------------
ROOT = "/kaggle/input/octdl-dataset/OCTDL"            # <- update if folder is in a different place
IMG_SIZE = 224
BATCH_SIZE = 32                    # target per-GPU batch for ViT/DeiT; lower for MaxViT if OOM
NUM_WORKERS = 0                    # Colab T4 sweet spot: 2-4
PIN_MEMORY = True
PERSISTENT_WORKERS = True
VAL_SPLIT = 0.20                   # stratified val split fraction
RANDOM_SEED = 42
CLASS_MAP_JSON = "/content/class_to_idx.json"
THROUGHPUT_BATCHES = 20            # number of batches to sample for a quick throughput test

# Repro
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

# --------------------- Sanity: dataset root & classes ---------------------
root_path = Path(ROOT)
assert root_path.exists(), f"Dataset root not found at {ROOT}. Update ROOT and re-run."

classes = [p.name for p in sorted(root_path.iterdir()) if p.is_dir()]
assert len(classes) > 0, "No class folders found — ensure dataset follows ImageFolder layout."
print("Detected classes (sorted):", classes)

# --------------------- Gather samples and per-class counts ---------------------
IMG_EXTS = (".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp")
class_to_idx = {cls: i for i, cls in enumerate(classes)}
all_samples = []
counts = Counter()

for cls in classes:
    cls_dir = root_path / cls
    files = [p for p in sorted(cls_dir.iterdir()) if p.suffix.lower() in IMG_EXTS]
    counts[cls] = len(files)
    for p in files:
        all_samples.append((str(p), class_to_idx[cls]))

total_images = len(all_samples)
print(f"Total images detected: {total_images}")
print("Per-class counts:")
for k,v in counts.items():
    print(f"  {k}: {v}")

# Save class map (ensures reproducible label ordering)
with open(CLASS_MAP_JSON, "w") as f:
    json.dump(class_to_idx, f, indent=2)
print(f"Saved class->index mapping to: {CLASS_MAP_JSON}")

# --------------------- Stratified split (no sklearn dependency) ---------------------
random.shuffle(all_samples)  # shuffle globally first
train_items, val_items = [], []
per_class_files = {}
for p, label in all_samples:
    per_class_files.setdefault(label, []).append(p)

for label, files in per_class_files.items():
    random.shuffle(files)
    n_val = max(1, int(len(files) * VAL_SPLIT))
    val_files = files[:n_val]
    train_files = files[n_val:]
    # ensure at least one train item per class if very small:
    if len(train_files) == 0 and len(val_files) > 1:
        train_files.append(val_files.pop())
    train_items += [(p, label) for p in train_files]
    val_items   += [(p, label) for p in val_files]

print(f"Train samples: {len(train_items)}  |  Val samples: {len(val_items)}")

# --------------------- Transforms ---------------------
# Use ImageNet normalization for timm pretrained backbones
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

train_tf = transforms.Compose([
    # convert grayscale to RGB if needed; if already RGB this is no-op
    transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.ColorJitter(brightness=0.08, contrast=0.08),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

val_tf = transforms.Compose([
    transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

# --------------------- Dataset wrapper from (path,label) pairs ---------------------
class PathListDataset(Dataset):
    def __init__(self, items, transform=None):
        self.items = items
        self.transform = transform
    def __len__(self):
        return len(self.items)
    def __getitem__(self, idx):
        p, label = self.items[idx]
        img = Image.open(p)
        # For safety, convert to RGB here so transforms get consistent PIL mode
        if img.mode != "RGB":
            img = img.convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label

train_ds = PathListDataset(train_items, transform=train_tf)
val_ds   = PathListDataset(val_items, transform=val_tf)

# --------------------- DataLoaders ---------------------
train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
     prefetch_factor=None
)
val_loader = DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=max(1, NUM_WORKERS//2), pin_memory=PIN_MEMORY,
    prefetch_factor=None
)

# --------------------- Quick sanity checks ---------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# one-batch shape check
try:
    imgs, labels = next(iter(train_loader))
    print("One batch imgs.shape:", imgs.shape, "  labels.shape:", labels.shape)
    assert imgs.ndim == 4 and imgs.shape[1] == 3, "Expected [B,3,H,W] tensors"
    print("Label distribution in this batch:", dict(Counter(labels.tolist())))
except Exception as e:
    print("ERROR while fetching batch (check transforms / disk read):", e)

# throughput test (data + transforms only) over a few batches
print("\nRunning a quick throughput test (data + transforms)...")
start = time.time()
n = 0
for i, (x, y) in enumerate(train_loader):
    n += x.size(0)
    if i+1 >= THROUGHPUT_BATCHES:
        break
elapsed = time.time() - start
print(f"Loaded {n} images in {elapsed:.2f}s -> {n/elapsed:.2f} samples/sec (approx)")

# print sample entries and class map
print("\nSample train items (5):")
for p,l in train_items[:5]:
    print(" ", p, "->", [k for k,v in class_to_idx.items() if v==l][0])

print("\nclass_to_idx mapping saved to:", CLASS_MAP_JSON)
with open(CLASS_MAP_JSON, "w") as f:
    json.dump(class_to_idx, f, indent=2)

print("\nDataloaders ready — proceed to model creation & training loop.")


Detected classes (sorted): ['AMD', 'DME', 'ERM', 'NO', 'RAO', 'RVO', 'VID']
Total images detected: 1618
Per-class counts:
  AMD: 885
  DME: 143
  ERM: 133
  NO: 284
  RAO: 22
  RVO: 93
  VID: 58
Saved class->index mapping to: /content/class_to_idx.json
Train samples: 1298  |  Val samples: 320
Device: cuda
One batch imgs.shape: torch.Size([32, 3, 224, 224])   labels.shape: torch.Size([32])
Label distribution in this batch: {0: 18, 3: 7, 1: 5, 6: 1, 2: 1}

Running a quick throughput test (data + transforms)...
Loaded 640 images in 6.49s -> 98.60 samples/sec (approx)

Sample train items (5):
  /kaggle/input/octdl-dataset/OCTDL/AMD/amd_e_108.jpg -> AMD
  /kaggle/input/octdl-dataset/OCTDL/AMD/amd_l_149.jpg -> AMD
  /kaggle/input/octdl-dataset/OCTDL/AMD/amd_f_91.jpg -> AMD
  /kaggle/input/octdl-dataset/OCTDL/AMD/amd_e_30.jpg -> AMD
  /kaggle/input/octdl-dataset/OCTDL/AMD/amd_l_37.jpg -> AMD

class_to_idx mapping saved to: /content/class_to_idx.json

Dataloaders ready — proceed to model creat

In [13]:
# === Finalize class-imbalance: compute class weights and create weighted criterion ===
# Place this cell before your training cell. It will compute normalized inverse-frequency
# weights and create `criterion` (CrossEntropyLoss) that you can use directly in training.
#
# Requirements: train_items (list of (path,label)) OR train_loader available from previous cells.
# It will save weights to /content/class_weights.json for reproducibility.

import os, json
from collections import Counter
import torch
import math

# ----- Config: label smoothing you used previously (adjust if needed) -----
LABEL_SMOOTHING = 0.1  # keep same training logic as before

# ----- Helper: get per-class counts (robust) -----
def get_train_counts():
    # 1) Preferred: train_items list created by the split cell (list of (path,label))
    if 'train_items' in globals() and isinstance(train_items, (list,tuple)) and len(train_items)>0:
        labels = [lbl for _, lbl in train_items]
        return Counter(labels)
    # 2) If you built a train_loader / train_ds earlier
    if 'train_loader' in globals():
        ds = train_loader.dataset
        if hasattr(ds, "items"):                  # our PathListDataset
            return Counter([lbl for _, lbl in ds.items])
        if hasattr(ds, "targets"):                # torchvision ImageFolder
            return Counter(ds.targets)
        if hasattr(ds, "imgs"):                   # ImageFolder: .imgs = [(path,label),...]
            return Counter([lbl for _, lbl in ds.imgs])
    # 3) Fallback: read saved class_to_idx and count files on disk (last resort)
    try:
        with open('/content/class_to_idx.json','r') as f:
            class_to_idx = json.load(f)
        counts = Counter()
        for cls, idx in class_to_idx.items():
            folder = os.path.join('/content/OCTDL', cls)
            if os.path.isdir(folder):
                files = [fn for fn in os.listdir(folder) if fn.lower().endswith(('.png','.jpg','.jpeg','.tif','.tiff','.bmp'))]
                counts[idx] = len(files)
        if sum(counts.values())>0:
            return counts
    except Exception:
        pass
    raise RuntimeError("Could not infer train counts. Ensure train_items or train_loader.dataset exist or class_to_idx.json and folders are present.")

counts = get_train_counts()
num_classes = len(counts)
print("Per-class counts (label_idx -> count):")
for k in sorted(counts.keys()):
    print(f"  {k}: {counts[k]}")

# ----- Compute inverse-frequency weights (normalized) -----
inv_freq = [0.0] * num_classes
for i in range(num_classes):
    inv_freq[i] = 1.0 / max(1, counts.get(i, 1))

weights = torch.tensor(inv_freq, dtype=torch.float)
# Normalize so average weight ~1 (keeps loss scale comparable to unweighted CE)
weights = weights / weights.mean()

# Save weights for reproducibility
weights_path = "/content/class_weights.json"
with open(weights_path, "w") as f:
    json.dump({"weights": weights.tolist(), "counts": dict(counts)}, f, indent=2)
print(f"\nSaved class weights to {weights_path}")
print("Normalized class weights:", weights.tolist())

# ----- Create criterion (move to device inside training if you prefer) -----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = torch.nn.CrossEntropyLoss(weight=weights.to(device), label_smoothing=LABEL_SMOOTHING)
print(f"Criterion created: CrossEntropyLoss with label_smoothing={LABEL_SMOOTHING} and device={device}")

# Expose for downstream cells
globals()['class_weights_tensor'] = weights
globals()['criterion'] = criterion


Per-class counts (label_idx -> count):
  0: 708
  1: 115
  2: 107
  3: 228
  4: 18
  5: 75
  6: 47

Saved class weights to /content/class_weights.json
Normalized class weights: [0.08672407269477844, 0.5339186191558838, 0.5738378167152405, 0.26930108666419983, 3.411147117614746, 0.8186752796173096, 1.306396722793579]
Criterion created: CrossEntropyLoss with label_smoothing=0.1 and device=cuda


In [14]:
# === Sanity training run (FINAL, robust, Kaggle-safe) ===
# - Assumes train_loader and val_loader are ALREADY DEFINED
# - Works with PathListDataset / ImageFolder / Subset
# - No reliance on dataset internals (.classes, .dataset, json files)
# - AMP (new API), grad clipping, warmup + cosine LR
# - num_workers = 0 compatible
# - Short run by default (epochs=2)

import os, time, math
import numpy as np
import torch
import torch.nn as nn
import timm
import torch.optim as optim
from sklearn.metrics import confusion_matrix, roc_auc_score

# ---------------- CFG ----------------
CFG = {
    "model_name": "deit_base_patch16_224",
    "pretrained": True,
    "epochs": 2,                # sanity run
    "base_lr": 1e-4,
    "weight_decay": 1e-2,
    "accum_steps": 1,
    "grad_clip": 1.0,
    "warmup_pct": 0.05,
    "save_dir": "/kaggle/working/checkpoints_sanity",
    "log_every": 50,
}
os.makedirs(CFG["save_dir"], exist_ok=True)

# ---------------- Device ----------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---------------- Infer num_classes robustly ----------------
def infer_num_classes(ds):
    labels = set()
    for i in range(len(ds)):
        _, y = ds[i]
        labels.add(int(y))
    labels = sorted(labels)
    return labels, len(labels)

label_set, num_classes = infer_num_classes(train_loader.dataset)
print("Detected label indices:", label_set)
print("num_classes =", num_classes)

# ---------------- Model ----------------
print("Creating model:", CFG["model_name"])
model = timm.create_model(
    CFG["model_name"],
    pretrained=CFG["pretrained"],
    num_classes=num_classes
).to(device)

# ---------------- Optimizer & Scheduler ----------------
optimizer = optim.AdamW(
    model.parameters(),
    lr=CFG["base_lr"],
    weight_decay=CFG["weight_decay"]
)

steps_per_epoch = len(train_loader)
total_steps = steps_per_epoch * CFG["epochs"]
warmup_steps = int(total_steps * CFG["warmup_pct"])

def lr_lambda(step):
    if step < warmup_steps:
        return step / max(1, warmup_steps)
    progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
    return 0.5 * (1.0 + math.cos(math.pi * progress))

scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# ---------------- Loss ----------------
# Uses existing class-weighted criterion if present
if "criterion" not in globals():
    print("Criterion not found — using unweighted CE (fallback)")
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
else:
    criterion = globals()["criterion"]

# ---------------- AMP Scaler (NEW API) ----------------
scaler = torch.amp.GradScaler("cuda", enabled=(device.type == "cuda"))

# ---------------- Validation ----------------
def validate(model, loader, criterion):
    model.eval()
    losses, all_probs, all_labels = [], [], []

    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)
            labels = labels.to(device)

            with torch.amp.autocast("cuda", enabled=(device.type == "cuda")):
                out = model(imgs)
                loss = criterion(out, labels)

            losses.append(loss.item() * imgs.size(0))
            all_probs.append(torch.softmax(out, dim=1).cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    all_probs = np.concatenate(all_probs, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    preds = all_probs.argmax(axis=1)

    avg_loss = sum(losses) / len(all_labels)
    acc = (preds == all_labels).mean()

    try:
        auc = roc_auc_score(all_labels, all_probs, multi_class="ovr")
    except Exception:
        auc = float("nan")

    cm = confusion_matrix(all_labels, preds)
    return avg_loss, acc, auc, cm

# ---------------- Training Loop ----------------
best_auc = -1.0

for epoch in range(1, CFG["epochs"] + 1):
    model.train()
    optimizer.zero_grad()

    running_loss, correct, total = 0.0, 0, 0
    epoch_start = time.time()

    for step, (imgs, labels) in enumerate(train_loader):
        imgs = imgs.to(device)
        labels = labels.to(device)

        with torch.amp.autocast("cuda", enabled=(device.type == "cuda")):
            out = model(imgs)
            loss = criterion(out, labels) / CFG["accum_steps"]

        scaler.scale(loss).backward()

        if (step + 1) % CFG["accum_steps"] == 0 or (step + 1) == len(train_loader):
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["grad_clip"])
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            scheduler.step()

        running_loss += loss.item() * CFG["accum_steps"] * imgs.size(0)
        preds = out.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += imgs.size(0)

        if (step + 1) % CFG["log_every"] == 0:
            print(
                f"Epoch {epoch} | step {step+1}/{len(train_loader)} "
                f"| loss {running_loss/total:.4f} | acc {correct/total:.4f}"
            )

    train_loss = running_loss / total
    train_acc = correct / total
    val_loss, val_acc, val_auc, cm = validate(model, val_loader, criterion)
    epoch_time = time.time() - epoch_start

    print(
        f"\nEpoch {epoch} summary: "
        f"train_loss {train_loss:.4f} train_acc {train_acc:.4f} | "
        f"val_loss {val_loss:.4f} val_acc {val_acc:.4f} val_auc {val_auc:.4f} | "
        f"time {epoch_time:.1f}s"
    )
    print("Confusion matrix:")
    print(cm)

    ckpt = {
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optim_state": optimizer.state_dict(),
        "val_auc": val_auc,
    }

    torch.save(ckpt, os.path.join(CFG["save_dir"], f"last_{CFG['model_name']}.pth"))
    if not math.isnan(val_auc) and val_auc > best_auc:
        best_auc = val_auc
        torch.save(ckpt, os.path.join(CFG["save_dir"], f"best_{CFG['model_name']}.pth"))
        print("Saved new best checkpoint.")

print(f"\nSanity run finished. Best val AUC: {best_auc:.4f}")
print(f"Checkpoints saved in: {CFG['save_dir']}")




Device: cuda
Detected label indices: [0, 1, 2, 3, 4, 5, 6]
num_classes = 7
Creating model: vit_base_patch16_224





Epoch 1 summary: train_loss 2.4869 train_acc 0.0755 | val_loss 2.8900 val_acc 0.1656 val_auc 0.8210 | time 26.0s
Confusion matrix:
[[  0   2   0  12 161   1   1]
 [  0   6   0   1  21   0   0]
 [  0   1   3   6  16   0   0]
 [  0   0   0  40  16   0   0]
 [  0   0   0   0   4   0   0]
 [  0   0   0   3  15   0   0]
 [  0   0   0   4   7   0   0]]
Saved new best checkpoint.

Epoch 2 summary: train_loss 1.8608 train_acc 0.5593 | val_loss 2.5436 val_acc 0.6281 val_auc nan | time 25.5s
Confusion matrix:
[[106   1   6  21  22  19   2]
 [  0  17   3   0   0   8   0]
 [  1   0  17   6   1   1   0]
 [  0   0   8  46   2   0   0]
 [  0   0   0   0   4   0   0]
 [  1   0   6   1   0   9   1]
 [  1   0   3   5   0   0   2]]

Sanity run finished. Best val AUC: 0.8210
Checkpoints saved in: /kaggle/working/checkpoints_sanity


class-weighted loss helped stability, but rare-classes can still be underexposed. Next logical step: try oversampling (WeightedRandomSampler) while keeping class weights as a fallback. Also add targeted augmentation for minority classes and/or try focal loss if oversampling overfits.

In [15]:
# === Replace train_loader with WeightedRandomSampler (oversampling) ===
from collections import Counter
import torch
from torch.utils.data import WeightedRandomSampler, DataLoader

# CONFIG
OVERSAMPLE_MULTIPLIER = 1.0   # try 1.0 first; increase to 1.5 or 2.0 if still underexposed
REPLACEMENT = True
BATCH_SIZE = getattr(train_loader, "batch_size", 32)
NUM_WORKERS = getattr(train_loader, "num_workers", 0)
PIN_MEMORY = getattr(train_loader, "pin_memory", True)
PERSISTENT = getattr(train_loader, "persistent_workers", True)

# get per-sample labels in train order (train_items from earlier split)
if 'train_items' in globals() and len(train_items)>0:
    sample_labels = [lbl for _, lbl in train_items]
else:
    # fallback: try to read dataset attribute
    ds = train_loader.dataset
    if hasattr(ds, "items"):
        sample_labels = [lbl for _, lbl in ds.items]
    elif hasattr(ds, "targets"):
        sample_labels = list(ds.targets)
    elif hasattr(ds, "imgs"):
        sample_labels = [label for _, label in ds.imgs]
    else:
        raise RuntimeError("No train_items or dataset label list found. Ensure train split exists.")

# compute counts (label -> freq)
counts = Counter(sample_labels)
print("Train label counts (before sampler):", dict(counts))

# per-sample weight = 1 / class_count[label]
sample_weights = [1.0 / counts[int(l)] for l in sample_labels]
sample_weights_tensor = torch.DoubleTensor(sample_weights)

# compute num_samples for one epoch
num_samples = int(len(sample_labels) * OVERSAMPLE_MULTIPLIER)
sampler = WeightedRandomSampler(weights=sample_weights_tensor, num_samples=num_samples, replacement=REPLACEMENT)

# create new train_loader that uses the sampler (disable shuffle)
new_train_loader = DataLoader(
    train_loader.dataset,
    batch_size=BATCH_SIZE,
    sampler=sampler,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    persistent_workers=PERSISTENT,
    prefetch_factor=None
)

# Quick check: show label distribution over first N batches
from collections import Counter
tmp_cnt = Counter()
for i, (_, lbls) in enumerate(new_train_loader):
    tmp_cnt.update(lbls.tolist())
    if i >= 30:   # sample ~30 batches
        break
print("Sampler-sampled label counts (approx over ~30 batches):", dict(tmp_cnt))

# Replace global train_loader
train_loader = new_train_loader
print("train_loader replaced with WeightedRandomSampler. num_samples per epoch:", num_samples)


Train label counts (before sampler): {0: 708, 2: 107, 3: 228, 5: 75, 6: 47, 1: 115, 4: 18}
Sampler-sampled label counts (approx over ~30 batches): {3: 145, 2: 150, 0: 129, 6: 146, 4: 144, 1: 161, 5: 117}
train_loader replaced with WeightedRandomSampler. num_samples per epoch: 1298


In [18]:
# === Training run using WeightedRandomSampler (5 epochs) ===
# - Detects sampler on train_loader and chooses unweighted CE automatically.
# - Prints per-epoch summary and per-class metrics.
# - Safe: will use existing model if present, otherwise instantiates vit_base_patch16_224 pretrained and trains.

import os, time, math, json
import numpy as np
from collections import Counter
import torch, torch.nn as nn, torch.optim as optim
import timm
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# ---------- CFG ----------
EPOCHS = 5
LR = 1e-4
WD = 1e-2
ACCUM_STEPS = 1
GRAD_CLIP = 1.0
WARMUP_PCT = 0.05
SAVE_DIR = "/content/checkpoints_sampler"
os.makedirs(SAVE_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---------- Model: reuse if exists, else create ----------
if 'model' in globals() and getattr(globals()['model'], 'parameters', None) is not None:
    print("Reusing existing model in workspace.")
    model = globals()['model']
else:
    print("No model found in workspace — creating fresh model.")
    model_name = globals().get('CFG',{}).get('model_name','vit_base_patch16_224')
    num_classes = len(json.load(open("/content/class_to_idx.json")))
    model = timm.create_model(model_name, pretrained=True, num_classes=num_classes)
model = model.to(device)

# ---------- Criterion: choose automatically ----------
# If train_loader has a sampler (oversampling), use unweighted CE; otherwise, use precomputed `criterion` if present.
use_weighted = False
try:
    sampler_present = getattr(train_loader, 'sampler', None) is not None
    if sampler_present:
        print("Detected sampler on train_loader -> using UNWEIGHTED CrossEntropyLoss (no class weights).")
        criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    else:
        if 'criterion' in globals():
            print("No sampler detected -> using existing `criterion` from workspace (likely weighted).")
            criterion = globals()['criterion']
        else:
            criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
except Exception:
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

criterion = criterion.to(device)

# ---------- Optimizer, scheduler, scaler ----------
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)

steps_per_epoch = math.ceil(len(train_loader) / 1.0)
total_steps = steps_per_epoch * EPOCHS
warmup_steps = int(total_steps * WARMUP_PCT)

def lr_lambda(step):
    if step < warmup_steps:
        return float(step) / max(1, warmup_steps)
    progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
    return 0.5 * (1.0 + math.cos(math.pi * progress))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
scaler = torch.cuda.amp.GradScaler(enabled=(device.type=="cuda"))

# ---------- Helpers: validate (returns per-class arrays and prints table) ----------
with open("/content/class_to_idx.json","r") as f:
    class_to_idx = json.load(f)
idx_to_class = {int(v): k for k,v in class_to_idx.items()}

def validate_and_report(model, loader, criterion):
    model.eval()
    losses = []
    all_logits = []
    all_labels = []
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            out = model(imgs)
            loss = criterion(out, labels)
            losses.append(loss.item() * imgs.size(0))
            probs = torch.softmax(out, dim=1).cpu().numpy()
            all_logits.append(probs)
            all_labels.append(labels.cpu().numpy())
    if len(all_labels) == 0:
        return {}, {}
    all_logits = np.concatenate(all_logits, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    preds = all_logits.argmax(axis=1)
    avg_loss = sum(losses) / all_labels.shape[0]
    acc = (preds == all_labels).mean()
    try:
        macro_auc = roc_auc_score(all_labels, all_logits, multi_class='ovr')
    except Exception:
        macro_auc = float('nan')

    # per-class PRF and AUC
    prec, rec, f1, support = precision_recall_fscore_support(all_labels, preds, labels=list(range(all_logits.shape[1])), zero_division=0)
    per_class = []
    for i in range(all_logits.shape[1]):
        try:
            y_true = (all_labels == i).astype(int)
            auc = roc_auc_score(y_true, all_logits[:, i])
        except Exception:
            auc = float('nan')
        per_class.append({
            "class_idx": i, "class_name": idx_to_class[i],
            "support": int(support[i]), "precision": float(prec[i]),
            "recall": float(rec[i]), "f1": float(f1[i]), "auc": float(auc)
        })

    cm = confusion_matrix(all_labels, preds, labels=list(range(all_logits.shape[1])))
    metrics = {"loss": avg_loss, "acc": acc, "macro_auc": macro_auc, "confusion_matrix": cm, "per_class": per_class}
    return metrics, (all_labels, all_logits)

# ---------- Training loop ----------
best_auc = -1.0
global_step = 0
for epoch in range(1, EPOCHS+1):
    model.train()
    epoch_loss = 0.0
    correct = 0
    total = 0
    t0 = time.time()
    optimizer.zero_grad()
    for step, (imgs, labels) in enumerate(train_loader):
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):
            out = model(imgs)
            loss = criterion(out, labels) / ACCUM_STEPS
        scaler.scale(loss).backward()
        if (step + 1) % ACCUM_STEPS == 0 or (step + 1) == len(train_loader):
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=GRAD_CLIP)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            scheduler.step()
            global_step += 1
        epoch_loss += loss.item() * ACCUM_STEPS * imgs.size(0)
        preds = out.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += imgs.size(0)

    train_loss = epoch_loss / total if total>0 else float('nan')
    train_acc = correct / total if total>0 else float('nan')

    # validate
    metrics, preds_pack = validate_and_report(model, val_loader, criterion)
    print(f"\nEpoch {epoch} | train_loss {train_loss:.4f} train_acc {train_acc:.4f} | val_loss {metrics['loss']:.4f} val_acc {metrics['acc']:.4f} val_auc {metrics['macro_auc']:.4f} | epoch_time {time.time()-t0:.1f}s")
    print("Confusion matrix:")
    print(metrics['confusion_matrix'])

    # print per-class table
    print("\nPer-class results:")
    print(f"{'class':>8} | {'supp':>4} | {'prec':>5} | {'rec':>5} | {'f1':>5} | {'auc':>5}")
    print("-"*50)
    for row in metrics['per_class']:
        print(f"{row['class_name']:>8} | {row['support']:4d} | {row['precision']:5.3f} | {row['recall']:5.3f} | {row['f1']:5.3f} | {row['auc']:5.3f}")

    # save checkpoint
    ckpt = {'epoch': epoch, 'model_state': model.state_dict(), 'optim_state': optimizer.state_dict(), 'val_auc': metrics['macro_auc']}
    torch.save(ckpt, os.path.join(SAVE_DIR, f"epoch{epoch}_sampler.pth"))
    if not (isinstance(metrics['macro_auc'], float) and math.isnan(metrics['macro_auc'])) and metrics['macro_auc'] > best_auc:
        best_auc = metrics['macro_auc']
        torch.save(ckpt, os.path.join(SAVE_DIR, f"best_sampler.pth"))
        print("Saved new best checkpoint.")

print(f"\nDone. Best val AUC in this run: {best_auc:.4f}. Checkpoints: {SAVE_DIR}")
# expose model and latest preds for inspection
globals()['model'] = model
globals()['val_all'] = preds_pack


Device: cuda
Reusing existing model in workspace.
Detected sampler on train_loader -> using UNWEIGHTED CrossEntropyLoss (no class weights).


  scaler = torch.cuda.amp.GradScaler(enabled=(device.type=="cuda"))
  with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):



Epoch 1 | train_loss 0.6059 train_acc 0.9307 | val_loss 0.7612 val_acc 0.8656 val_auc 0.9540 | epoch_time 30.1s
Confusion matrix:
[[169   0   1   6   0   1   0]
 [  0  25   0   0   0   2   1]
 [  3   4  13   2   2   2   0]
 [  4   0   1  49   0   2   0]
 [  0   0   0   0   4   0   0]
 [  1   5   0   3   0   9   0]
 [  1   0   0   1   0   1   8]]

Per-class results:
   class | supp |  prec |   rec |    f1 |   auc
--------------------------------------------------
     AMD |  177 | 0.949 | 0.955 | 0.952 | 0.982
     DME |   28 | 0.735 | 0.893 | 0.806 | 0.987
     ERM |   26 | 0.867 | 0.500 | 0.634 | 0.963
      NO |   56 | 0.803 | 0.875 | 0.838 | 0.979
     RAO |    4 | 0.667 | 1.000 | 0.800 | 1.000
     RVO |   18 | 0.529 | 0.500 | 0.514 | 0.875
     VID |   11 | 0.889 | 0.727 | 0.800 | 0.891
Saved new best checkpoint.


  with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):



Epoch 2 | train_loss 0.5520 train_acc 0.9584 | val_loss 0.9719 val_acc 0.8219 val_auc 0.9704 | epoch_time 30.4s
Confusion matrix:
[[148   1  18   5   0   5   0]
 [  0  23   3   0   0   1   1]
 [  0   0  26   0   0   0   0]
 [  1   0  13  41   0   1   0]
 [  0   0   0   0   4   0   0]
 [  0   2   1   3   0  12   0]
 [  1   0   0   1   0   0   9]]

Per-class results:
   class | supp |  prec |   rec |    f1 |   auc
--------------------------------------------------
     AMD |  177 | 0.987 | 0.836 | 0.905 | 0.968
     DME |   28 | 0.885 | 0.821 | 0.852 | 0.988
     ERM |   26 | 0.426 | 1.000 | 0.598 | 0.980
      NO |   56 | 0.820 | 0.732 | 0.774 | 0.915
     RAO |    4 | 1.000 | 1.000 | 1.000 | 1.000
     RVO |   18 | 0.632 | 0.667 | 0.649 | 0.952
     VID |   11 | 0.900 | 0.818 | 0.857 | 0.989
Saved new best checkpoint.


  with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):



Epoch 3 | train_loss 0.5444 train_acc 0.9607 | val_loss 0.7036 val_acc 0.9000 val_auc 0.9869 | epoch_time 31.1s
Confusion matrix:
[[163   0   1  11   0   1   1]
 [  0  23   1   0   0   4   0]
 [  0   0  25   1   0   0   0]
 [  0   0   4  52   0   0   0]
 [  0   0   0   0   4   0   0]
 [  1   2   0   3   0  12   0]
 [  1   0   0   1   0   0   9]]

Per-class results:
   class | supp |  prec |   rec |    f1 |   auc
--------------------------------------------------
     AMD |  177 | 0.988 | 0.921 | 0.953 | 0.978
     DME |   28 | 0.920 | 0.821 | 0.868 | 0.995
     ERM |   26 | 0.806 | 0.962 | 0.877 | 0.997
      NO |   56 | 0.765 | 0.929 | 0.839 | 0.981
     RAO |    4 | 1.000 | 1.000 | 1.000 | 1.000
     RVO |   18 | 0.706 | 0.667 | 0.686 | 0.967
     VID |   11 | 0.900 | 0.818 | 0.857 | 0.991
Saved new best checkpoint.


  with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):



Epoch 4 | train_loss 0.4689 train_acc 0.9923 | val_loss 0.6757 val_acc 0.9031 val_auc 0.9904 | epoch_time 30.8s
Confusion matrix:
[[163   0   1  10   0   2   1]
 [  0  24   1   0   0   1   2]
 [  0   1  23   2   0   0   0]
 [  1   0   1  54   0   0   0]
 [  0   0   0   0   4   0   0]
 [  0   3   0   3   0  12   0]
 [  1   0   0   1   0   0   9]]

Per-class results:
   class | supp |  prec |   rec |    f1 |   auc
--------------------------------------------------
     AMD |  177 | 0.988 | 0.921 | 0.953 | 0.986
     DME |   28 | 0.857 | 0.857 | 0.857 | 0.995
     ERM |   26 | 0.885 | 0.885 | 0.885 | 0.996
      NO |   56 | 0.771 | 0.964 | 0.857 | 0.984
     RAO |    4 | 1.000 | 1.000 | 1.000 | 1.000
     RVO |   18 | 0.800 | 0.667 | 0.727 | 0.975
     VID |   11 | 0.750 | 0.818 | 0.783 | 0.997
Saved new best checkpoint.


  with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):


KeyboardInterrupt: 

In [None]:
# === Focused training (CLEAN, portable, no Colab assumptions) ===

import os, time, math
import torch
import torch.nn as nn
import torch.optim as optim
import timm

# ---------------- CONFIG ----------------
CFG = {
    "model_name": "deit_base_patch16_224",
    "pretrained": True,

    "max_epochs": 40,
    "target_val_acc": 0.95,

    "lr": 1e-4,
    "weight_decay": 1e-2,

    "accum_steps": 8,
    "grad_clip": 0.5,
    "warmup_pct": 0.06,

    "num_workers": 0,
    "label_smoothing": 0.1,

    "input_size": 384,
    "save_dir": "/kaggle/working/checkpoints"
}

os.makedirs(CFG["save_dir"], exist_ok=True)

# ---------------- Device ----------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---------------- Infer num_classes robustly ----------------
def infer_num_classes(ds):
    labels = set()
    for i in range(len(ds)):
        _, y = ds[i]
        labels.add(int(y))
    return len(labels)

num_classes = infer_num_classes(train_loader.dataset)
print("Detected num_classes:", num_classes)

# ---------------- Model (safe reuse) ----------------
reuse_ok = (
    "model" in globals()
    and isinstance(globals()["model"], torch.nn.Module)
    and getattr(globals()["model"], "num_classes", None) == num_classes
)

if reuse_ok:
    model = globals()["model"]
    print("Reusing existing compatible model.")
else:
    print("Creating new model:", CFG["model_name"])
    model = timm.create_model(
        CFG["model_name"],
        pretrained=CFG["pretrained"],
        num_classes=num_classes
    )

model = model.to(device)

# ---------------- Rebuild DataLoaders (workers=0 safe) ----------------
from torch.utils.data import DataLoader

train_loader = DataLoader(
    train_loader.dataset,
    batch_size=train_loader.batch_size,
    sampler=train_loader.sampler if hasattr(train_loader, "sampler") else None,
    shuffle=False if hasattr(train_loader, "sampler") else True,
    num_workers=0,
    pin_memory=True
)

val_loader = DataLoader(
    val_loader.dataset,
    batch_size=val_loader.batch_size,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

# ---------------- Criterion ----------------
# If sampler is used, DO NOT weight loss again
criterion = nn.CrossEntropyLoss(label_smoothing=CFG["label_smoothing"]).to(device)

# ---------------- Optimizer / Scheduler / AMP ----------------
optimizer = optim.AdamW(
    model.parameters(),
    lr=CFG["lr"],
    weight_decay=CFG["weight_decay"]
)

steps_per_epoch = len(train_loader)
total_steps = steps_per_epoch * CFG["max_epochs"]
warmup_steps = int(total_steps * CFG["warmup_pct"])

def lr_lambda(step):
    if step < warmup_steps:
        return step / max(1, warmup_steps)
    progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
    return 0.5 * (1.0 + math.cos(math.pi * progress))

scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
scaler = torch.amp.GradScaler("cuda", enabled=(device.type == "cuda"))

# ---------------- Quick validation ----------------
@torch.no_grad()
def quick_validate(model, loader):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    for imgs, labels in loader:
        imgs = imgs.to(device)
        labels = labels.to(device)
        out = model(imgs)
        loss = criterion(out, labels)
        total_loss += loss.item() * imgs.size(0)
        correct += (out.argmax(1) == labels).sum().item()
        total += imgs.size(0)
    return total_loss / total, correct / total

# ---------------- Training loop ----------------
best_val_acc = -1.0
best_epoch = -1

for epoch in range(1, CFG["max_epochs"] + 1):
    model.train()
    optimizer.zero_grad()

    epoch_loss, correct, total = 0.0, 0, 0
    t0 = time.time()

    for step, (imgs, labels) in enumerate(train_loader):
        imgs = imgs.to(device)
        labels = labels.to(device)

        with torch.amp.autocast("cuda", enabled=(device.type == "cuda")):
            out = model(imgs)
            loss = criterion(out, labels) / CFG["accum_steps"]

        scaler.scale(loss).backward()

        if (step + 1) % CFG["accum_steps"] == 0 or (step + 1) == len(train_loader):
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["grad_clip"])
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            scheduler.step()

        epoch_loss += loss.item() * CFG["accum_steps"] * imgs.size(0)
        correct += (out.argmax(1) == labels).sum().item()
        total += imgs.size(0)

    train_loss = epoch_loss / total
    train_acc = correct / total

    val_loss, val_acc = quick_validate(model, val_loader)
    dt = time.time() - t0

    print(
        f"Epoch {epoch}/{CFG['max_epochs']} | "
        f"time {dt:.1f}s | "
        f"train_loss {train_loss:.4f} train_acc {train_acc:.4f} | "
        f"val_loss {val_loss:.4f} val_acc {val_acc:.4f}"
    )

    # Save last
    torch.save(
        {
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optim_state": optimizer.state_dict(),
            "val_acc": val_acc,
            "cfg": CFG,
        },
        os.path.join(CFG["save_dir"], f"last_{CFG['model_name']}.pth")
    )

    # Save best
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_epoch = epoch
        best_path = os.path.join(CFG["save_dir"], f"best_{CFG['model_name']}.pth")
        torch.save(
            {
                "epoch": epoch,
                "model_state": model.state_dict(),
                "optim_state": optimizer.state_dict(),
                "val_acc": val_acc,
                "cfg": CFG,
            },
            best_path
        )
        print("Saved new best checkpoint.")

    # Early stop
    if val_acc >= CFG["target_val_acc"]:
        print(f"Target val_acc {CFG['target_val_acc']:.2f} reached at epoch {epoch}. Stopping.")
        break

globals()["model"] = model
print(f"Finished. Best val_acc: {best_val_acc:.4f} at epoch {best_epoch}")
print(f"Checkpoints saved in: {CFG['save_dir']}")


In [None]:
# === CELL 1: Evaluation & Metrics ===
import os, json, numpy as np, torch
from sklearn.metrics import precision_recall_fscore_support, classification_report

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---------- class names ----------
if 'class_names' in globals() and isinstance(class_names, (list,tuple)):
    names = list(class_names)
else:
    try:
        with open("/content/class_to_idx.json","r") as f:
            ct = json.load(f)
        names = [None]*len(ct)
        for k,v in ct.items():
            names[int(v)] = k
    except Exception:
        names = train_loader.dataset.classes

num_classes = len(names)
print("Classes:", names)

# ---------- load best checkpoint ----------
ckpt_path = os.path.join("/kaggle/working/checkpoints", f"best_{CFG['model_name']}.pth")
ckpt = torch.load(ckpt_path, map_location=device)
model.load_state_dict(ckpt["model_state"], strict=False)
model.to(device).eval()

print(f"Loaded best checkpoint | epoch={ckpt.get('epoch')} val_acc={ckpt.get('val_acc')}")

# ---------- inference ----------
y_true, y_pred, probs_list = [], [], []
with torch.no_grad():
    for imgs, labels in val_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        logits = model(imgs)
        probs = torch.softmax(logits, dim=1).cpu().numpy()
        preds = logits.argmax(dim=1).cpu().numpy()
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds)
        probs_list.append(probs)

y_true = np.array(y_true)
y_pred = np.array(y_pred)
probs_all = np.vstack(probs_list)

# ---------- metrics ----------
acc = (y_true == y_pred).mean()
print(f"\nValidation accuracy: {acc:.4f}\n")

precision, recall, f1, support = precision_recall_fscore_support(
    y_true, y_pred, labels=range(num_classes), zero_division=0
)

print(classification_report(y_true, y_pred, target_names=names, digits=4))


In [None]:
# === CELL 2: Confusion Matrix ===
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

cm = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(7,6))
sns.heatmap(
    cm, annot=True, fmt="d", cmap="Blues",
    xticklabels=names, yticklabels=names
)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.show()


In [None]:
# === CELL 3: Training Dynamics & Transformer Plots ===
import matplotlib.pyplot as plt
import numpy as np

# -------------------------------------------------
# Plot 1: Validation Accuracy vs Epoch
# -------------------------------------------------
epochs, val_accuracies = None, None

if "history" in globals():
    epochs = [h["epoch"] for h in history]
    val_accuracies = [h["val_acc"] for h in history]
else:
    try:
        val_accuracies = np.load("val_accuracies.npy")
        epochs = np.arange(1, len(val_accuracies) + 1)
    except Exception:
        print("No validation history found — skipping val acc curve.")

if epochs is not None:
    plt.figure(figsize=(6,4))
    plt.plot(epochs, val_accuracies, marker="o")
    plt.axhline(0.95, color="red", linestyle="--", label="Target 95%")
    plt.xlabel("Epoch")
    plt.ylabel("Validation Accuracy")
    plt.title("Validation Accuracy vs Epoch")
    plt.grid(True)
    plt.legend()
    plt.show()

# -------------------------------------------------
# Plot 2: Per-class Recall (MOST IMPORTANT)
# -------------------------------------------------
plt.figure(figsize=(8,4))
plt.bar(names, recall)
plt.ylabel("Recall")
plt.title("Per-class Recall (Sensitivity)")
plt.ylim(0, 1.05)
plt.grid(axis="y")
plt.show()

# -------------------------------------------------
# Plot 3: Transformer Confidence Histogram
# -------------------------------------------------
max_conf = probs_all.max(axis=1)

plt.figure(figsize=(6,4))
plt.hist(max_conf, bins=20, edgecolor="black")
plt.xlabel("Max Softmax Probability")
plt.ylabel("Number of Samples")
plt.title("Prediction Confidence Distribution")
plt.grid(True)
plt.show()
