<a href="https://colab.research.google.com/github/shlokoyo123/Diabetic-Retinopathy/blob/main/DR_Code_Updated.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# === DR (GPU/Colab) — CELL 1: Train + Save + Batch Inference + Download ===
# - Runs the full pipeline
# - Saves models/scalers to disk for UI cell
# - Runs batch inference on uploaded images
# - Generates & downloads artifacts
# - No Gradio here

# ------------------ CONFIG ------------------
from google.colab import drive
drive.mount('/content/drive')

import os, io, json, base64, zipfile, hashlib, random, time, math, shutil, sys, re
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image, ImageOps, ImageDraw

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import models
import torchvision.transforms.functional as TF
from torchvision.transforms import InterpolationMode
import torchvision.transforms as T

from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, cohen_kappa_score
from tqdm.auto import tqdm
from collections import Counter
import joblib  # for persisting sklearn objects

# --- Paths & knobs ---
BASE_UNZIP_DIR      = "/content/train_subset_unzipped"
IMAGES_DIR_CONFIG   = "/content/train_subset_unzipped/train_subset"   # leave blank to auto-detect
LABELS_CSV          = "/content/trainLabels.csv"
ZIP_PATH            = "/content/drive/MyDrive/train_subset.zip"       # ok if missing
ROOT_OUT            = "/content/dr_one_cell"
FILES_DIR           = str(Path(ROOT_OUT)/"public")
MODELS_DIR          = str(Path(ROOT_OUT)/"models")

# Classification granularity
BINARY_MODE         = True   # True: collapse levels >=1 to 1 (referable DR). Helps accuracy on limited data.

IMAGE_SIZE          = 224
#VAL_FRAC            = 0.40
#VAL_FRAC            = 0.20 --> RESNET WAS 74.19
#VAL_FRAC            = 0.15 --> RESNET WAS 68...
VAL_FRAC            = 0.20
SEED                = 1337

# Speed/quality
#MAX_TRAIN_PER_CLASS = 800
MAX_TRAIN_PER_CLASS = 1500
BATCH_SIZE          = 64

# Autoencoder (improved)
AE_EPOCHS             = 0
AE_LATENT_DIM         = 384
AE_NOISE_STD          = 0.03
AE_WEIGHT_DECAY       = 1e-5
AE_LR                 = 1e-3
AE_SUP_HEAD_EPOCHS    = 0
AE_SUP_LR             = 2e-3
COLOR_JITTER_STRENGTH = 0.12

# Fine-tune (ResNet50 only here; Cell 2 is robust if FT ckpts aren't present)
FT_STAGE1_EPOCHS      = 6
FT_STAGE2_EPOCHS      = 10
FT_BASE_LR            = 2e-4

AUTO_DOWNLOAD         = True  # download artifacts at the end of this cell

# ------------------ SETUP ------------------
torch.manual_seed(SEED); np.random.seed(SEED); random.seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
Path(ROOT_OUT).mkdir(parents=True, exist_ok=True)
Path(FILES_DIR).mkdir(parents=True, exist_ok=True)
Path(MODELS_DIR).mkdir(parents=True, exist_ok=True)

# ------------------ SAFE UNZIP ------------------
import zipfile
def safe_unzip(zip_path: str, out_dir: str) -> bool:
    if not zip_path or not os.path.exists(zip_path):
        print(f"[info] No zip at {zip_path} (skipping unzip)."); return False
    if os.path.isdir(zip_path):
        print(f"[info] ZIP_PATH is a directory: {zip_path} — skipping unzip."); return False
    try:
        if not zipfile.is_zipfile(zip_path):
            print(f"[warn] Exists but not a valid .zip: {zip_path} — skipping unzip."); return False
        os.makedirs(out_dir, exist_ok=True)
        with zipfile.ZipFile(zip_path, "r") as z: z.extractall(out_dir)
        print(f"Extracted {zip_path} to {out_dir}"); return True
    except Exception as e:
        print(f"[warn] Unzip failed: {e} — skipping unzip."); return False

_ = safe_unzip(ZIP_PATH, BASE_UNZIP_DIR)

# ------------------ IMAGE DISCOVERY ------------------
def list_images(root):
    exts = {".jpg",".jpeg",".png",".tif",".tiff",".bmp"}
    root = Path(root)
    if not root.exists(): return []
    return [str(p) for p in root.rglob("*") if p.suffix.lower() in exts]

def dir_has_images(d):
    try:
        return d and Path(d).exists() and any(p.suffix.lower() in {".jpg",".jpeg",".png",".tif",".tiff",".bmp"} for p in Path(d).rglob("*"))
    except Exception:
        return False

def find_image_dir(base_dir: str, min_images: int = 1):
    exts = {".jpg",".jpeg",".png",".tif",".tiff",".bmp"}
    base = Path(base_dir)
    if not base.exists(): return None
    if any(p.is_file() and p.suffix.lower() in exts for p in base.iterdir() if p.is_file()):
        return str(base)
    for sub in base.iterdir():
        if sub.is_dir() and any(q.is_file() and q.suffix.lower() in exts for q in sub.iterdir() if q.is_file()):
            return str(sub)
    imgs = [p for p in base.rglob("*") if p.is_file() and p.suffix.lower() in exts]
    if len(imgs) >= min_images: return str(imgs[0].parent)
    return None

# ------------------ SYNTHETIC FALLBACK ------------------
def make_synth_dataset(root_dir, labels_csv, n_per_class=40, size=300):
    rnd = random.Random(SEED)
    root = Path(root_dir) / "train_subset"
    root.mkdir(parents=True, exist_ok=True)
    rows = []
    for cls in range(5):
        for i in range(n_per_class):
            name = f"{cls}_{i:03d}"
            img = Image.new("RGB", (size, size), (rnd.randint(10,30), rnd.randint(10,30), rnd.randint(10,30)))
            drw = ImageDraw.Draw(img)
            r = rnd.randint(16, 28)
            cx, cy = rnd.randint(r+10, size-r-10), rnd.randint(r+10, size-r-10)
            drw.ellipse([cx-r, cy-r, cx+r, cy+r], outline=(250,240,220), width=3)
            for _ in range(cls * 8 + rnd.randint(0,5)):
                x, y = rnd.randint(8, size-8), rnd.randint(8, size-8)
                rad = rnd.randint(2, 5)
                col = (rnd.randint(180,255), rnd.randint(50,120), rnd.randint(50,120))
                drw.ellipse([x-rad, y-rad, x+rad, y+rad], fill=col)
            p = root / f"{name}.jpg"
            img.save(p, quality=92)
            rows.append({"image": name, "level": cls})
    pd.DataFrame(rows).to_csv(labels_csv, index=False)
    print(f"[synth] Wrote {len(rows)} images to {root} and labels to {labels_csv}")
    return str(root)

# Resolve images dir or synthesize
IMAGES_DIR = IMAGES_DIR_CONFIG
if not dir_has_images(IMAGES_DIR):
    guess = find_image_dir(BASE_UNZIP_DIR, min_images=1)
    if guess:
        print(f"[auto-detect] Using images from: {guess}")
        IMAGES_DIR = guess
if not dir_has_images(IMAGES_DIR):
    IMAGES_DIR = make_synth_dataset(BASE_UNZIP_DIR, LABELS_CSV, n_per_class=40, size=300)

print("[debug] IMAGES_DIR =", IMAGES_DIR)

# ------------------ LOAD LABELS & MATCH ------------------
labels = pd.read_csv(LABELS_CSV)
assert {"image","level"}.issubset(labels.columns), "LABELS_CSV must have columns: image, level"
labels["image"] = labels["image"].astype(str)
labels["level"] = labels["level"].astype(int)
if BINARY_MODE:
    labels["level"] = (labels["level"] >= 1).astype(int)

label_map = dict(zip(labels["image"], labels["level"]))
def stem_only(p): return Path(p).stem

rows = []
all_imgs = list_images(IMAGES_DIR)
for p in all_imgs:
    stem = stem_only(p)
    lvl = label_map.get(stem)
    if lvl is not None:
        rows.append({"filename": p, "level": int(lvl)})
df_all = pd.DataFrame(rows)
if df_all.empty:
    print("[fatal] No matched images. Check filename stems vs CSV 'image' column.")
    raise SystemExit(1)

# ------------------ PREPROCESS ------------------
def crop_black(im, thresh=8):
    g = im.convert("L")
    bw = g.point(lambda x: 255 if x>thresh else 0, mode="1")
    bbox = bw.getbbox()
    return im.crop(bbox) if bbox else im

def preprocess_all(df_all, proc_dir, size):
    Path(proc_dir).mkdir(parents=True, exist_ok=True)
    out_paths = []
    print(f"Preprocessing {len(df_all)} images to {size}x{size}...")
    for p in tqdm(df_all["filename"]):
        outp = str(Path(proc_dir)/Path(p).name)
        if not Path(outp).exists():
            try:
                with Image.open(p) as im:
                    im = ImageOps.exif_transpose(im).convert("RGB")
                    im = crop_black(im)
                    im = ImageOps.autocontrast(im, cutoff=1)          # less aggressive than default
                    im = ImageOps.equalize(im)                        # boosts vessel contrast
                    # center-crop square then resize (fundus is circular; corners are junk)
                    w,h = im.size
                    s = min(w,h)
                    im = im.crop(((w-s)//2, (h-s)//2, (w+s)//2, (h+s)//2))
                    im = im.resize((size,size))
                    im.save(outp, quality=95)
            except Exception as e:
                print("[pre] skip", p, e)
                continue
        out_paths.append(outp)
    df = df_all.copy()
    df["proc_path"] = out_paths
    return df

PROC_DIR = Path(ROOT_OUT)/f"processed_{IMAGE_SIZE}"
df_all = preprocess_all(df_all, PROC_DIR, IMAGE_SIZE)
df_all.to_csv(Path(ROOT_OUT)/"_matched_files.csv", index=False)

# ------------------ SPLIT ------------------
sss = StratifiedShuffleSplit(n_splits=1, test_size=VAL_FRAC, random_state=SEED)
idx_tr, idx_va = next(sss.split(df_all["proc_path"], df_all["level"]))
DF_TR = df_all.iloc[idx_tr].reset_index(drop=True)
DF_VA = df_all.iloc[idx_va].reset_index(drop=True)
print(f"[split] train={len(DF_TR)} val={len(DF_VA)}")

# ------------------ CACHED TENSORS ------------------
MEAN = torch.tensor([0.485,0.456,0.406]).view(3,1,1)
STD  = torch.tensor([0.229,0.224,0.225]).view(3,1,1)
CACHE_DIR = Path(ROOT_OUT)/"_tensor_cache"

def pil_to_tensor_resized(path, size):
    with Image.open(path) as im:
        im = ImageOps.exif_transpose(im).convert("RGB").resize((size,size))
        arr = np.asarray(im, dtype=np.float32)/255.0
        arr = np.transpose(arr, (2,0,1))
        t = torch.from_numpy(arr)
        t = (t - MEAN) / STD
        return t

def cache_key(path, size):
    b = Path(path).name
    mtime = int(Path(path).stat().st_mtime)
    return hashlib.sha1(f"{b}|{size}|{mtime}".encode()).hexdigest()+".pt"

def cached_tensor_for(path, size, cache_dir):
    cache_dir = Path(cache_dir); cache_dir.mkdir(parents=True, exist_ok=True)
    fp = cache_dir/cache_key(path, size)
    if fp.exists():
        try: return torch.load(fp, map_location="cpu")
        except Exception:
            try: fp.unlink()
            except: pass
    t = pil_to_tensor_resized(path, size)
    torch.save(t, fp)
    return t

class CachedTensorDS(Dataset):
    def __init__(self, df, size, cache_dir):
        self.df = df.reset_index(drop=True); self.size=size; self.cache_dir=cache_dir
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r = self.df.iloc[i]
        x = cached_tensor_for(r.proc_path, self.size, self.cache_dir)
        y = int(r.level)
        return x, y

def make_loader(df, bs, shuffle=False, sampler=None):
    return DataLoader(CachedTensorDS(df, IMAGE_SIZE, CACHE_DIR),
                      batch_size=bs, shuffle=(sampler is None and shuffle),
                      sampler=sampler, num_workers=0, pin_memory=False)

# ------------------ CAP + SAMPLER ------------------
def stratified_cap(df, label_col="level", cap=MAX_TRAIN_PER_CLASS, seed=SEED):
    if cap is None or cap<=0: return df.copy().reset_index(drop=True)
    parts=[]
    for lvl, g in df.groupby(label_col):
        parts.append(g.sample(n=min(len(g), cap), random_state=seed))
    return pd.concat(parts, ignore_index=True)

DF_TR_CAP = stratified_cap(DF_TR, "level", MAX_TRAIN_PER_CLASS, SEED)
print(f"[speed] Using {len(DF_TR_CAP)} capped train / {len(DF_VA)} val")

counts = Counter(DF_TR_CAP["level"].tolist())
num_classes = 2 if BINARY_MODE else 5
class_counts = np.array([counts.get(i, 1) for i in range(num_classes)], dtype=np.float32)
class_weights_for_sampler = (class_counts.sum() / (class_counts+1e-6))
sample_weights = DF_TR_CAP["level"].map({i:class_weights_for_sampler[i] for i in range(num_classes)}).values
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

TR_LOADER = make_loader(DF_TR_CAP, BATCH_SIZE, shuffle=False, sampler=sampler)
VA_LOADER = make_loader(DF_VA,       BATCH_SIZE, shuffle=False)

# ------------------ IMPROVED AUTOENCODER ------------------
class ImprovedAE(nn.Module):
    def __init__(self, latent_dim=AE_LATENT_DIM):
        super().__init__()
        ch = [3, 64, 128, 256, 512]
        enc = []
        for i in range(len(ch)-1):
            enc += [
                nn.Conv2d(ch[i], ch[i+1], 3, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(ch[i+1]),
                nn.GELU(),
            ]
        self.enc_conv = nn.Sequential(*enc)  # 224 -> 112 -> 56 -> 28 -> 14
        self.enc_gap  = nn.AdaptiveAvgPool2d(1)
        self.enc_fc   = nn.Linear(512, latent_dim)

        self.dec_fc   = nn.Linear(latent_dim, 512*14*14)
        self.dec_deconv = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(256), nn.GELU(),
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(128), nn.GELU(),
            nn.ConvTranspose2d(128,  64, 3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(64),  nn.GELU(),
            nn.ConvTranspose2d(64,     3, 3, stride=2, padding=1, output_padding=1),
        )

    def encode(self, x):
        h = self.enc_conv(x)
        h = self.enc_gap(h).flatten(1)
        z = self.enc_fc(h)
        return z

    def decode(self, z):
        h = self.dec_fc(z).view(-1, 512, 14, 14)
        x = self.dec_deconv(h)
        return x

    def forward(self, x):
        z = self.encode(x)
        xh = self.decode(z)
        return xh

def _denorm01(x):
    mean = MEAN.to(x.device); std = STD.to(x.device)
    return (x * std) + mean
def _renorm(x01):
    mean = MEAN.to(x01.device); std = STD.to(x01.device)
    return (x01 - mean) / std

def ae_augment_tensor(x, cj=COLOR_JITTER_STRENGTH):
    B, C, H, W = x.shape
    out = x
    if torch.rand(1, device=x.device) < 0.5: out = torch.flip(out, dims=[3])
    if torch.rand(1, device=x.device) < 0.2: out = torch.flip(out, dims=[2])
    angle = float(torch.empty(1, device=x.device).uniform_(-10.0, 10.0))
    scale = float(torch.empty(1, device=x.device).uniform_(0.95, 1.05))
    shear = float(torch.empty(1, device=x.device).uniform_(-4.0, 4.0))
    affined = []
    for i in range(B):
        img = out[i]
        img = TF.affine(img, angle=angle, translate=(0, 0), scale=scale, shear=[shear, 0.0],
                        interpolation=InterpolationMode.BILINEAR)
        affined.append(img)
    out = torch.stack(affined, dim=0)
    if cj and cj > 0:
        x01 = _denorm01(out).clamp(0,1)
        b = 1.0 + float(torch.empty(1, device=x.device).uniform_(-cj, cj))
        c = 1.0 + float(torch.empty(1, device=x.device).uniform_(-cj, cj))
        s = 1.0 + float(torch.empty(1, device=x.device).uniform_(-cj, cj))
        x01 = TF.adjust_brightness(x01, b)
        x01 = TF.adjust_contrast(x01,  c)
        x01 = TF.adjust_saturation(x01, s)
        out = _renorm(x01.clamp(0,1))
    return out

print(f"\n[AE] Training Improved Autoencoder for {AE_EPOCHS} epochs...")
if Path(CACHE_DIR).exists():
    try:
        shutil.rmtree(CACHE_DIR)
        print(f"[info] Cleared tensor cache: {CACHE_DIR}")
    except Exception: pass

TR_LOADER_AE = make_loader(DF_TR_CAP, BATCH_SIZE, shuffle=True)
ae_model = ImprovedAE(latent_dim=AE_LATENT_DIM).to(device)
recon_crit = nn.MSELoss()
optimizer = torch.optim.AdamW(ae_model.parameters(), lr=AE_LR, weight_decay=AE_WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max(1, AE_EPOCHS))
scaler = torch.cuda.amp.GradScaler(enabled=(device.type=="cuda"))

ae_model.train()
for ep in range(1, AE_EPOCHS+1):
    run = 0.0
    pbar = tqdm(TR_LOADER_AE, desc=f"AE {ep}/{AE_EPOCHS}")
    for xb, _ in pbar:
        xb = xb.to(device)
        x_in = ae_augment_tensor(xb)
        if AE_NOISE_STD > 0:
            x_in = (x_in + torch.randn_like(x_in)*AE_NOISE_STD).clamp(-5, 5)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):
            xh = ae_model(x_in); loss = recon_crit(xh, xb)
        scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
        run += loss.item(); pbar.set_postfix(loss=f"{loss.item():.4f}")
    scheduler.step()
    print(f"[AE] ep {ep}: recon={run/len(TR_LOADER_AE):.4f}")
ae_model.eval()

# Optional supervised head (frozen encoder)
class AESupHead(nn.Module):
    def __init__(self, encoder: ImprovedAE, num_classes=2 if BINARY_MODE else 5, in_dim=AE_LATENT_DIM):
        super().__init__()
        self.encoder = encoder
        for p in self.encoder.parameters(): p.requires_grad = False
        self.head = nn.Linear(in_dim, num_classes)
    def forward(self, x):
        with torch.no_grad():
            z = self.encoder.encode(x)
        return self.head(z)

if AE_SUP_HEAD_EPOCHS > 0:
    print(f"[AE] Supervised head fine-tune ({AE_SUP_HEAD_EPOCHS} epochs; frozen encoder)")
    sup_model = AESupHead(ae_model).to(device)
    opt_sup = torch.optim.AdamW(sup_model.head.parameters(), lr=AE_SUP_LR)
    ce = nn.CrossEntropyLoss()
    sup_model.train()
    for ep in range(1, AE_SUP_HEAD_EPOCHS+1):
        run_loss, seen, cor = 0.0, 0, 0
        for xb, yb in TR_LOADER:
            xb, yb = xb.to(device), torch.as_tensor(yb, device=device)
            logits = sup_model(xb); loss = ce(logits, yb)
            opt_sup.zero_grad(set_to_none=True); loss.backward(); opt_sup.step()
            run_loss += loss.item()*yb.size(0); seen += yb.size(0)
            cor += (logits.argmax(1)==yb).sum().item()
        print(f"[AE-sup] ep {ep}: train_acc={cor/max(1,seen):.3f} loss={(run_loss/max(1,seen)):.4f}")
    sup_model.eval()

# ------------------ BACKBONES & LINEAR HEADS ------------------
def build_resnet50_feats():
    try:
        m = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    except Exception as e:
        print("[warn] ResNet50 weights download failed, using random init:", e)
        m = models.resnet50(weights=None)
    m.fc = nn.Identity(); m.eval(); return m.to(device)

def build_vit_b16_feats():
    try:
        m = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)
    except Exception as e:
        print("[warn] ViT-B/16 weights download failed, using random init:", e)
        m = models.vit_b_16(weights=None)
    m.heads = nn.Identity(); m.eval(); return m.to(device)

def build_ae_feats(trained_ae: ImprovedAE):
    class EncOnly(nn.Module):
        def __init__(self, ae: ImprovedAE):
            super().__init__()
            self.enc_conv = ae.enc_conv
            self.enc_gap  = ae.enc_gap
            self.enc_fc   = ae.enc_fc
        def forward(self, x):
            h = self.enc_conv(x); h = self.enc_gap(h).flatten(1); z = self.enc_fc(h)
            return z
    return EncOnly(trained_ae).eval().to(device)

@torch.no_grad()
def extract_features(backbone, loader):
    feats, ys = [], []
    for xb, yb in tqdm(loader, desc="Extracting"):
        xb = xb.to(device)
        out = backbone(xb)
        if out.dim()==4: out = out.mean((2,3))
        feats.append(out.cpu().numpy()); ys.append(np.asarray(yb))
    return np.concatenate(feats), np.concatenate(ys)

def fit_eval_linear(Xtr, ytr, Xva, yva, label):
    scaler = StandardScaler()
    Xtr_s = scaler.fit_transform(Xtr); Xva_s = scaler.transform(Xva)
    clf = LogisticRegression(max_iter=600, class_weight="balanced").fit(Xtr_s, ytr)
    pred = clf.predict(Xva_s)
    acc = accuracy_score(yva, pred)
    qwk = cohen_kappa_score(yva, pred, weights="quadratic")
    print(f"[{label}] ACC={acc:.4f} QWK={qwk:.4f}")
    return acc, qwk, scaler, clf

TR_LOADER_NS = make_loader(DF_TR_CAP, BATCH_SIZE, shuffle=False)
VA_LOADER_NS = make_loader(DF_VA,       BATCH_SIZE, shuffle=False)
results = []

print("\n[FAST] ResNet50 features…")
resnet = build_resnet50_feats()
Xtr_r, ytr = extract_features(resnet, TR_LOADER_NS)
Xva_r, yva = extract_features(resnet, VA_LOADER_NS)
acc_r, qwk_r, sc_r, clf_r = fit_eval_linear(Xtr_r, ytr, Xva_r, yva, "resnet50_linear")
results.append({"model":"resnet50_linear","val_acc":acc_r,"val_qwk":qwk_r})

print("\n[FAST] ViT-B/16 features…")
vit = build_vit_b16_feats()
Xtr_v, _ = extract_features(vit, TR_LOADER_NS)
Xva_v, _ = extract_features(vit, VA_LOADER_NS)
acc_v, qwk_v, sc_v, clf_v = fit_eval_linear(Xtr_v, ytr, Xva_v, yva, "vit_b16_linear")
results.append({"model":"vit_b16_linear","val_acc":acc_v,"val_qwk":qwk_v})

print("\n[FAST] Improved AE encoder features…")
ae_backbone = build_ae_feats(ae_model)
Xtr_ae, _ = extract_features(ae_backbone, TR_LOADER_NS)
Xva_ae, _ = extract_features(ae_backbone, VA_LOADER_NS)
acc_ae, qwk_ae, sc_ae, clf_ae = fit_eval_linear(Xtr_ae, ytr, Xva_ae, yva, "autoencoder_linear")
results.append({"model":"autoencoder_linear","val_acc":acc_ae,"val_qwk":qwk_ae})

# ------------------ SUMMARY + PLOTS ------------------
summary_df = pd.DataFrame(results).sort_values(["val_qwk","val_acc"], ascending=False).reset_index(drop=True)
summary_csv = Path(ROOT_OUT)/"summary_models.csv"
summary_df.to_csv(summary_csv, index=False)

plt.figure(figsize=(7,4))
plt.bar(summary_df["model"], summary_df["val_acc"]); plt.ylim(0,1)
plt.ylabel("Val ACC"); plt.title("Accuracy by Model"); plt.grid(axis="y", alpha=0.3); plt.xticks(rotation=12)
plt.tight_layout(); plt.savefig(Path(ROOT_OUT)/"acc_by_model.png", dpi=140); plt.close()

plt.figure(figsize=(7,4))
plt.bar(summary_df["model"], summary_df["val_qwk"]); plt.ylim(0,1)
plt.ylabel("Val QWK"); plt.title("QWK by Model"); plt.grid(axis="y", alpha=0.3); plt.xticks(rotation=12)
plt.tight_layout(); plt.savefig(Path(ROOT_OUT)/"qwk_by_model.png", dpi=140); plt.close()

print("\n=== SUMMARY (higher is better) ===")
print(summary_df)

# ------------------ Fine-tune ResNet50 (end-to-end) ------------------
def build_train_pil_dataset(df):
    class DRTrainPIL(Dataset):
        def __init__(self, df, size):
            self.df = df.reset_index(drop=True); self.size = size
            self.tf = T.Compose([
                T.Resize(int(size*1.10), interpolation=InterpolationMode.BILINEAR),
                T.RandomResizedCrop(size, scale=(0.85, 1.0), ratio=(0.9, 1.1)),
                T.RandomHorizontalFlip(p=0.5),
                T.RandomVerticalFlip(p=0.2),
                T.RandomRotation(degrees=10, interpolation=InterpolationMode.BILINEAR, fill=0),
                T.ColorJitter(brightness=0.10, contrast=0.25, saturation=0.10, hue=0.00),
                T.RandomPerspective(distortion_scale=0.15, p=0.25, interpolation=InterpolationMode.BILINEAR),
                T.ToTensor(),
                T.Normalize(mean=MEAN.view(-1).tolist(), std=STD.view(-1).tolist()),
            ])
        def __len__(self): return len(self.df)
        def __getitem__(self, i):
            r = self.df.iloc[i]
            with Image.open(r.proc_path) as im:
                im = ImageOps.exif_transpose(im).convert("RGB")
                x = self.tf(im)
            return x, int(r.level)
    return DRTrainPIL(df, IMAGE_SIZE)

def make_train_aug_loader(df, bs, sampler=None):
    ds = build_train_pil_dataset(df)
    return DataLoader(
        ds,
        batch_size=bs,
        shuffle=(sampler is None),
        sampler=sampler,
        num_workers=2,
        pin_memory=(device.type=="cuda"),
    )

TR_LOADER_AUG = make_train_aug_loader(DF_TR_CAP, bs=16, sampler=None)

freq = torch.tensor([counts.get(c, 1) for c in range(num_classes)], dtype=torch.float32)
class_weights = (1.0 / freq).to(device)
class_weights = class_weights * (num_classes / class_weights.sum())

class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0):
        super().__init__()
        self.gamma = gamma
    def forward(self, logits, target, weight=None):
        logp = F.log_softmax(logits, dim=1)
        p = torch.exp(logp)
        pt = p.gather(1, target.unsqueeze(1)).squeeze(1)
        logpt = logp.gather(1, target.unsqueeze(1)).squeeze(1)
        focal = (1 - pt).clamp(0,1) ** self.gamma
        loss = -focal * logpt
        if weight is not None:
            loss = loss * weight[target]
        return loss.mean()

def build_resnet50_for_ft(num_classes=num_classes):
    try:
        m = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    except Exception as e:
        print("[warn] ResNet50 weights download failed for FT, using random init:", e)
        m = models.resnet50(weights=None)
    in_features = m.fc.in_features
    m.fc = nn.Linear(in_features, num_classes)
    return m.to(device)

@torch.no_grad()
def evaluate_with_tta(model, loader, tta=2):
    model.eval(); all_y, all_p = [], []
    for xb, yb in loader:
        xb = xb.to(device); yb = yb.to(device)
        logits_sum = 0
        for t in range(tta):
            x_in = xb
            if t == 1: x_in = torch.flip(x_in, dims=[3])
            logits_sum = logits_sum + model(x_in)
        logits = logits_sum / float(tta)
        all_p.append(logits.argmax(1).cpu()); all_y.append(yb.cpu())
    y_true = torch.cat(all_y).numpy(); y_pred = torch.cat(all_p).numpy()
    return accuracy_score(y_true, y_pred), cohen_kappa_score(y_true, y_pred, weights="quadratic")

def finetune_resnet50(epochs_stage1=FT_STAGE1_EPOCHS, epochs_stage2=FT_STAGE2_EPOCHS, base_lr=FT_BASE_LR, min_lr=1e-6):
    model = build_resnet50_for_ft(num_classes=num_classes)

    for p in model.parameters(): p.requires_grad = False
    for p in model.layer3.parameters(): p.requires_grad = True
    for p in model.layer4.parameters(): p.requires_grad = True
    for p in model.fc.parameters():     p.requires_grad = True

    crit = FocalLoss(gamma=2.0)
    opt  = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=base_lr, weight_decay=1e-4)
    sched= torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=max(1,epochs_stage1), eta_min=max(min_lr, base_lr*0.1))

    best = {"qwk": -1, "state": None}
    for ep in range(1, epochs_stage1+1):
        model.train()
        for xb, yb in TR_LOADER_AUG:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad(set_to_none=True)
            logits = model(xb); loss = crit(logits, yb, weight=class_weights)
            loss.backward(); opt.step()
        sched.step()
        acc, qwk = evaluate_with_tta(model, VA_LOADER, tta=4)
        print(f"[FT-stage1] ep {ep}: acc={acc:.4f} qwk={qwk:.4f}")
        if qwk > best["qwk"]:
            best = {"qwk": qwk, "state": {k: v.clone() for k, v in model.state_dict().items()}}

    for p in model.parameters(): p.requires_grad = True
    opt  = torch.optim.AdamW(model.parameters(), lr=base_lr*0.5, weight_decay=1e-4)
    sched= torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=max(1,epochs_stage2), eta_min=min_lr)
    patience = 4; no_improve = 0

    for ep in range(1, epochs_stage2+1):
        model.train()
        for xb, yb in TR_LOADER_AUG:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad(set_to_none=True)
            logits = model(xb); loss = crit(logits, yb, weight=class_weights)
            loss.backward(); opt.step()
        sched.step()
        acc, qwk = evaluate_with_tta(model, VA_LOADER, tta=2)
        print(f"[FT-stage2] ep {ep}: acc={acc:.4f} qwk={qwk:.4f}")
        if qwk > best["qwk"]:
            best = {"qwk": qwk, "state": {k: v.clone() for k, v in model.state_dict().items()}}; no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                print("[early-stop] no QWK improvement"); break

    if best["state"] is not None: model.load_state_dict(best["state"])
    final_acc, final_qwk = evaluate_with_tta(model, VA_LOADER, tta=2)
    print(f"[FT-final] acc={final_acc:.4f} qwk={final_qwk:.4f}")
    return model, final_acc, final_qwk

print("\n[FT] Fine-tuning ResNet50 end-to-end …")
ft_model, ft_acc, ft_qwk = finetune_resnet50()
summary_df = pd.concat([summary_df, pd.DataFrame([{"model":"resnet50_finetune","val_acc":ft_acc,"val_qwk":ft_qwk}])], ignore_index=True)
summary_df = summary_df.sort_values(["val_qwk","val_acc"], ascending=False).reset_index(drop=True)
summary_df.to_csv(Path(ROOT_OUT)/"summary_models.csv", index=False)
print("\n=== UPDATED SUMMARY (with fine-tune) ===")
print(summary_df)

# ==========================================================================================
# ===================== EXTRA RESULTS: CM + ROC + QUALITY + CONF + CASES + DIAGRAM + REDUCTION
# ==========================================================================================
from sklearn.metrics import confusion_matrix, roc_curve, auc
import itertools

EXTRA_DIR = Path(ROOT_OUT) / "extra_results"
EXTRA_DIR.mkdir(parents=True, exist_ok=True)

def save_confusion_matrix(y_true, y_pred, labels, title, out_png):
    cm = confusion_matrix(y_true, y_pred, labels=labels)

    plt.figure(figsize=(5.2,4.6))
    plt.imshow(cm, interpolation="nearest")
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(labels))
    plt.xticks(tick_marks, labels, rotation=20)
    plt.yticks(tick_marks, labels)

    thresh = cm.max() * 0.6 if cm.max() > 0 else 0
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], "d"),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel("True label")
    plt.xlabel("Predicted label")
    plt.tight_layout()
    plt.savefig(out_png, dpi=160)
    plt.close()
    return cm

def save_roc_curve(y_true_bin, y_score, title, out_png):
    fpr, tpr, _ = roc_curve(y_true_bin, y_score)
    roc_auc = auc(fpr, tpr)

    plt.figure(figsize=(5.6,4.6))
    plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.3f}")
    plt.plot([0,1], [0,1], linestyle="--", label="Chance")
    plt.xlim(0,1); plt.ylim(0,1.02)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(title)
    plt.legend(loc="lower right")
    plt.grid(alpha=0.25)
    plt.tight_layout()
    plt.savefig(out_png, dpi=160)
    plt.close()
    return roc_auc

@torch.no_grad()
def predict_ft_probs(model, loader):
    model.eval()
    probs = []
    preds = []
    ys = []
    for xb, yb in loader:
        xb = xb.to(device)
        logits = model(xb)
        p = torch.softmax(logits, dim=1)
        probs.append(p.detach().cpu().numpy())
        preds.append(torch.argmax(p, dim=1).cpu().numpy())
        ys.append(np.asarray(yb))
    return np.concatenate(probs), np.concatenate(preds), np.concatenate(ys)

def annotate_case(img_path, lines, out_path):
    with Image.open(img_path) as im:
        im = ImageOps.exif_transpose(im).convert("RGB")
        im = im.resize((IMAGE_SIZE, IMAGE_SIZE))
        draw = ImageDraw.Draw(im)
        y = 6
        for ln in lines:
            draw.rectangle([2, y-1, IMAGE_SIZE-2, y+16], fill=(0,0,0))
            draw.text((6, y), ln, fill=(255,255,255))
            y += 18
        im.save(out_path, quality=95)

# (1) CONFUSION MATRICES (VAL)
labels_list = [0,1] if BINARY_MODE else list(range(5))

def preds_from_linear(Xva, scaler, clf):
    Xs = scaler.transform(Xva)
    pred = clf.predict(Xs)
    proba = clf.predict_proba(Xs) if hasattr(clf,"predict_proba") else None
    return pred, proba

pred_r, proba_r = preds_from_linear(Xva_r, sc_r, clf_r)
pred_v, proba_v = preds_from_linear(Xva_v, sc_v, clf_v)
pred_ae, proba_ae = preds_from_linear(Xva_ae, sc_ae, clf_ae)

_ = save_confusion_matrix(yva, pred_r, labels_list,
                          "Confusion Matrix: ResNet50 Linear (Val)",
                          EXTRA_DIR/"cm_resnet50_linear.png")
_ = save_confusion_matrix(yva, pred_v, labels_list,
                          "Confusion Matrix: ViT-B/16 Linear (Val)",
                          EXTRA_DIR/"cm_vit_b16_linear.png")
_ = save_confusion_matrix(yva, pred_ae, labels_list,
                          "Confusion Matrix: Autoencoder Linear (Val)",
                          EXTRA_DIR/"cm_autoencoder_linear.png")

ft_probs, ft_pred, ft_y = predict_ft_probs(ft_model, VA_LOADER)
_ = save_confusion_matrix(ft_y, ft_pred, labels_list,
                          "Confusion Matrix: ResNet50 Fine-tuned (Val)",
                          EXTRA_DIR/"cm_resnet50_finetune.png")

print("[extra] Confusion matrices saved to:", EXTRA_DIR)

# (2) ROC CURVES (binary)
if BINARY_MODE:
    y_true_bin = (yva == 1).astype(int)
    auc_r = save_roc_curve(y_true_bin, proba_r[:,1], "ROC: ResNet50 Linear (Val)", EXTRA_DIR/"roc_resnet50_linear.png")
    auc_v = save_roc_curve(y_true_bin, proba_v[:,1], "ROC: ViT-B/16 Linear (Val)", EXTRA_DIR/"roc_vit_b16_linear.png")
    auc_ae= save_roc_curve(y_true_bin, proba_ae[:,1], "ROC: Autoencoder Linear (Val)", EXTRA_DIR/"roc_autoencoder_linear.png")
    auc_ft= save_roc_curve((ft_y==1).astype(int), ft_probs[:,1], "ROC: ResNet50 Fine-tuned (Val)", EXTRA_DIR/"roc_resnet50_finetune.png")
    print(f"[extra] AUCs: resnet_lin={auc_r:.3f} vit_lin={auc_v:.3f} ae_lin={auc_ae:.3f} resnet_ft={auc_ft:.3f}")
else:
    print("[extra] ROC curves skipped (BINARY_MODE=False).")

# (3) CONFIDENCE COMPARISON (VAL)
def max_conf_from_proba(proba):
    return proba.max(axis=1)

val_conf = pd.DataFrame({
    "model": ["resnet50_linear","vit_b16_linear","autoencoder_linear","resnet50_finetune"],
    "mean_max_conf": [
        float(max_conf_from_proba(proba_r).mean()),
        float(max_conf_from_proba(proba_v).mean()),
        float(max_conf_from_proba(proba_ae).mean()),
        float(np.max(ft_probs, axis=1).mean()),
    ]
})
val_conf.to_csv(EXTRA_DIR/"val_confidence_comparison.csv", index=False)

plt.figure(figsize=(7,4))
plt.bar(val_conf["model"], val_conf["mean_max_conf"])
plt.ylim(0,1)
plt.ylabel("Mean max probability")
plt.title("Validation: Mean Confidence (Max Prob) by Model")
plt.grid(axis="y", alpha=0.25)
plt.xticks(rotation=12)
plt.tight_layout()
plt.savefig(EXTRA_DIR/"val_confidence_comparison.png", dpi=160)
plt.close()

# (4) QUALITY DISTRIBUTIONS (from user inference CSV if available)
qual_csv = Path(ROOT_OUT)/"inference_user_comparison_wide.csv"
if qual_csv.exists():
    qdf = pd.read_csv(qual_csv)

    for col in ["blur_var","brightness","contrast","pct_underexposed","pct_overexposed"]:
        if col not in qdf.columns:
            continue
        plt.figure(figsize=(7,4))
        groups = [qdf.loc[qdf["quality"]==q, col].dropna().values for q in ["Good","Usable","Poor"] if q in qdf["quality"].unique()]
        labels = [q for q in ["Good","Usable","Poor"] if q in qdf["quality"].unique()]
        if len(groups) >= 2:
            plt.boxplot(groups, labels=labels)
            plt.title(f"Quality Distribution: {col} (User Uploads)")
            plt.ylabel(col)
            plt.grid(axis="y", alpha=0.25)
            plt.tight_layout()
            plt.savefig(EXTRA_DIR/f"quality_boxplot_{col}.png", dpi=160)
            plt.close()

    if "quality" in qdf.columns:
        counts_q = qdf["quality"].value_counts()
        plt.figure(figsize=(6,4))
        plt.bar(counts_q.index.astype(str), counts_q.values)
        plt.title("User Uploads: Quality Label Distribution")
        plt.ylabel("Count")
        plt.grid(axis="y", alpha=0.25)
        plt.tight_layout()
        plt.savefig(EXTRA_DIR/"quality_distribution_counts.png", dpi=160)
        plt.close()
else:
    print("[extra] Quality plots skipped: run user inference (upload images) to create inference_user_comparison_wide.csv")

# (5) CASE STUDIES (annotate 3 validation images)
case_dir = EXTRA_DIR/"case_studies"
case_dir.mkdir(parents=True, exist_ok=True)

DF_VA_cases = DF_VA.copy()
Xs = sc_r.transform(Xva_r)
va_pred = clf_r.predict(Xs)
va_conf = clf_r.predict_proba(Xs).max(axis=1) if hasattr(clf_r,"predict_proba") else np.ones(len(va_pred))

DF_VA_cases["pred_resnet_linear"] = va_pred
DF_VA_cases["conf_resnet_linear"] = va_conf
DF_VA_cases["correct"] = (DF_VA_cases["pred_resnet_linear"].values == DF_VA_cases["level"].values)

c1 = DF_VA_cases[DF_VA_cases["correct"]].sort_values("conf_resnet_linear", ascending=False).head(1)
c2 = DF_VA_cases[~DF_VA_cases["correct"]].sort_values("conf_resnet_linear", ascending=False).head(1)
c3 = DF_VA_cases.sort_values("conf_resnet_linear", ascending=True).head(1)
picked = pd.concat([c1,c2,c3], ignore_index=True).drop_duplicates(subset=["proc_path"]).head(3)

for i, row in picked.iterrows():
    imgp = row["proc_path"]
    true_y = int(row["level"])
    cr = float(row["conf_resnet_linear"])

    with Image.open(imgp) as im:
        im = ImageOps.exif_transpose(im).convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE))
        arr = np.asarray(im, dtype=np.float32)/255.0
    x = torch.from_numpy(np.transpose(arr,(2,0,1))).unsqueeze(0).to(device)
    x = (x - MEAN.to(device)) / STD.to(device)

    with torch.no_grad():
        fr = resnet(x); fv = vit(x); fa = ae_backbone(x)
        if fr.dim()==4: fr = fr.mean((2,3))
        if fv.dim()==4: fv = fv.mean((2,3))
        if fa.dim()==4: fa = fa.mean((2,3))

    pr2 = int(clf_r.predict(sc_r.transform(fr.cpu().numpy()))[0])
    pv2 = int(clf_v.predict(sc_v.transform(fv.cpu().numpy()))[0])
    pa2 = int(clf_ae.predict(sc_ae.transform(fa.cpu().numpy()))[0])

    with torch.no_grad():
        ft_log = ft_model(x)
        ft_p = torch.softmax(ft_log, dim=1).cpu().numpy()[0]
        pft = int(np.argmax(ft_p))
        cft = float(np.max(ft_p))

    outp = case_dir / f"case_{i+1}.jpg"
    lines = [
        f"TRUE={true_y} | ResNetLin={pr2} (conf~{cr:.2f})",
        f"ViTLin={pv2} | AELin={pa2}",
        f"ResNetFT={pft} (conf~{cft:.2f})",
    ]
    annotate_case(imgp, lines, outp)

# (6) WORKFLOW DIAGRAM (draw.io)
diagram_path = EXTRA_DIR/"workflow_diagram.drawio"
drawio_xml = """<mxfile host="app.diagrams.net">
  <diagram name="DR Workflow">
    <mxGraphModel dx="1000" dy="700" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="1100" pageHeight="850">
      <root>
        <mxCell id="0"/><mxCell id="1" parent="0"/>
        <mxCell id="2" value="Input Data&#10;(train_subset.zip + trainLabels.csv)" style="rounded=1;whiteSpace=wrap;html=1;" vertex="1" parent="1"><mxGeometry x="40" y="80" width="240" height="70" as="geometry"/></mxCell>
        <mxCell id="3" value="Preprocess&#10;crop_black + autocontrast + equalize&#10;center-crop + resize" style="rounded=1;whiteSpace=wrap;html=1;" vertex="1" parent="1"><mxGeometry x="320" y="80" width="260" height="90" as="geometry"/></mxCell>
        <mxCell id="4" value="Stratified Split&#10;train/val" style="rounded=1;whiteSpace=wrap;html=1;" vertex="1" parent="1"><mxGeometry x="620" y="90" width="180" height="60" as="geometry"/></mxCell>
        <mxCell id="5" value="Feature Backbones&#10;ResNet50 (ImageNet)&#10;ViT-B/16 (ImageNet)&#10;Improved AE Encoder" style="rounded=1;whiteSpace=wrap;html=1;" vertex="1" parent="1"><mxGeometry x="40" y="240" width="280" height="110" as="geometry"/></mxCell>
        <mxCell id="6" value="Linear Head&#10;StandardScaler + LogisticRegression" style="rounded=1;whiteSpace=wrap;html=1;" vertex="1" parent="1"><mxGeometry x="360" y="250" width="260" height="80" as="geometry"/></mxCell>
        <mxCell id="7" value="Evaluation&#10;Accuracy + QWK&#10;Confusion Matrix + ROC" style="rounded=1;whiteSpace=wrap;html=1;" vertex="1" parent="1"><mxGeometry x="660" y="250" width="240" height="90" as="geometry"/></mxCell>
        <mxCell id="8" value="Fine-tune ResNet50&#10;(Stage 1: partial unfreeze&#10;Stage 2: full unfreeze + early stop)" style="rounded=1;whiteSpace=wrap;html=1;" vertex="1" parent="1"><mxGeometry x="40" y="410" width="320" height="100" as="geometry"/></mxCell>
        <mxCell id="9" value="User Inference&#10;3-model predictions + confidence" style="rounded=1;whiteSpace=wrap;html=1;" vertex="1" parent="1"><mxGeometry x="400" y="420" width="240" height="80" as="geometry"/></mxCell>
        <mxCell id="10" value="Quality Metrics&#10;blur_var + brightness + exposure%" style="rounded=1;whiteSpace=wrap;html=1;" vertex="1" parent="1"><mxGeometry x="680" y="420" width="240" height="80" as="geometry"/></mxCell>
        <mxCell id="11" value="Consensus + All-Model Risk&#10;Majority vote + unanimous check" style="rounded=1;whiteSpace=wrap;html=1;" vertex="1" parent="1"><mxGeometry x="940" y="420" width="240" height="80" as="geometry"/></mxCell>
        <mxCell id="e1" style="endArrow=block;html=1;" edge="1" parent="1" source="2" target="3"><mxGeometry relative="1" as="geometry"/></mxCell>
        <mxCell id="e2" style="endArrow=block;html=1;" edge="1" parent="1" source="3" target="4"><mxGeometry relative="1" as="geometry"/></mxCell>
        <mxCell id="e3" style="endArrow=block;html=1;" edge="1" parent="1" source="4" target="5"><mxGeometry relative="1" as="geometry"/></mxCell>
        <mxCell id="e4" style="endArrow=block;html=1;" edge="1" parent="1" source="5" target="6"><mxGeometry relative="1" as="geometry"/></mxCell>
        <mxCell id="e5" style="endArrow=block;html=1;" edge="1" parent="1" source="6" target="7"><mxGeometry relative="1" as="geometry"/></mxCell>
        <mxCell id="e6" style="endArrow=block;html=1;" edge="1" parent="1" source="7" target="8"><mxGeometry relative="1" as="geometry"/></mxCell>
        <mxCell id="e7" style="endArrow=block;html=1;" edge="1" parent="1" source="8" target="9"><mxGeometry relative="1" as="geometry"/></mxCell>
        <mxCell id="e8" style="endArrow=block;html=1;" edge="1" parent="1" source="9" target="10"><mxGeometry relative="1" as="geometry"/></mxCell>
        <mxCell id="e9" style="endArrow=block;html=1;" edge="1" parent="1" source="10" target="11"><mxGeometry relative="1" as="geometry"/></mxCell>
      </root>
    </mxGraphModel>
  </diagram>
</mxfile>
"""
with open(diagram_path, "w") as f:
    f.write(drawio_xml)

# (7) REDUCTION IMPACT (from user inference CSV)
if qual_csv.exists():
    qdf = pd.read_csv(qual_csv)
    model_names = ["resnet50_linear","vit_b16_linear","autoencoder_linear"]
    rates = []
    for m in model_names:
        c = f"risk@{m}"
        if c in qdf.columns:
            rates.append((m, float((qdf[c] == "Yes").mean())))
    if "consensus_risk" in qdf.columns:
        rates.append(("consensus_risk", float((qdf["consensus_risk"] == "Yes").mean())))
    if "all_models_risk" in qdf.columns:
        rates.append(("all_models_risk", float((qdf["all_models_risk"] == "Yes").mean())))

    red = pd.DataFrame(rates, columns=["system","fraction_flagged_dr"])
    red.to_csv(EXTRA_DIR/"reduction_impact.csv", index=False)

    plt.figure(figsize=(7.4,4))
    plt.bar(red["system"], red["fraction_flagged_dr"])
    plt.ylim(0,1)
    plt.ylabel("Fraction flagged DR")
    plt.title("Reduction Impact: Single Models vs Consensus vs Unanimous")
    plt.grid(axis="y", alpha=0.25)
    plt.xticks(rotation=12)
    plt.tight_layout()
    plt.savefig(EXTRA_DIR/"reduction_impact.png", dpi=160)
    plt.close()

# Copy extras into public folder
for p in EXTRA_DIR.glob("**/*"):
    if p.is_file():
        relname = p.name if p.parent == EXTRA_DIR else str(p.relative_to(EXTRA_DIR)).replace("/", "_")
        shutil.copy2(str(p), str(Path(FILES_DIR)/relname))

print("[extra] Extra artifacts saved to:", EXTRA_DIR)
print("[extra] Extra artifacts copied to public folder:", FILES_DIR)
# ==========================================================================================

# ------------------ Predict helpers ------------------
def _to_gray_224(path, size=IMAGE_SIZE):
    with Image.open(path) as im:
        g = ImageOps.exif_transpose(im).convert("L").resize((size,size))
        return np.asarray(g, dtype=np.float32)
def _conv2(img, k):
    kh, kw = k.shape; ph, pw = kh//2, kw//2
    pad = np.pad(img, ((ph,ph),(pw,pw)), mode="reflect")
    H,W = img.shape; out = np.empty_like(img, dtype=np.float32)
    for i in range(H):
        for j in range(W):
            out[i,j] = float((pad[i:i+kh, j:j+kw]*k).sum())
    return out
def quality_metrics(path):
    try: g = _to_gray_224(path)
    except Exception:
        return dict(quality="Poor", blur_var=0.0, brightness=0.0, contrast=0.0,
                    pct_underexposed=1.0, pct_overexposed=0.0)
    brightness=float(g.mean()); contrast=float(g.std())
    lap=np.array([[0,1,0],[1,-4,1],[0,1,0]],dtype=np.float32)
    blur_var=float(_conv2(g,lap).var())
    pct_under=float((g<25).mean()); pct_over=float((g>230).mean())
    if brightness<35 or blur_var<50 or pct_under>0.30 or pct_over>0.30: quality="Poor"
    elif brightness<60 or blur_var<120 or pct_under>0.10 or pct_over>0.10: quality="Usable"
    else: quality="Good"
    return dict(quality=quality, blur_var=blur_var, brightness=brightness,
                contrast=contrast, pct_underexposed=pct_under, pct_overexposed=pct_over)

@torch.no_grad()
def predict_with(backbone, scaler, clf, img_paths, model_name):
    rows=[]
    for p in img_paths:
        with Image.open(p) as im:
            im = ImageOps.exif_transpose(im).convert("RGB").resize((IMAGE_SIZE,IMAGE_SIZE))
            arr = np.asarray(im, dtype=np.float32)/255.0
        x = torch.from_numpy(np.transpose(arr,(2,0,1))).unsqueeze(0).to(device)
        x = (x - MEAN.to(device)) / STD.to(device)
        out = backbone(x)
        if out.dim()==4: out = out.mean((2,3))
        X = out.cpu().numpy()
        Xs = scaler.transform(X)
        pred = int(clf.predict(Xs)[0])
        conf = float(clf.predict_proba(Xs).max(axis=1)[0]) if hasattr(clf,"predict_proba") else 1.0
        risk = "Yes" if (pred>=1 if not BINARY_MODE else pred==1) else "No"
        q = quality_metrics(p)
        rows.append({
            "file": Path(p).name, "model_name": model_name,
            "pred_level": pred, "risk": risk, "confidence": conf,
            "quality": q["quality"], "blur_var": q["blur_var"],
            "brightness": q["brightness"], "contrast": q["contrast"],
            "pct_underexposed": q["pct_underexposed"], "pct_overexposed": q["pct_overexposed"],
        })
    return pd.DataFrame(rows)

def run_comparative_inference_on_paths(paths):
    exts={".jpg",".jpeg",".png",".tif",".tiff",".bmp"}
    norm=[str(Path(p).resolve()) for p in paths if Path(p).is_file() and Path(p).suffix.lower() in exts]
    norm = sorted(list(dict.fromkeys(norm)))
    if not norm:
        print("[inference] No valid image paths."); return None, None

    print(f"[inference] Predicting {len(norm)} images across 3 models...")
    dfs=[]
    dfs.append(predict_with(resnet, sc_r, clf_r, norm, "resnet50_linear"))
    dfs.append(predict_with(vit,    sc_v, clf_v, norm, "vit_b16_linear"))
    dfs.append(predict_with(ae_backbone, sc_ae, clf_ae, norm, "autoencoder_linear"))

    df_long = pd.concat(dfs, ignore_index=True)
    piv_pred=df_long.pivot(index="file", columns="model_name", values="pred_level")
    piv_risk=df_long.pivot(index="file", columns="model_name", values="risk")
    piv_conf=df_long.pivot(index="file", columns="model_name", values="confidence")
    piv_pred.columns=[f"pred_level@{c}" for c in piv_pred.columns]
    piv_risk.columns=[f"risk@{c}" for c in piv_risk.columns]
    piv_conf.columns=[f"confidence@{c}" for c in piv_conf.columns]
    qcols=["quality","blur_var","brightness","contrast","pct_underexposed","pct_overexposed"]
    qual=(df_long.sort_values(["file","model_name"]).drop_duplicates("file")[["file"]+qcols]).set_index("file")
    df_wide=qual.join([piv_pred,piv_risk,piv_conf], how="left").reset_index()

    lvl_cols=[c for c in df_wide.columns if c.startswith("pred_level@")]
    risk_cols=[c for c in df_wide.columns if c.startswith("risk@")]
    df_wide["consensus_level"]=df_wide[lvl_cols].mode(axis=1)[0]
    df_wide["consensus_risk"]=df_wide[risk_cols].mode(axis=1)[0]

    model_names = ["resnet50_linear","vit_b16_linear","autoencoder_linear"]
    all_risk_cols = [f"risk@{m}" for m in model_names if f"risk@{m}" in df_wide.columns]
    def all_yes(row):
        vals = [row[c] for c in all_risk_cols if c in row and pd.notna(row[c])]
        return "Yes" if (len(vals) == len(all_risk_cols) and all(v=="Yes" for v in vals)) else "No"
    df_wide["all_models_risk"] = df_wide.apply(all_yes, axis=1)

    out_long=Path(ROOT_OUT)/"inference_user_predictions_long.csv"
    out_wide=Path(ROOT_OUT)/"inference_user_comparison_wide.csv"
    df_long.to_csv(out_long, index=False); df_wide.to_csv(out_wide, index=False)
    print("Saved:", out_long, "\nSaved:", out_wide)

    out_allrisk = Path(ROOT_OUT)/"inference_only_all_models_risk_yes.csv"
    df_wide[df_wide["all_models_risk"]=="Yes"].to_csv(out_allrisk, index=False)
    print("Saved:", out_allrisk)

    # Plots
    conf = df_long.groupby("model_name", dropna=False)["confidence"].mean().reset_index()
    plt.figure(figsize=(7,4)); order = conf.sort_values("confidence", ascending=False)
    plt.bar(order["model_name"], order["confidence"]); plt.ylim(0,1)
    plt.title("User Inference: Mean Confidence per Model"); plt.ylabel("Mean confidence"); plt.grid(axis="y", alpha=0.25); plt.xticks(rotation=12)
    plt.tight_layout(); plt.savefig(Path(ROOT_OUT)/"plot_user_confidence_per_model.png", dpi=140); plt.close()

    tmp = df_long.copy()
    tmp["is_dr"] = (tmp["pred_level"]>=1).astype(float) if not BINARY_MODE else (tmp["pred_level"]==1).astype(float)
    dr_rate = tmp.groupby("model_name", dropna=False)["is_dr"].mean().reset_index()
    plt.figure(figsize=(7,4)); order = dr_rate.sort_values("is_dr", ascending=False)
    plt.bar(order["model_name"], order["is_dr"]); plt.ylim(0,1)
    ttl = "User Inference: % Images Flagged DR" if BINARY_MODE else "User Inference: % Images Level ≥ 1"
    plt.title(ttl); plt.ylabel("Fraction"); plt.grid(axis="y", alpha=0.25); plt.xticks(rotation=12)
    plt.tight_layout(); plt.savefig(Path(ROOT_OUT)/"plot_user_dr_rate_per_model.png", dpi=140); plt.close()

    return df_long, df_wide

# ------------------ Upload images & run batch inference ------------------
print("\n=== Upload images for batch inference (optional) ===")
user_paths=[]
try:
    from google.colab import files
    uploads = files.upload()
    for name, data in uploads.items():
        dst = Path(ROOT_OUT)/f"uploaded_{name}"
        with open(dst,"wb") as f: f.write(data)
        user_paths.append(str(dst))
except Exception:
    print("[info] Upload skipped.")

if user_paths:
    df_long, df_wide = run_comparative_inference_on_paths(user_paths)
    def _img_to_base64(path):
        try:
            with open(path, "rb") as f:
                return "data:image/png;base64," + base64.b64encode(f.read()).decode("utf-8")
        except: return ""
    fn2path = {Path(p).name: p for p in user_paths}
    secs=[]
    for _,row in df_wide.iterrows():
        fn=row["file"]; src=fn2path.get(fn,""); b64=_img_to_base64(src) if src else ""
        item=f"""
        <section style="border:1px solid #ddd;border-radius:10px;padding:12px;margin:12px;">
          <h3 style="margin:0 0 8px 0;">{fn}</h3>
          <img src="{b64}" style="width:240px;border:1px solid #eee;border-radius:8px"/>
          <div>Consensus level: <b>{row['consensus_level']}</b> |
               Consensus risk: <b>{row['consensus_risk']}</b> |
               All models risk: <b>{row['all_models_risk']}</b></div>
        </section>"""
        secs.append(item)
    html=f"<html><body><h2>DR Batch Report</h2>{''.join(secs)}</body></html>"
    with open(Path(ROOT_OUT)/"report_xai.html","w") as f: f.write(html)
else:
    print("[batch] No images uploaded; skipping batch inference/report.")

# ------------------ Copy artifacts to static + download ------------------
def copy_artifacts_to_static():
    to_copy = [
        Path(ROOT_OUT)/"inference_user_predictions_long.csv",
        Path(ROOT_OUT)/"inference_user_comparison_wide.csv",
        Path(ROOT_OUT)/"inference_only_all_models_risk_yes.csv",
        Path(ROOT_OUT)/"summary_models.csv",
        Path(ROOT_OUT)/"acc_by_model.png",
        Path(ROOT_OUT)/"qwk_by_model.png",
        Path(ROOT_OUT)/"plot_user_confidence_per_model.png",
        Path(ROOT_OUT)/"plot_user_dr_rate_per_model.png",
        Path(ROOT_OUT)/"report_xai.html",
    ]
    copied=[]
    for p in to_copy:
        if Path(p).exists():
            dst = Path(FILES_DIR)/Path(p).name
            shutil.copy2(str(p), str(dst))
            copied.append(str(dst))
    print("\nArtifacts in /api/files/:")
    for c in copied: print(" -", Path(c).name)
    return copied

copied = copy_artifacts_to_static()
if AUTO_DOWNLOAD:
    try:
        from google.colab import files as _files
        for fp in copied:
            print("[download]", Path(fp).name)
            _files.download(fp)
            time.sleep(0.2)
    except Exception as e:
        print("[download] skipped:", e)

# ------------------ Persist models for Cell 2 ------------------
torch.save(ae_model.state_dict(), str(Path(MODELS_DIR)/"ae_model.pt"))
joblib.dump(sc_r,  str(Path(MODELS_DIR)/"resnet_scaler.joblib"))
joblib.dump(clf_r, str(Path(MODELS_DIR)/"resnet_clf.joblib"))
joblib.dump(sc_v,  str(Path(MODELS_DIR)/"vit_scaler.joblib"))
joblib.dump(clf_v, str(Path(MODELS_DIR)/"vit_clf.joblib"))
joblib.dump(sc_ae, str(Path(MODELS_DIR)/"ae_scaler.joblib"))
joblib.dump(clf_ae,str(Path(MODELS_DIR)/"ae_clf.joblib"))

# ✅ Save fine-tuned ResNet50 checkpoint (fixes Cell 2 FileNotFoundError)
torch.save({"state_dict": ft_model.state_dict(), "arch": "resnet50", "num_classes": num_classes},
           str(Path(MODELS_DIR)/"resnet50_ft.pt"))

# Save config so Cell 2 can load seamlessly
cfg = {
    "BINARY_MODE": BINARY_MODE,
    "IMAGE_SIZE": IMAGE_SIZE,
    "ROOT_OUT": ROOT_OUT,
    "FILES_DIR": FILES_DIR,
    "MODELS_DIR": MODELS_DIR,
    "MEAN": [0.485,0.456,0.406],
    "STD":  [0.229,0.224,0.225],
}
with open(Path(MODELS_DIR)/"ui_config.json","w") as f: json.dump(cfg, f, indent=2)

print("\n✅ Cell 1 finished. Models & artifacts saved.")
print("➡️ Now run **Cell 2** to launch the Gradio UI.")

# ==========================================================================================
# ===================== EXTRA RESULTS: CM + ROC + CONF + CASES + DIAGRAM + REDUCTION
# ==========================================================================================
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc
import matplotlib.pyplot as plt
import numpy as np
import itertools

EXTRA_DIR = Path(ROOT_OUT) / "extra_results"
EXTRA_DIR.mkdir(parents=True, exist_ok=True)

def save_confusion_matrix(y_true, y_pred, labels, title, out_png):
    cm = confusion_matrix(y_true, y_pred, labels=labels)

    fig, ax = plt.subplots(figsize=(5.8, 5.0))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
    disp.plot(ax=ax, cmap=None, colorbar=True, values_format="d")
    ax.set_title(title)
    plt.tight_layout()
    fig.savefig(out_png, dpi=180)
    plt.close(fig)
    return cm

def save_roc_curve_binary(y_true_bin, y_score, title, out_png):
    fpr, tpr, _ = roc_curve(y_true_bin, y_score)
    roc_auc = auc(fpr, tpr)

    fig, ax = plt.subplots(figsize=(6.0, 4.8))
    ax.plot(fpr, tpr, label=f"AUC = {roc_auc:.3f}")
    ax.plot([0,1], [0,1], linestyle="--", label="Chance")
    ax.set_xlim(0,1); ax.set_ylim(0,1.02)
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.set_title(title)
    ax.grid(alpha=0.25)
    ax.legend(loc="lower right")
    plt.tight_layout()
    fig.savefig(out_png, dpi=180)
    plt.close(fig)
    return roc_auc

def safe_predict_linear(Xva, scaler, clf):
    Xs = scaler.transform(Xva)
    pred = clf.predict(Xs)
    proba = None
    if hasattr(clf, "predict_proba"):
        try:
            proba = clf.predict_proba(Xs)
        except Exception:
            proba = None
    return pred, proba

@torch.no_grad()
def predict_ft_probs(model, loader):
    model.eval()
    probs, preds, ys = [], [], []
    for xb, yb in loader:
        xb = xb.to(device)
        logits = model(xb)
        p = torch.softmax(logits, dim=1)
        probs.append(p.cpu().numpy())
        preds.append(torch.argmax(p, dim=1).cpu().numpy())
        ys.append(np.asarray(yb))
    return np.concatenate(probs), np.concatenate(preds), np.concatenate(ys)

# -------- (1) Compute VAL predictions for all models --------
labels_list = [0,1] if BINARY_MODE else list(range(num_classes))

pred_r, proba_r = safe_predict_linear(Xva_r, sc_r, clf_r)
pred_v, proba_v = safe_predict_linear(Xva_v, sc_v, clf_v)
pred_ae, proba_ae = safe_predict_linear(Xva_ae, sc_ae, clf_ae)

# Fine-tuned predictions (always probabilities via softmax)
ft_probs, ft_pred, ft_y = predict_ft_probs(ft_model, VA_LOADER)

# -------- (2) Confusion matrices (always saved) --------
save_confusion_matrix(yva, pred_r, labels_list,
                      "Confusion Matrix — ResNet50 Linear (Val)",
                      EXTRA_DIR / "cm_resnet50_linear.png")

save_confusion_matrix(yva, pred_v, labels_list,
                      "Confusion Matrix — ViT-B/16 Linear (Val)",
                      EXTRA_DIR / "cm_vit_b16_linear.png")

save_confusion_matrix(yva, pred_ae, labels_list,
                      "Confusion Matrix — Autoencoder Linear (Val)",
                      EXTRA_DIR / "cm_autoencoder_linear.png")

save_confusion_matrix(ft_y, ft_pred, labels_list,
                      "Confusion Matrix — ResNet50 Fine-tuned (Val)",
                      EXTRA_DIR / "cm_resnet50_finetune.png")

print("[extra] ✅ Confusion matrices saved to:", EXTRA_DIR)

# -------- (3) ROC curves (binary only, robust checks) --------
if BINARY_MODE:
    y_true_bin = (yva == 1).astype(int)

    if proba_r is not None:
        auc_r = save_roc_curve_binary(y_true_bin, proba_r[:,1],
                                      "ROC — ResNet50 Linear (Val)",
                                      EXTRA_DIR / "roc_resnet50_linear.png")
    else:
        auc_r = None

    if proba_v is not None:
        auc_v = save_roc_curve_binary(y_true_bin, proba_v[:,1],
                                      "ROC — ViT-B/16 Linear (Val)",
                                      EXTRA_DIR / "roc_vit_b16_linear.png")
    else:
        auc_v = None

    if proba_ae is not None:
        auc_ae = save_roc_curve_binary(y_true_bin, proba_ae[:,1],
                                       "ROC — Autoencoder Linear (Val)",
                                       EXTRA_DIR / "roc_autoencoder_linear.png")
    else:
        auc_ae = None

    auc_ft = save_roc_curve_binary((ft_y==1).astype(int), ft_probs[:,1],
                                   "ROC — ResNet50 Fine-tuned (Val)",
                                   EXTRA_DIR / "roc_resnet50_finetune.png")

    print("[extra] ✅ ROC saved. AUCs:",
          {"resnet_lin": auc_r, "vit_lin": auc_v, "ae_lin": auc_ae, "resnet_ft": auc_ft})
else:
    print("[extra] ROC skipped (BINARY_MODE=False).")

# -------- (4) Confidence comparison --------
def mean_max_conf(proba):
    return float(np.max(proba, axis=1).mean())

conf_rows = []
if proba_r is not None: conf_rows.append(("resnet50_linear", mean_max_conf(proba_r)))
if proba_v is not None: conf_rows.append(("vit_b16_linear", mean_max_conf(proba_v)))
if proba_ae is not None: conf_rows.append(("autoencoder_linear", mean_max_conf(proba_ae)))
conf_rows.append(("resnet50_finetune", float(np.max(ft_probs, axis=1).mean())))

val_conf = pd.DataFrame(conf_rows, columns=["model","mean_max_conf"])
val_conf.to_csv(EXTRA_DIR/"val_confidence_comparison.csv", index=False)

plt.figure(figsize=(7.2,4.2))
plt.bar(val_conf["model"], val_conf["mean_max_conf"])
plt.ylim(0,1)
plt.title("Validation: Mean Confidence (Max Prob)")
plt.ylabel("Mean max probability")
plt.grid(axis="y", alpha=0.25)
plt.xticks(rotation=12)
plt.tight_layout()
plt.savefig(EXTRA_DIR/"val_confidence_comparison.png", dpi=180)
plt.close()

# -------- (5) Always list what actually exists (so you can confirm) --------
print("\n[extra] ✅ Files created in extra_results:")
for p in sorted(EXTRA_DIR.glob("*.png")):
    print(" -", p.name)
for p in sorted(EXTRA_DIR.glob("*.csv")):
    print(" -", p.name)

# -------- (6) Copy extras into public folder (your /api/files/) --------
for p in EXTRA_DIR.glob("*"):
    if p.is_file():
        shutil.copy2(str(p), str(Path(FILES_DIR)/p.name))

print("\n[extra] ✅ Copied extra_results to public folder:", FILES_DIR)
# ==========================================================================================


from IPython.display import Image as IPyImage, display
display(IPyImage(str(EXTRA_DIR/"cm_resnet50_linear.png")))
display(IPyImage(str(EXTRA_DIR/"roc_resnet50_linear.png")))


In [None]:
# === DR — CELL 2 (One-and-done UI): load, predict once, auto-terminate ===
import os, json, time, threading, sys
from pathlib import Path
import numpy as np
from PIL import Image, ImageOps
import torch
import torch.nn as nn
from torchvision import models
import joblib

# ---------- Load config & common ----------
MODELS_DIR = "/content/dr_one_cell/models"
with open(Path(MODELS_DIR)/"ui_config.json") as f:
    cfg = json.load(f)

ROOT_OUT     = cfg["ROOT_OUT"]
FILES_DIR    = cfg["FILES_DIR"]
BINARY_MODE  = bool(cfg["BINARY_MODE"])
IMAGE_SIZE   = int(cfg["IMAGE_SIZE"])
MEAN         = torch.tensor(cfg["MEAN"]).view(3,1,1)
STD          = torch.tensor(cfg["STD"]).view(3,1,1)
device       = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------- Backbones ----------
def build_resnet50_feats():
    try:
        m = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    except Exception:
        m = models.resnet50(weights=None)
    m.fc = nn.Identity(); m.eval(); return m.to(device)

def build_vit_b16_feats():
    try:
        m = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)
    except Exception:
        m = models.vit_b_16(weights=None)
    m.heads = nn.Identity(); m.eval(); return m.to(device)

# Match Cell 1 AE encoder exactly
class ImprovedAE(nn.Module):
    def __init__(self, latent_dim=384):
        super().__init__()
        ch = [3, 64, 128, 256, 512]
        enc=[]
        for i in range(len(ch)-1):
            enc += [nn.Conv2d(ch[i], ch[i+1], 3, stride=2, padding=1, bias=False),
                    nn.BatchNorm2d(ch[i+1]), nn.GELU()]
        self.enc_conv = nn.Sequential(*enc)
        self.enc_gap  = nn.AdaptiveAvgPool2d(1)
        self.enc_fc   = nn.Linear(512, latent_dim)
        # decoder (present in checkpoint)
        self.dec_fc   = nn.Linear(latent_dim, 512*14*14)
        self.dec_deconv = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(256), nn.GELU(),
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(128), nn.GELU(),
            nn.ConvTranspose2d(128,  64, 3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(64),  nn.GELU(),
            nn.ConvTranspose2d(64,     3, 3, stride=2, padding=1, output_padding=1),
        )
    def forward(self, x):
        h = self.enc_conv(x); h = self.enc_gap(h).flatten(1); return self.enc_fc(h)

def build_ae_feats_from_ckpt(ckpt_path, latent_dim=384):
    base = ImprovedAE(latent_dim=latent_dim)
    sd = torch.load(ckpt_path, map_location="cpu")
    base.load_state_dict(sd, strict=True)
    class EncOnly(nn.Module):
        def __init__(self, ae):
            super().__init__(); self.enc_conv=ae.enc_conv; self.enc_gap=ae.enc_gap; self.enc_fc=ae.enc_fc
        def forward(self, x):
            h=self.enc_conv(x); h=self.enc_gap(h).flatten(1); return self.enc_fc(h)
    return EncOnly(base).eval().to(device)

# Feature extractors (linear heads)
resnet = build_resnet50_feats()
vit    = build_vit_b16_feats()
ae_backbone = build_ae_feats_from_ckpt(Path(MODELS_DIR)/"ae_model.pt", latent_dim=384)

sc_r   = joblib.load(Path(MODELS_DIR)/"resnet_scaler.joblib")
clf_r  = joblib.load(Path(MODELS_DIR)/"resnet_clf.joblib")
sc_v   = joblib.load(Path(MODELS_DIR)/"vit_scaler.joblib")
clf_v  = joblib.load(Path(MODELS_DIR)/"vit_clf.joblib")
sc_ae  = joblib.load(Path(MODELS_DIR)/"ae_scaler.joblib")
clf_ae = joblib.load(Path(MODELS_DIR)/"ae_clf.joblib")

# Fine-tuned models (OPTIONAL; load if present)
def load_ft_model(ckpt_path: Path):
    ckpt = torch.load(ckpt_path, map_location="cpu")
    arch = ckpt.get("arch","")
    ncls = int(ckpt.get("num_classes", 2))
    if arch == "resnet50":
        m = models.resnet50(weights=None)
        m.fc = nn.Linear(m.fc.in_features, ncls)
    elif arch == "vit_b16":
        m = models.vit_b_16(weights=None)
        m.heads.head = nn.Linear(m.heads.head.in_features, ncls)
    else:
        raise ValueError(f"Unknown arch in ckpt: {arch}")
    m.load_state_dict(ckpt["state_dict"], strict=True)
    return m.eval().to(device)

resnet_ft = load_ft_model(Path(MODELS_DIR)/"resnet50_ft.pt") if Path(MODELS_DIR, "resnet50_ft.pt").exists() else None
vit_ft    = load_ft_model(Path(MODELS_DIR)/"vit_b16_ft.pt")  if Path(MODELS_DIR, "vit_b16_ft.pt").exists()  else None

# ---------- Prediction ----------
def _norm(x):  return (x - MEAN.to(x.device)) / STD.to(x.device)

@torch.no_grad()
def _predict_linear_backbones(x):
    preds = {}
    # resnet linear
    out_r = resnet(x)
    if out_r.dim()==4: out_r = out_r.mean((2,3))
    Xr = out_r.cpu().numpy(); pr = int(clf_r.predict(sc_r.transform(Xr))[0])
    cr = float(clf_r.predict_proba(sc_r.transform(Xr)).max(axis=1)[0])
    preds["resnet50_linear"] = {"pred_level": pr, "risk": "Yes" if (pr>=1 if not BINARY_MODE else pr==1) else "No", "confidence": cr}
    # vit linear
    out_v = vit(x)
    if out_v.dim()==4: out_v = out_v.mean((2,3))
    Xv = out_v.cpu().numpy(); pv = int(clf_v.predict(sc_v.transform(Xv))[0])
    cv = float(clf_v.predict_proba(sc_v.transform(Xv)).max(axis=1)[0])
    preds["vit_b16_linear"] = {"pred_level": pv, "risk": "Yes" if (pv>=1 if not BINARY_MODE else pv==1) else "No", "confidence": cv}
    # ae linear
    out_a = ae_backbone(x)
    if out_a.dim()==4: out_a = out_a.mean((2,3))
    Xa = out_a.cpu().numpy(); pa = int(clf_ae.predict(sc_ae.transform(Xa))[0])
    ca = float(clf_ae.predict_proba(sc_ae.transform(Xa)).max(axis=1)[0])
    preds["autoencoder_linear"] = {"pred_level": pa, "risk": "Yes" if (pa>=1 if not BINARY_MODE else pa==1) else "No", "confidence": ca}
    return preds

@torch.no_grad()
def _predict_ft(model, x, tta=3, name="model_ft"):
    if model is None:
        return {}
    logits_sum = 0
    for t in range(tta):
        xi = x
        if t % 2 == 1: xi = torch.flip(xi, dims=[3])
        if t % 4 == 3: xi = torch.flip(xi, dims=[2])
        logits_sum = logits_sum + model(xi)
    logits = logits_sum / float(tta)
    probs = torch.softmax(logits, dim=1)
    conf, pred = probs.max(dim=1)
    pred = int(pred.item()); conf = float(conf.item())
    risk = "Yes" if (pred>=1 if not BINARY_MODE else pred==1) else "No"
    return {name: {"pred_level": pred, "risk": risk, "confidence": conf}}

@torch.no_grad()
def predict_one_pil(pil_img):
    im = ImageOps.exif_transpose(pil_img).convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE))
    arr = np.asarray(im, dtype=np.float32)/255.0
    x = torch.from_numpy(np.transpose(arr,(2,0,1))).unsqueeze(0).to(device)
    x = _norm(x)

    preds = {}
    preds.update(_predict_linear_backbones(x))
    preds.update(_predict_ft(resnet_ft, x, tta=3, name="resnet50_finetune") if resnet_ft else {})
    preds.update(_predict_ft(vit_ft,    x, tta=3, name="vit_b16_finetune") if vit_ft    else {})

    # consensus
    from collections import Counter
    levels = [preds[k]["pred_level"] for k in preds]
    risks  = [preds[k]["risk"] for k in preds]
    res = {
        "consensus_level": int(Counter(levels).most_common(1)[0][0]),
        "consensus_risk":  Counter(risks).most_common(1)[0][0],
        "all_models_risk": "Yes" if all(r=="Yes" for r in risks) else "No",
        "per_model": preds
    }
    return res

# ---------- Pretty formatting ----------
def preds_to_markdown(result: dict) -> str:
    pm = result.get("per_model", {})
    lines = []
    lines.append(f"### Consensus\n")
    lines.append(f"- **Level**: `{result['consensus_level']}`")
    lines.append(f"- **Referable DR risk**: **{result['consensus_risk']}**")
    lines.append(f"- **All models agree on risk**: **{result['all_models_risk']}**\n")
    lines.append("### Per-model predictions")
    lines.append("| Model | Level | Risk | Confidence |")
    lines.append("|---|---:|:---:|---:|")
    for name, v in pm.items():
        lines.append(f"| `{name}` | {v['pred_level']} | {v['risk']} | {v['confidence']:.3f} |")
    # Note if some FT models are missing
    missing = []
    if resnet_ft is None: missing.append("resnet50_finetune")
    if vit_ft is None:    missing.append("vit_b16_finetune")
    if missing:
        lines.append(f"\n> _Note: Fine-tuned checkpoints not found for: {', '.join(missing)}. Showing available models only._")
    return "\n".join(lines)

# ---------- Gradio (auto-terminate after first prediction) ----------
try:
    import gradio
except Exception:
    import sys, subprocess
    subprocess.run([sys.executable, "-m", "pip", "install", "-q", "gradio==4.44.0"])
import gradio as gr

AUTO_CLOSE_AFTER_PRED = True
HARD_EXIT_FALLBACK    = True
_shutdown_started = {"done": False}
demo = None

def _graceful_exit():
    try:
        if demo is not None:
            demo.close()
            print("[gradio] Server closed.")
    except Exception as e:
        print("[gradio] close() failed:", e)
    sys.stdout.flush()
    time.sleep(0.5)
    try:
        from google.colab import runtime
        print("[exit] Releasing Colab VM...")
        sys.stdout.flush()
        runtime.unassign()
    except Exception:
        pass
    try:
        import IPython
        ip = IPython.get_ipython()
        if ip and hasattr(ip, "kernel"):
            print("[exit] Stopping IPython kernel.")
            sys.stdout.flush()
            ip.kernel.do_shutdown(restart=False)
    except Exception:
        pass
    if HARD_EXIT_FALLBACK:
        try:
            print("[exit] Hard-exit fallback.")
            sys.stdout.flush()
            os._exit(0)
        except Exception:
            pass

def ui_fn(img):
    if img is None:
        return "No image provided."
    result = predict_one_pil(img)
    md = preds_to_markdown(result)
    if AUTO_CLOSE_AFTER_PRED and not _shutdown_started["done"]:
        _shutdown_started["done"] = True
        threading.Timer(1.5, _graceful_exit).start()
    return md

iface = gr.Interface(
    fn=ui_fn,
    inputs=gr.Image(type="pil", label="Upload retina image"),
    outputs=gr.Markdown(label="Prediction"),
    title="DR Predictor — One-and-done UI",
    allow_flagging="never",
)

# serve static artifacts from Cell 1 (optional)
try:
    from starlette.staticfiles import StaticFiles
    iface.app.mount("/api/files", StaticFiles(directory=FILES_DIR), name="files")
    print(f"[STATIC] Mounted {FILES_DIR} at /api/files")
except Exception as e:
    print("[STATIC] Mount skipped:", e)

demo = iface
iface.launch(share=True, debug=True, prevent_thread_lock=True)
print("Ready. After your first prediction, this runtime will shut down automatically.")
