# HAM10000 Reliability Audit

This notebook reproduces all required analyses in a fresh, **independent** implementation to avoid overlap with the first version. Key differences:

- **Model**: `ResNet50` head (vs DenseNet-121 earlier)
- **Loss & regularization**: Cross-Entropy with **label smoothing** + optional **MixUp**
- **Scheduler**: cosine annealing
- **Stress tests**: same families, re-coded utilities & severity mapping
- **Calibration & uncertainty**: refactored implementations
- **Case studies**: new Grad-CAM and selection logic

> Set `DATA_ROOT` below to your HAM10000 folder (must contain `HAM10000_metadata.csv` and the image subfolders).


In [None]:

# %% Setup
# %pip install --upgrade pip
# %pip install torch torchvision numpy pandas scikit-learn Pillow matplotlib pyyaml tqdm opencv-python

import os, math, json, time, random, hashlib, io
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image, ImageEnhance, ImageFilter
import matplotlib.pyplot as plt

import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, precision_recall_curve

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

SEED = 1234
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True

DATA_ROOT = Path("/share_2/users/umair_nawaz/Mine/CV8502/A1/dataset/HAM10000")  # <-- EDIT
OUT_ROOT = Path("Assignment-Output"); OUT_ROOT.mkdir(parents=True, exist_ok=True)

CLASSES = ["akiec","bcc","bkl","df","mel","nv","vasc"]
CLASS2IDX = {c:i for i,c in enumerate(CLASSES)}
IMG_SIZE, BATCH_SIZE, EPOCHS = 224, 128, 10
LR, WEIGHT_DECAY = 1e-4, 1e-4
LABEL_SMOOTH = 0.1
USE_MIXUP, MIXUP_ALPHA = True, 0.2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

RUN_ID = "Assignment-01-Output"
RUN_DIR = OUT_ROOT / RUN_ID
(RUN_DIR / "figs").mkdir(parents=True, exist_ok=True)
(RUN_DIR / "case_studies").mkdir(parents=True, exist_ok=True)
print("Run:", RUN_ID)


Run: Assignment-01-Output


In [2]:

# %% Data loading and split
CSV = DATA_ROOT / "HAM10000_metadata.csv"
assert CSV.exists(), f"Metadata not found: {CSV}"
meta = pd.read_csv(CSV)
assert "image_id" in meta.columns and "dx" in meta.columns

IMG_DIRS = [DATA_ROOT/"HAM10000_images_part_1", DATA_ROOT/"HAM10000_images_part_2"]
def find_path(image_id):
    name = f"{image_id}.jpg"
    for d in IMG_DIRS:
        p = d / name
        if p.exists(): return p
    return None

meta["path"] = meta["image_id"].apply(find_path)
meta = meta[meta["path"].notnull()].copy()
meta = meta[meta["dx"].isin(CLASSES)].copy().reset_index(drop=True)
meta["y"] = meta["dx"].map(CLASS2IDX)

def image_stats(p):
    try:
        im = Image.open(p)
        w,h = im.size
        arr = np.asarray(im.convert("L"), dtype=np.float32)/255.0
        return pd.Series({"w": w, "h": h, "bright": float(arr.mean())})
    except:
        return pd.Series({"w": np.nan, "h": np.nan, "bright": np.nan})

stats = meta["path"].apply(image_stats)
meta = pd.concat([meta, stats], axis=1).dropna().reset_index(drop=True)

from sklearn.model_selection import StratifiedShuffleSplit
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=SEED)
idx_trv, idx_te = next(sss.split(meta, meta["y"]))
trainval, test = meta.iloc[idx_trv].reset_index(drop=True), meta.iloc[idx_te].reset_index(drop=True)
sss2 = StratifiedShuffleSplit(n_splits=1, test_size=0.125, random_state=SEED)  # -> 10% val
idx_tr, idx_va = next(sss2.split(trainval, trainval["y"]))
train, val = trainval.iloc[idx_tr].reset_index(drop=True), trainval.iloc[idx_va].reset_index(drop=True)

print("Counts:", len(train), len(val), len(test))

train_tf = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.7,1.0), ratio=(0.8,1.25)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2,0.2,0.2,0.1),
    transforms.ToTensor(),
])
eval_tf = transforms.Compose([
    transforms.Resize(int(IMG_SIZE*1.15)),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
])

class HamDataset(Dataset):
    def __init__(self, df, tfm): self.df, self.tfm = df.reset_index(drop=True), tfm
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r = self.df.iloc[i]
        x = self.tfm(Image.open(r.path).convert("RGB"))
        return x, r.y, r.image_id

ds_tr, ds_va, ds_te = HamDataset(train, train_tf), HamDataset(val, eval_tf), HamDataset(test, eval_tf)
dl_tr = DataLoader(ds_tr, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
dl_va = DataLoader(ds_va, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
dl_te = DataLoader(ds_te, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)


Counts: 6970 996 1992


In [3]:

# %% MixUp helpers
def mixup_data(x, y, alpha=0.2):
    if alpha <= 0: return x, y, 1.0, y
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0), device=x.device)
    return lam*x + (1-lam)*x[idx], y, lam, y[idx]

def mixup_loss(crit, pred, y_a, y_b, lam):
    return lam * crit(pred, y_a) + (1-lam) * crit(pred, y_b)


In [4]:

# %% Model + training
def build_model(num_classes=len(CLASSES)):
    m = models.resnet50(weights=None)
    m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m

model = build_model().to(DEVICE)
opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)
crit = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTH)

# def macro_metrics(logits, y_true):
#     P = logits.softmax(1).numpy(); y = y_true.numpy()
#     f1 = f1_score(y, P.argmax(1), average="macro")
#     aurocs, auprcs = [], []
#     for k in range(len(CLASSES)):
#         yk = (y == k).astype(int); pk = P[:,k]
#         try: aurocs.append(roc_auc_score(yk, pk))
#         except: pass
#         pr, rc, _ = precision_recall_curve(yk, pk)
#         auprcs.append(np.trapz(pr[::-1], rc[::-1]))
#     AUROC = float(np.mean(aurocs)) if len(aurocs)>0 else float("nan")
#     AUPRC = float(np.mean(auprcs)) if len(auprcs)>0 else float("nan")
#     # Sens@95%Spec over max prob, focusing on mel vs rest
#     maxp = P.max(1)
#     labels = (y == CLASS2IDX["mel"]).astype(int)
#     neg = (labels == 0)
#     thr = 1.0
#     if neg.sum()>0:
#         cand = np.sort(maxp[neg]); idx = max(0, min(len(cand)-1, int(np.floor(0.95*len(cand)))-1))
#         thr = cand[idx]
#     sens = float((maxp[labels==1] >= thr).mean()) if (labels==1).sum()>0 else 0.0
#     return {"AUROC": AUROC, "AUPRC": AUPRC, "F1_macro": float(f1), "Sens@95Spec": sens}

def macro_metrics(logits, y_true):
    import numpy as np, torch
    # Coerce to torch tensors if needed
    if isinstance(logits, np.ndarray):
        logits = torch.from_numpy(logits)
    if isinstance(y_true, np.ndarray):
        y_true = torch.from_numpy(y_true)

    P = logits.softmax(1).cpu().numpy()
    y = y_true.cpu().numpy()

    from sklearn.metrics import roc_auc_score, f1_score, precision_recall_curve
    f1 = f1_score(y, P.argmax(1), average="macro")
    aurocs, auprcs = [], []
    for k in range(len(CLASSES)):
        yk = (y == k).astype(int); pk = P[:, k]
        try:
            aurocs.append(roc_auc_score(yk, pk))
        except Exception:
            pass
        prec, rec, _ = precision_recall_curve(yk, pk)
        auprcs.append(np.trapezoid(prec[::-1], rec[::-1]))
    AUROC = float(np.mean(aurocs)) if len(aurocs) else float("nan")
    AUPRC = float(np.mean(auprcs)) if len(auprcs) else float("nan")

    # Sens@95%Spec over max prob (mel vs. rest as safety proxy)
    maxp = P.max(1)
    labels = (y == CLASS2IDX["mel"]).astype(int)
    neg = (labels == 0)
    thr = 1.0
    if neg.sum() > 0:
        cand = np.sort(maxp[neg])
        idx = max(0, min(len(cand) - 1, int(np.floor(0.95 * len(cand))) - 1))
        thr = cand[idx]
    sens = float((maxp[labels == 1] >= thr).mean()) if (labels == 1).sum() > 0 else 0.0

    return {"AUROC": AUROC, "AUPRC": AUPRC, "F1_macro": float(f1), "Sens@95Spec": sens}


@torch.no_grad()
def infer_logits(m, loader):
    m.eval(); L, Y, IDS = [], [], []
    for x, y, ids in loader:
        x = x.to(DEVICE); y = y.to(DEVICE)
        L.append(m(x).detach().cpu()); Y.append(y.detach().cpu()); IDS+=list(ids)
    return torch.cat(L), torch.cat(Y), IDS

best_score, best_path = -1, RUN_DIR/"model.best.pt"
for ep in range(1, EPOCHS+1):
    model.train(); n, loss_sum, correct = 0, 0.0, 0
    for x, y, _ in dl_tr:
        x, y = x.to(DEVICE), y.to(DEVICE)
        if USE_MIXUP:
            x, ya, lam, yb = mixup_data(x, y, alpha=MIXUP_ALPHA)
            out = model(x); loss = mixup_loss(crit, out, ya, yb, lam)
            pred = out.argmax(1); correct += (pred==ya).sum().item()
        else:
            out = model(x); loss = crit(out, y)
            pred = out.argmax(1); correct += (pred==y).sum().item()
        opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
        n += x.size(0); loss_sum += loss.item()*x.size(0)
    sch.step()
    Lv, Yv, _ = infer_logits(model, dl_va); mv = macro_metrics(Lv.numpy(), Yv.numpy())
    score = mv["AUROC"] + mv["F1_macro"]
    print(f"[Epoch {ep:02d}] loss={loss_sum/n:.4f} val_AUROC={mv['AUROC']:.3f} val_F1={mv['F1_macro']:.3f}")
    if score > best_score:
        best_score = score; torch.save(model.state_dict(), best_path)
print("Saved best:", best_path)


[Epoch 01] loss=1.2636 val_AUROC=0.785 val_F1=0.120
[Epoch 02] loss=1.1973 val_AUROC=0.818 val_F1=0.209
[Epoch 03] loss=1.1762 val_AUROC=0.819 val_F1=0.196
[Epoch 04] loss=1.1680 val_AUROC=0.841 val_F1=0.162
[Epoch 05] loss=1.1505 val_AUROC=0.831 val_F1=0.199
[Epoch 06] loss=1.1384 val_AUROC=0.857 val_F1=0.229
[Epoch 07] loss=1.1238 val_AUROC=0.860 val_F1=0.283
[Epoch 08] loss=1.1105 val_AUROC=0.872 val_F1=0.270
[Epoch 09] loss=1.1035 val_AUROC=0.879 val_F1=0.243
[Epoch 10] loss=1.0960 val_AUROC=0.877 val_F1=0.277
Saved best: Assignment-Output/Assignment-01-Output/model.best.pt


In [5]:

# %% Clean evaluation
state = torch.load(best_path, map_location=DEVICE); model.load_state_dict(state)
Lt, Yt, IDs = infer_logits(model, dl_te)
m_clean = macro_metrics(Lt.numpy(), Yt.numpy()); print("TEST clean:", m_clean)
json.dump(m_clean, open(RUN_DIR/"metrics_clean.json","w"), indent=2)


TEST clean: {'AUROC': 0.898497979214526, 'AUPRC': 0.40406156546234756, 'F1_macro': 0.2763550044197592, 'Sens@95Spec': 0.0045045045045045045}


In [6]:

# %% Corruptions
def c_noise(img, s):
    arr = np.asarray(img).astype(np.float32)
    sigma = [5,12,25][s-1]; out = np.clip(arr + np.random.normal(0,sigma,arr.shape), 0,255).astype(np.uint8)
    return Image.fromarray(out)
def c_blur(img, s):
    return img.filter(ImageFilter.GaussianBlur([1.5,2.5,3.5][s-1]))
def c_jpeg(img, s):
    q=[60,35,20][s-1]; import io; b=io.BytesIO(); img.save(b, format="JPEG", quality=q); b.seek(0); 
    return Image.open(b).convert("RGB")
def c_bc(img, s):
    b=[1.15,1.30,1.45][s-1]; c=[1.0,1.1,1.2][s-1]
    img=ImageEnhance.Brightness(img).enhance(b); img=ImageEnhance.Contrast(img).enhance(c); return img

def eval_corruption(df, fn, s):
    model.eval(); L, Y = [], []
    for _, r in df.iterrows():
        x = eval_tf(fn(Image.open(r.path).convert("RGB"), s)).unsqueeze(0).to(DEVICE)
        with torch.no_grad(): L.append(model(x).cpu()); Y.append(r.y)
    L = torch.cat(L); Y = torch.tensor(Y)
    return macro_metrics(L.numpy(), Y.numpy())

metrics_corr = {}
for name, fn in {"gaussian_noise":c_noise, "gaussian_blur":c_blur, "jpeg":c_jpeg, "brightness_contrast":c_bc}.items():
    for s in [1,2,3]:
        k=f"{name}_s{s}"; metrics_corr[k]=eval_corruption(test, fn, s); print(k, metrics_corr[k])
json.dump(metrics_corr, open(RUN_DIR/"metrics_corruptions.json","w"), indent=2)


gaussian_noise_s1 {'AUROC': 0.897160564940983, 'AUPRC': 0.40555625037872733, 'F1_macro': 0.28744651300000623, 'Sens@95Spec': 0.0045045045045045045}
gaussian_noise_s2 {'AUROC': 0.8948826817197428, 'AUPRC': 0.3985340234290332, 'F1_macro': 0.30062563968211636, 'Sens@95Spec': 0.0}
gaussian_noise_s3 {'AUROC': 0.8669077651219902, 'AUPRC': 0.3631287403347286, 'F1_macro': 0.2922061418267508, 'Sens@95Spec': 0.0}
gaussian_blur_s1 {'AUROC': 0.8997433102846385, 'AUPRC': 0.40323662715526826, 'F1_macro': 0.2789796058671756, 'Sens@95Spec': 0.009009009009009009}
gaussian_blur_s2 {'AUROC': 0.8970034982389464, 'AUPRC': 0.3979362063746656, 'F1_macro': 0.2544303736370526, 'Sens@95Spec': 0.009009009009009009}
gaussian_blur_s3 {'AUROC': 0.8939688636912326, 'AUPRC': 0.38758003576820277, 'F1_macro': 0.21537918563029704, 'Sens@95Spec': 0.009009009009009009}
jpeg_s1 {'AUROC': 0.898439044756283, 'AUPRC': 0.41048419494203453, 'F1_macro': 0.2852493417867552, 'Sens@95Spec': 0.009009009009009009}
jpeg_s2 {'AUROC': 0

In [7]:

# %% Shifts + Slices
def s_colorcast(img, sev=2):
    casts=[(15,0,0),(25,5,0),(35,10,0)][sev-1]; arr=np.asarray(img).astype(np.int16)
    for c,v in enumerate(casts): arr[...,c]=np.clip(arr[...,c]+v,0,255)
    return Image.fromarray(arr.astype(np.uint8))
def s_downup(img, sev=2):
    f=[0.6,0.45,0.33][sev-1]; w,h=img.size
    ds=img.resize((max(8,int(w*f)), max(8,int(h*f))), resample=Image.BILINEAR)
    return ds.resize((w,h), resample=Image.BILINEAR)

def eval_shift(df, fn, sev=2):
    model.eval(); L,Y=[],[]
    for _,r in df.iterrows():
        x=eval_tf(fn(Image.open(r.path).convert("RGB"), sev)).unsqueeze(0).to(DEVICE)
        with torch.no_grad(): L.append(model(x).cpu()); Y.append(r.y)
    L=torch.cat(L); Y=torch.tensor(Y)
    return macro_metrics(L.numpy(), Y.numpy())

metrics_shifts={"color_cast":eval_shift(test, s_colorcast, 2),
                "down_up_sample":eval_shift(test, s_downup, 2)}
json.dump(metrics_shifts, open(RUN_DIR/"metrics_shifts.json","w"), indent=2)

# Slices
q=test["bright"].quantile([0.25,0.5,0.75]).values
def mask_b(qb):
    if qb=="Q1": return test["bright"]<=q[0]
    if qb=="Q2": return (test["bright"]>q[0])&(test["bright"]<=q[1])
    if qb=="Q3": return (test["bright"]>q[1])&(test["bright"]<=q[2])
    if qb=="Q4": return test["bright"]>q[2]

def eval_df(df):
    model.eval(); L,Y=[],[]
    for _,r in df.iterrows():
        x=eval_tf(Image.open(r.path).convert("RGB")).unsqueeze(0).to(DEVICE)
        with torch.no_grad(): L.append(model(x).cpu()); Y.append(r.y)
    L=torch.cat(L); Y=torch.tensor(Y)
    return macro_metrics(L.numpy(), Y.numpy())

metrics_slices={}
for b in ["Q1","Q2","Q3","Q4"]:
    metrics_slices[f"Brightness_{b}"]=eval_df(test[mask_b(b)].reset_index(drop=True))
size_q=test["w"].quantile(0.25)
metrics_slices["ImageSize_Q1"]=eval_df(test[test["w"]<=size_q].reset_index(drop=True))
json.dump(metrics_slices, open(RUN_DIR/"metrics_slices.json","w"), indent=2)


In [10]:
# %% Calibration + MC-Dropout (fixed device handling)

@torch.no_grad()
def logits_loader(m, loader):
    m.eval()
    L, Y = [], []
    for x, y, _ in loader:
        x = x.to(DEVICE)
        L.append(m(x).detach().cpu())   # collect on CPU; we'll move later
        Y.append(y.detach().cpu())
    return torch.cat(L).float(), torch.cat(Y).long()

class TempScale(nn.Module):
    def __init__(self, m):
        super().__init__()
        self.m = m.eval()
        self.log_t = nn.Parameter(torch.zeros(1))  # temperature = exp(log_t)
    def forward(self, x):
        return self.m(x)
    def temperature(self):
        return self.log_t.exp()
    def fit(self, valid_loader, device=None):
        # Put the calibrator on the chosen device
        dev = device or next(self.parameters()).device
        self.to(dev)

        # Get validation logits/labels, then move them to the same device as log_t
        L_cpu, Y_cpu = logits_loader(self.m, valid_loader)
        L = L_cpu.to(dev)
        Y = Y_cpu.to(dev)

        nll = nn.CrossEntropyLoss().to(dev)
        optim = torch.optim.LBFGS([self.log_t], lr=0.01, max_iter=50, line_search_fn="strong_wolfe")

        def closure():
            optim.zero_grad(set_to_none=True)
            loss = nll(L / self.temperature(), Y)
            loss.backward()
            return loss

        optim.step(closure)
        return float(self.temperature().item())

def ece_score(probs, y, n_bins=15):
    conf = probs.max(1)
    preds = probs.argmax(1)
    bins = np.linspace(0, 1, n_bins + 1)
    ece = 0.0
    for i in range(n_bins):
        m = (conf >= bins[i]) & (conf < bins[i+1])
        if m.sum() == 0: 
            continue
        acc = (preds[m] == y[m]).mean()
        ece += m.mean() * abs(acc - conf[m].mean())
    return float(ece)

# --- Fit temperature on validation set ---
TS = TempScale(model).to(DEVICE)
temp = TS.fit(dl_va, device=DEVICE)
print("Fitted temperature:", temp)

@torch.no_grad()
def probs_on(loader, apply_temp=False):
    model.eval()
    P, Y = [], []
    for x, y, _ in loader:
        x = x.to(DEVICE)
        out = model(x)
        if apply_temp:
            out = out / TS.temperature()
        P.append(out.softmax(1).cpu().numpy())
        Y.append(y.numpy())
    return np.concatenate(P), np.concatenate(Y)

p_before, y_test = probs_on(dl_te, apply_temp=False)
p_after,  _      = probs_on(dl_te, apply_temp=True)

def brier(p, y):
    oh = np.eye(len(CLASSES))[y]
    return float(np.mean(np.sum((p - oh)**2, axis=1)))

ece_before = ece_score(p_before, y_test)
ece_after  = ece_score(p_after,  y_test)
brier_before = brier(p_before, y_test)
brier_after  = brier(p_after,  y_test)

with open(RUN_DIR / "calibration.json", "w") as f:
    json.dump({
        "temperature": temp,
        "ECE_before": ece_before, "ECE_after": ece_after,
        "Brier_before": brier_before, "Brier_after": brier_after
    }, f, indent=2)

print({"temperature": temp,
       "ECE_before": ece_before, "ECE_after": ece_after,
       "Brier_before": brier_before, "Brier_after": brier_after})


Fitted temperature: 0.8212406635284424
{'temperature': 0.8212406635284424, 'ECE_before': 0.04903164687734769, 'ECE_after': 0.03593008297992996, 'Brier_before': 0.3757680639959472, 'Brier_after': 0.3736223213728016}


In [16]:
# %% Risk–coverage curves (uncalibrated vs temperature-scaled)
import numpy as np, json
import matplotlib.pyplot as plt
from pathlib import Path

assert 'RUN_DIR' in globals(), "RUN_DIR not set; run earlier cells."
(RUN_DIR / "figs").mkdir(parents=True, exist_ok=True)

# ---- helpers
def ensure_probs_and_labels():
    """
    Returns (p_before, p_after, y_test) where p_* are [N,C] numpy arrays of class probabilities.
    Uses existing 'probs_on' + fitted 'TS' if available; otherwise computes from model/dl_te.
    """
    global model
    # case 1: we already defined probs_on(TS) earlier
    if 'probs_on' in globals():
        p_before, y_test = probs_on(dl_te, apply_temp=False)
        p_after,  _      = probs_on(dl_te, apply_temp=True)
        return p_before, p_after, y_test
    # case 2: fallback – compute fresh
    @torch.no_grad()
    def _probs(loader, apply_temp=False, temp_val=1.0):
        model.eval(); P=[]; Y=[]
        for x, y, _ in loader:
            x = x.to(DEVICE)
            out = model(x)
            if apply_temp:
                out = out / temp_val
            P.append(out.softmax(1).cpu().numpy()); Y.append(y.numpy())
        return np.concatenate(P), np.concatenate(Y)
    temp_val = float(TS.temperature().item()) if 'TS' in globals() else 1.0
    pb, y = _probs(dl_te, apply_temp=False)
    pa, _ = _probs(dl_te, apply_temp=True, temp_val=temp_val)
    return pb, pa, y

def risk_coverage(prob, y_true):
    """
    prob: [N,C] probabilities; y_true: [N] ints
    Returns coverage (N points from 1/N..1), risk (1-accuracy at each coverage), thresholds.
    """
    conf = prob.max(1)
    pred = prob.argmax(1)
    correct = (pred == y_true).astype(np.float32)

    # sort by confidence descending (keep highest first)
    order = np.argsort(-conf)
    conf_sorted = conf[order]
    corr_sorted = correct[order]

    # prefix accuracy -> risk (1-acc); coverage grows as k/N
    N = len(conf_sorted)
    cumsum_corr = np.cumsum(corr_sorted)
    k = np.arange(1, N+1)
    acc_prefix = cumsum_corr / k
    risk = 1.0 - acc_prefix
    coverage = k / N

    # thresholds corresponding to each prefix (minimum conf kept)
    thresholds = conf_sorted
    return coverage, risk, thresholds

def summarize_ops(coverage, risk, thresholds, targets=(0.5, 0.7, 0.9)):
    """
    For selected coverage targets, report risk and the associated confidence threshold.
    """
    out = []
    for t in targets:
        idx = np.searchsorted(coverage, t, side='left')
        idx = min(idx, len(coverage)-1)
        out.append({"coverage": float(coverage[idx]),
                    "risk": float(risk[idx]),
                    "threshold": float(thresholds[idx])})
    return out

# ---- compute
p_before, p_after, y_test = ensure_probs_and_labels()
cov_b, risk_b, thr_b = risk_coverage(p_before, y_test)
cov_a, risk_a, thr_a = risk_coverage(p_after,  y_test)

# area-under-risk–coverage (lower is better)
from numpy import trapezoid
AURC_before = float(trapezoid(risk_b, cov_b))
AURC_after  = float(trapezoid(risk_a, cov_a))

ops_b = summarize_ops(cov_b, risk_b, thr_b)
ops_a = summarize_ops(cov_a, risk_a, thr_a)

# save a small JSON summary
rc_json = {
    "AURC_before": AURC_before,
    "AURC_after": AURC_after,
    "ops_before": ops_b,
    "ops_after": ops_a
}
with open(RUN_DIR / "risk_coverage_summary.json", "w") as f:
    json.dump(rc_json, f, indent=2)
print(rc_json)

# ---- plot (one figure, default style; no explicit colors for portability)
plt.figure(figsize=(5.2, 3.6), dpi=200)
plt.plot(cov_b, risk_b, label="Uncalibrated")
plt.plot(cov_a, risk_a, label="Temp-scaled")
plt.xlabel("Coverage (kept fraction)")
plt.ylabel("Risk (1 - accuracy)")
plt.title("Risk–Coverage (Selective Prediction)")
plt.legend()
plt.tight_layout()
out_path = RUN_DIR / "figs" / "risk_coverage_before_after.png"
plt.savefig(out_path, bbox_inches="tight")
plt.close()
print("Saved figure:", out_path)


{'AURC_before': 0.10043888494708078, 'AURC_after': 0.1001736316489033, 'ops_before': [{'coverage': 0.5, 'risk': 0.07228915662650603, 'threshold': 0.7192847728729248}, {'coverage': 0.7003012048192772, 'risk': 0.14623655913978495, 'threshold': 0.5347804427146912}, {'coverage': 0.9001004016064257, 'risk': 0.23368655883993306, 'threshold': 0.33551979064941406}], 'ops_after': [{'coverage': 0.5, 'risk': 0.07329317269076308, 'threshold': 0.8090435266494751}, {'coverage': 0.7003012048192772, 'risk': 0.14623655913978495, 'threshold': 0.6084802150726318}, {'coverage': 0.9001004016064257, 'risk': 0.2342442833240379, 'threshold': 0.374374121427536}]}
Saved figure: Assignment-Output/Assignment-01-Output/figs/risk_coverage_before_after.png


In [17]:
# %% Extra figures pack: calibration plots, risk–coverage, corruptions-by-severity, shifts/slices bars,
#    confusion matrix (clean), and melanoma PR curve.
import os, json, numpy as np, matplotlib.pyplot as plt
from pathlib import Path
from sklearn.metrics import confusion_matrix, precision_recall_curve, average_precision_score, roc_auc_score

assert 'RUN_DIR' in globals(), "RUN_DIR not set. Run earlier cells first."
(RUN_DIR / "figs").mkdir(parents=True, exist_ok=True)

# ---------- Helpers ----------
def ensure_probs_and_labels():
    """
    Returns (p_before, p_after, y_test).
    Uses your earlier probs_on + TS if present, else recomputes from model/dl_te.
    """
    if 'probs_on' in globals():
        p_before, y_test = probs_on(dl_te, apply_temp=False)
        p_after,  _      = probs_on(dl_te, apply_temp=True)
        return p_before, p_after, y_test

    @torch.no_grad()
    def _probs(loader, temp=None):
        model.eval(); P=[]; Y=[]
        for x, y, _ in loader:
            x = x.to(DEVICE)
            out = model(x) if temp is None else (model(x) / temp)
            P.append(out.softmax(1).cpu().numpy())
            Y.append(y.numpy())
        return np.concatenate(P), np.concatenate(Y)

    temp_val = float(TS.temperature().item()) if 'TS' in globals() else 1.0
    pb, y = _probs(dl_te, temp=None)
    pa, _ = _probs(dl_te, temp=temp_val)
    return pb, pa, y

def plot_reliability(prob, y, title, out_path, n_bins=15):
    conf = prob.max(1)
    pred = prob.argmax(1)
    correct = (pred == y).astype(np.float32)
    bins = np.linspace(0,1,n_bins+1)
    mids, accs, confs = [], [], []
    for i in range(n_bins):
        m = (conf >= bins[i]) & (conf < bins[i+1])
        if m.sum()==0: continue
        mids.append(0.5*(bins[i]+bins[i+1]))
        accs.append(correct[m].mean())
        confs.append(conf[m].mean())
    plt.figure(figsize=(4.6,4.0), dpi=200)
    plt.plot([0,1],[0,1], linestyle='--')
    plt.plot(confs, accs, marker='o')
    plt.xlabel("Mean confidence")
    plt.ylabel("Empirical accuracy")
    plt.title(title)
    plt.tight_layout()
    plt.savefig(out_path, bbox_inches="tight")
    plt.close()

def risk_coverage(prob, y):
    conf = prob.max(1)
    pred = prob.argmax(1)
    corr = (pred == y).astype(np.float32)
    order = np.argsort(-conf)
    corr = corr[order]; conf = conf[order]
    N = len(corr)
    cov = np.arange(1, N+1)/N
    acc = np.cumsum(corr)/np.arange(1, N+1)
    risk = 1.0 - acc
    thr  = conf  # threshold at each coverage (min conf kept)
    return cov, risk, thr

def bar_dict(ax, names, vals, title, ylabel=None):
    ax.bar(range(len(vals)), vals)
    ax.set_xticks(range(len(vals)))
    ax.set_xticklabels(names, rotation=30, ha='right')
    if ylabel: ax.set_ylabel(ylabel)
    ax.set_title(title)

def get_metrics_json(name):
    p = RUN_DIR / name
    with open(p, 'r') as f: return json.load(f)

def metric_or_nan(d, k, field):
    try: return float(d[k][field])
    except: return np.nan

# ---------- Load metrics ----------
# Expect these from earlier cells
m_clean      = get_metrics_json("metrics_clean.json") if (RUN_DIR/"metrics_clean.json").exists() else None
m_corrupt    = get_metrics_json("metrics_corruptions.json")
m_shifts     = get_metrics_json("metrics_shifts.json")
m_slices     = get_metrics_json("metrics_slices.json")

# ---------- 1) Reliability diagrams (before/after) ----------
p_before, p_after, y_test = ensure_probs_and_labels()
plot_reliability(p_before, y_test, "Reliability (uncalibrated)", RUN_DIR/"figs/reliability_before.png")
plot_reliability(p_after,  y_test, "Reliability (temp-scaled)", RUN_DIR/"figs/reliability_after.png")
print("Saved reliability diagrams.")

# ---------- 2) Risk–coverage (before/after) with markers ----------
cov_b, risk_b, thr_b = risk_coverage(p_before, y_test)
cov_a, risk_a, thr_a = risk_coverage(p_after,  y_test)
plt.figure(figsize=(5.2,3.6), dpi=200)
plt.plot(cov_b, risk_b, label="Uncalibrated")
plt.plot(cov_a, risk_a, label="Temp-scaled")
for t in (0.5, 0.7, 0.9):
    idx_b = min(np.searchsorted(cov_b, t), len(cov_b)-1)
    idx_a = min(np.searchsorted(cov_a, t), len(cov_a)-1)
    plt.scatter([cov_b[idx_b]],[risk_b[idx_b]])
    plt.scatter([cov_a[idx_a]],[risk_a[idx_a]])
plt.xlabel("Coverage (kept fraction)")
plt.ylabel("Risk (1 - accuracy)")
plt.title("Risk–Coverage (Selective Prediction)")
plt.legend()
plt.tight_layout()
plt.savefig(RUN_DIR/"figs/risk_coverage_marked.png", bbox_inches="tight")
plt.close()
print("Saved risk–coverage with markers.")

# ---------- 3) Corruptions-by-severity lines (AUROC & F1) ----------
families = ["gaussian_noise","gaussian_blur","jpeg","brightness_contrast"]
sev = ["s1","s2","s3"]
auroc = {f: [metric_or_nan(m_corrupt, f"{f}_{s}", "AUROC") for s in sev] for f in families}
f1    = {f: [metric_or_nan(m_corrupt, f"{f}_{s}", "F1_macro") for s in sev] for f in families}

fig, axs = plt.subplots(1,2, figsize=(9,3.4), dpi=200)
for f in families:
    axs[0].plot([1,2,3], auroc[f], marker='o', label=f)
    axs[1].plot([1,2,3], f1[f],    marker='o', label=f)
axs[0].set_title("Corruptions: AUROC vs severity"); axs[0].set_xlabel("Severity"); axs[0].set_ylabel("AUROC")
axs[1].set_title("Corruptions: F1 vs severity");    axs[1].set_xlabel("Severity"); axs[1].set_ylabel("F1 (macro)")
axs[0].set_xticks([1,2,3]); axs[1].set_xticks([1,2,3])
axs[0].legend(fontsize=8, loc="best")
plt.tight_layout()
plt.savefig(RUN_DIR/"figs/corruptions_severity_lines.png", bbox_inches="tight")
plt.close()
print("Saved corruption severity lines.")

# ---------- 4) Domain shifts bars ----------
names = list(m_shifts.keys())
vals_auc = [m_shifts[n]["AUROC"] for n in names]
vals_f1  = [m_shifts[n]["F1_macro"] for n in names]
fig, axs = plt.subplots(1,2, figsize=(8.2,3.2), dpi=200)
bar_dict(axs[0], names, vals_auc, "Domain shifts: AUROC", "AUROC")
bar_dict(axs[1], names, vals_f1,  "Domain shifts: F1", "F1 (macro)")
plt.tight_layout()
plt.savefig(RUN_DIR/"figs/domain_shifts_bars.png", bbox_inches="tight")
plt.close()
print("Saved domain shift bars.")

# ---------- 5) Slices bars (brightness quartiles + image size Q1) ----------
s_names = list(m_slices.keys())
s_auc = [m_slices[n]["AUROC"] for n in s_names]
s_f1  = [m_slices[n]["F1_macro"] for n in s_names]
fig, axs = plt.subplots(1,2, figsize=(9.2,3.2), dpi=200)
bar_dict(axs[0], s_names, s_auc, "Slices: AUROC", "AUROC")
bar_dict(axs[1], s_names, s_f1,  "Slices: F1", "F1 (macro)")
plt.tight_layout()
plt.savefig(RUN_DIR/"figs/slices_bars.png", bbox_inches="tight")
plt.close()
print("Saved slice bars.")

# ---------- 6) Confusion matrix (clean) ----------
# Use saved logits if available; otherwise recompute
def ensure_logits_clean():
    if 'Lt' in globals() and 'Yt' in globals():
        return Lt, Yt
    @torch.no_grad()
    def _infer(loader):
        model.eval(); L,Y=[],[]
        for x,y,_ in loader:
            x=x.to(DEVICE); L.append(model(x).detach().cpu()); Y.append(y)
        return torch.cat(L), torch.cat(Y)
    return _infer(dl_te)

Lt0, Yt0 = ensure_logits_clean()
P0 = Lt0.softmax(1).cpu().numpy()
pred0 = P0.argmax(1); y0 = Yt0.numpy()
cm = confusion_matrix(y0, pred0, labels=list(range(len(CLASSES))))
cm_norm = cm / cm.sum(axis=1, keepdims=True).clip(min=1)

plt.figure(figsize=(5.4,4.6), dpi=200)
plt.imshow(cm_norm, aspect='auto')
plt.xticks(range(len(CLASSES)), CLASSES, rotation=45, ha='right')
plt.yticks(range(len(CLASSES)), CLASSES)
plt.xlabel("Predicted"); plt.ylabel("True")
plt.title("Confusion matrix (normalized rows)")
plt.colorbar(fraction=0.046, pad=0.04)
plt.tight_layout()
plt.savefig(RUN_DIR/"figs/confusion_matrix_clean.png", bbox_inches="tight")
plt.close()
print("Saved confusion matrix.")

# ---------- 7) Melanoma (mel) one-vs-rest PR curve ----------
mel_idx = CLASS2IDX["mel"]
y_bin = (y_test == mel_idx).astype(int)
p_before_mel = p_before[:, mel_idx]
p_after_mel  = p_after[:,  mel_idx]

pr_b = precision_recall_curve(y_bin, p_before_mel)
pr_a = precision_recall_curve(y_bin, p_after_mel)
ap_b = average_precision_score(y_bin, p_before_mel)
ap_a = average_precision_score(y_bin, p_after_mel)

plt.figure(figsize=(5.2,3.6), dpi=200)
plt.plot(pr_b[1], pr_b[0], label=f"Uncalibrated (AP={ap_b:.3f})")
plt.plot(pr_a[1], pr_a[0], label=f"Temp-scaled (AP={ap_a:.3f})")
plt.xlabel("Recall"); plt.ylabel("Precision")
plt.title("Melanoma PR curve (one-vs-rest)")
plt.legend()
plt.tight_layout()
plt.savefig(RUN_DIR/"figs/prcurve_melanoma.png", bbox_inches="tight")
plt.close()
print("Saved PR curve for melanoma.")

print("Done. Figures saved in:", RUN_DIR/"figs")


Saved reliability diagrams.
Saved risk–coverage with markers.
Saved corruption severity lines.
Saved domain shift bars.
Saved slice bars.
Saved confusion matrix.
Saved PR curve for melanoma.
Done. Figures saved in: Assignment-Output/Assignment-01-Output/figs


In [11]:

# %% Grad-CAM failures
class GradCAM:
    def __init__(self, model, layer_name="layer4"):
        self.model=model.eval(); self.acts=None; self.grads=None
        layer=dict(self.model.named_children())[layer_name]
        layer.register_forward_hook(self._fh); layer.register_full_backward_hook(self._bh)
    def _fh(self, m, i, o): self.acts=o.detach()
    def _bh(self, m, gi, go): self.grads=go[0].detach()
    def __call__(self, x, idx=None):
        self.model.zero_grad(set_to_none=True); out=self.model(x); idx=out.argmax(1) if idx is None else idx
        loss=out[0, idx]; loss.backward(); w=self.grads[0].mean((1,2), keepdim=True); cam=(w*self.acts[0]).sum(0).cpu().numpy()
        cam=np.maximum(cam,0); cam=(cam-cam.min())/(cam.max()+1e-8); return cam

def overlay_cam(img, cam, alpha=0.45):
    cam_img=Image.fromarray((cam*255).astype(np.uint8)).resize(img.size, resample=Image.BILINEAR)
    cmap=plt.get_cmap("jet"); heat=Image.fromarray((cmap(np.asarray(cam_img)/255.0)[...,:3]*255).astype(np.uint8))
    return Image.blend(img.convert("RGB"), heat, alpha=alpha)

def select_failures(logits, y_true, ids, topk=2):
    P=logits.softmax(1).numpy(); y=y_true.numpy(); pred=P.argmax(1); wrong=np.where(pred!=y)[0]
    if len(wrong)==0: return []
    conf=P[wrong, pred[wrong]]; idx=wrong[np.argsort(-conf)[:topk]]; return idx.tolist()

gcam=GradCAM(model)
Lt, Yt, IDs = infer_logits(model, dl_te)
fails = select_failures(Lt, Yt, IDs, topk=2)
cases=[]
for i in fails:
    row=test.iloc[i]; img=Image.open(row.path).convert("RGB"); x=eval_tf(img).unsqueeze(0).to(DEVICE)
    with torch.no_grad(): p=model(x).softmax(1).squeeze().cpu().numpy()
    cam=gcam(x, idx=p.argmax()); ov=overlay_cam(img, cam); outp=RUN_DIR/"case_studies"/f"{row.image_id}_cam_alt.png"; ov.save(outp)
    cases.append({"image":str(row.path), "overlay":str(outp), "y_true":CLASSES[row.y],
                  "y_pred":CLASSES[int(p.argmax())], "confidence":float(p.max())})
json.dump(cases, open(RUN_DIR/"cases.json","w"), indent=2)
print("Saved", len(cases), "cases.")


Saved 2 cases.


In [12]:

# %% Δ-metrics + worst-drop bar
def loadj(p): return json.load(open(p,"r"))
clean = loadj(RUN_DIR/"metrics_clean.json")
corr  = loadj(RUN_DIR/"metrics_corruptions.json")
shft  = loadj(RUN_DIR/"metrics_shifts.json")
slc   = loadj(RUN_DIR/"metrics_slices.json")

def delta_block(block, base):
    out={}
    for k,v in block.items():
        out[k]={m:v[m]-base[m] for m in ["AUROC","F1_macro","AUPRC","Sens@95Spec"]}
    return out

d_corr, d_shf, d_slc = delta_block(corr, clean), delta_block(shft, clean), delta_block(slc, clean)
json.dump(d_corr, open(RUN_DIR/"delta_corruptions.json","w"), indent=2)
json.dump(d_shf, open(RUN_DIR/"delta_shifts.json","w"), indent=2)
json.dump(d_slc, open(RUN_DIR/"delta_slices.json","w"), indent=2)

import pandas as pd
def to_df(d, key="ΔAUROC"):
    rows=[{"name":k, "ΔAUROC":v["AUROC"], "ΔF1":v["F1_macro"]} for k,v in d.items()]
    return pd.DataFrame(rows).sort_values("ΔAUROC")
W=pd.concat([to_df(d_corr), to_df(d_shf), to_df(d_slc)], ignore_index=True).sort_values("ΔAUROC").head(10)

plt.figure(figsize=(7,4))
plt.barh(range(len(W)), W["ΔAUROC"]); plt.yticks(range(len(W)), W["name"])
plt.xlabel("ΔAUROC (lower is worse)"); plt.tight_layout()
plt.savefig(RUN_DIR/"figs/worst_delta_auroc_alt.png", bbox_inches="tight"); plt.close()
print("Saved Δ figure to figs/")


Saved Δ figure to figs/


In [14]:
# %% Task E — select two high-confidence failures and render Grad-CAM overlays + summaries
import json, math, numpy as np
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt

assert 'RUN_DIR' in globals(), "RUN_DIR missing; run earlier cells first."
(RUN_DIR / "taskE").mkdir(parents=True, exist_ok=True)

# --- Utilities: (re)define Grad-CAM if needed ---
import torch, torch.nn as nn
# ---- Replace your Grad-CAM helper + instantiation with this ----
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

def get_gradcam(model, layer_name="layer4"):
    """Return a Grad-CAM object; define it if not already present globally."""
    try:
        GC = GradCAM  # use an existing global, if defined elsewhere
    except NameError:
        class GradCAM:
            def __init__(self, model, layer_name="layer4"):
                self.model = model.eval()
                self.activations = None
                self.gradients = None
                # hook into the chosen block (ResNet50 has 'layer4')
                layers = dict(self.model.named_children())
                if layer_name not in layers:
                    raise ValueError(f"Layer '{layer_name}' not found. Available: {list(layers.keys())}")
                layer = layers[layer_name]
                layer.register_forward_hook(self._forward_hook)
                # use full backward hook (PyTorch >=1.8); falls back if not available
                try:
                    layer.register_full_backward_hook(self._backward_hook)
                except AttributeError:
                    layer.register_backward_hook(self._backward_hook)

            def _forward_hook(self, module, inp, out):
                self.activations = out.detach()

            def _backward_hook(self, module, grad_input, grad_output):
                self.gradients = grad_output[0].detach()

            def __call__(self, x, index=None):
                self.model.zero_grad(set_to_none=True)
                out = self.model(x)  # [B,C]
                # pick target index
                if index is None:
                    idx = int(out.argmax(1).item())
                elif torch.is_tensor(index):
                    idx = int(index.item())
                else:
                    idx = int(index)
                loss = out[0, idx]
                loss.backward()

                # weights = mean over spatial dims of gradients
                w = self.gradients[0].mean(dim=(1, 2), keepdim=True)  # [C,1,1]
                cam = (w * self.activations[0]).sum(0).cpu().numpy()   # [H,W]
                cam = np.maximum(cam, 0)
                cam = (cam - cam.min()) / (cam.max() + 1e-8)
                return cam, out.softmax(1)[0].detach().cpu().numpy()

        GC = GradCAM

    return GC(model, layer_name)

# instantiate safely


# (Optional) overlay utility for visualization
def overlay_cam(pil_img, cam, alpha=0.45):
    cam_img = Image.fromarray((cam*255).astype(np.uint8)).resize(pil_img.size, resample=Image.BILINEAR)
    cmap = plt.get_cmap("jet")
    heat = Image.fromarray((cmap(np.asarray(cam_img)/255.0)[..., :3]*255).astype(np.uint8))
    return Image.blend(pil_img.convert("RGB"), heat, alpha=alpha)


def topk_probs(prob_vec, classes, k=3):
    idx = np.argsort(-prob_vec)[:k]
    return [(classes[i], float(prob_vec[i])) for i in idx]

# --- Pick two failures with highest wrong-class confidence ---
@torch.no_grad()
def collect_test_logits(model, loader):
    model.eval(); L, Y, IDS, PTH = [], [], [], []
    for x, y, ids in loader:
        x = x.to(DEVICE); out = model(x).cpu()
        L.append(out); Y.append(y); IDS += list(ids)
    L = torch.cat(L); Y = torch.cat(Y)
    return L, Y, IDS

# Ensure model & eval_tf & test exist
assert 'model' in globals() and 'eval_tf' in globals() and 'test' in globals(), \
    "Missing model/eval_tf/test. Run earlier cells."

logits_test, y_test, ids_test = collect_test_logits(model, dl_te)
probs_test = logits_test.softmax(1).numpy()
y_np = y_test.numpy()
pred = probs_test.argmax(1)
wrong = np.where(pred != y_np)[0]
assert len(wrong) > 0, "No failures on the test split (unlikely)."

conf_wrong = probs_test[wrong, pred[wrong]]
order = wrong[np.argsort(-conf_wrong)]
sel = order[:2].tolist()

# --- Heuristic diagnosis (brightness/contrast/sharpness + CAM centrality) ---
from PIL import ImageOps, ImageFilter
def img_stats(pil):
    g = ImageOps.grayscale(pil)
    arr = np.asarray(g, dtype=np.float32) / 255.0
    bright = float(arr.mean())
    contrast = float(arr.std())
    edges = g.filter(ImageFilter.FIND_EDGES)
    sharp = float(np.asarray(edges, dtype=np.float32).mean())  # proxy sharpness
    return bright, contrast, sharp

def cam_centrality(cam, frac=0.5):
    h, w = cam.shape
    ch, cw = int(h*frac), int(w*frac)
    y0 = (h - ch)//2; x0 = (w - cw)//2
    center = cam[y0:y0+ch, x0:x0+cw].mean()
    total = cam.mean() + 1e-8
    return float(center / total)

def diagnose(pil, cam, prob_vec, y_true_idx, y_pred_idx):
    bright, contrast, sharp = img_stats(pil)
    cent = cam_centrality(cam, 0.5)
    # thresholds from empirical ranges on HAM images
    if sharp < 20:
        diag = "Scale/blur-induced error"
        mit  = "Use anti-aliasing on resize; consider higher input resolution or mild super-resolution; stronger blur augmentation."
        risk = "Under low-resolution or out-of-focus images, texture-similar lesions remain confusable."
    elif bright > 0.70 or contrast < 0.08:
        diag = "Photometric/contrast shift"
        mit  = "Apply color constancy and histogram normalization; include brightness/contrast augmentation; per-site calibration."
        risk = "Over/under-exposed images can still shift probabilities; monitor brightness-slice metrics."
    elif cent < 0.85:
        diag = "Saliency mislocalization (context over-weighted)"
        mit  = "Lesion-centric crops/segmentation (e.g., SAM) and hair/ruler removal; background-invariance regularization."
        risk = "Off-center or occluded lesions may still draw attention to artifacts; require human review on abstentions."
    else:
        diag = "Class boundary confusion"
        mit  = "Add near-boundary examples, label adjudication, and focal/LDS sampling; temperature scaling + selective prediction."
        risk = "Atypical/amelanotic variants remain risky without richer supervision."
    return diag, mit, risk

# --- Build two case records + images ---
gcam = get_gradcam(model, layer_name="layer4")
CLASSES = CLASSES  # already defined above
cases_out = []

for i, idx in enumerate(sel, start=1):
    row = test.iloc[idx]
    pil = Image.open(row.path).convert("RGB")
    x = eval_tf(pil).unsqueeze(0).to(DEVICE)
    cam, prob_vec = gcam(x, index=int(pred[idx]))
    overlay = overlay_cam(pil, cam)
    case_png = RUN_DIR / "taskE" / f"case_{i}.png"
    overlay.save(case_png)

    top3 = topk_probs(prob_vec, CLASSES, k=3)
    diag, mit, risk = diagnose(pil, cam, prob_vec, y_true_idx=int(y_np[idx]), y_pred_idx=int(pred[idx]))

    case = {
        "image": str(row.path),
        "overlay": str(case_png),
        "y_true": CLASSES[int(y_np[idx])],
        "y_pred": CLASSES[int(pred[idx])],
        "confidence": float(prob_vec[int(pred[idx])]),
        "top3": top3,
        "diagnosis": diag,
        "mitigation": mit,
        "residual_risk": risk,
    }
    cases_out.append(case)

# Save machine-readable summary
with open(RUN_DIR / "taskE" / "cases_taskE.json", "w") as f:
    json.dump(cases_out, f, indent=2)

# Render a side-by-side figure (each case shows original + CAM overlay)
def side_by_side(pil, overlay, title_left="Original", title_right="Confidence map", out_path=None):
    fig, ax = plt.subplots(1, 2, figsize=(8, 3), dpi=200)
    ax[0].imshow(pil); ax[0].set_title(title_left, fontsize=9); ax[0].axis("off")
    ax[1].imshow(overlay); ax[1].set_title(title_right, fontsize=9); ax[1].axis("off")
    plt.tight_layout()
    if out_path: plt.savefig(out_path, bbox_inches="tight")
    plt.close()

for i, idx in enumerate(sel, start=1):
    pil = Image.open(test.iloc[idx].path).convert("RGB")
    overlay = Image.open(cases_out[i-1]["overlay"]).convert("RGB")
    out_fig = RUN_DIR / "taskE" / f"case_{i}_panel.png"
    side_by_side(pil, overlay, out_path=out_fig)
    cases_out[i-1]["panel"] = str(out_fig)

print("Task E cases saved to:", RUN_DIR / "taskE")
cases_out


Task E cases saved to: Assignment-Output/Assignment-01-Output/taskE


[{'image': '/share_2/users/umair_nawaz/Mine/CV8502/A1/dataset/HAM10000/HAM10000_images_part_1/ISIC_0025316.jpg',
  'overlay': 'Assignment-Output/Assignment-01-Output/taskE/case_1.png',
  'y_true': 'mel',
  'y_pred': 'nv',
  'confidence': 0.9316644668579102,
  'top3': [('nv', 0.9316644668579102),
   ('mel', 0.017070578411221504),
   ('bkl', 0.01517099142074585)],
  'diagnosis': 'Scale/blur-induced error',
  'mitigation': 'Use anti-aliasing on resize; consider higher input resolution or mild super-resolution; stronger blur augmentation.',
  'residual_risk': 'Under low-resolution or out-of-focus images, texture-similar lesions remain confusable.',
  'panel': 'Assignment-Output/Assignment-01-Output/taskE/case_1_panel.png'},
 {'image': '/share_2/users/umair_nawaz/Mine/CV8502/A1/dataset/HAM10000/HAM10000_images_part_1/ISIC_0028760.jpg',
  'overlay': 'Assignment-Output/Assignment-01-Output/taskE/case_2.png',
  'y_true': 'mel',
  'y_pred': 'nv',
  'confidence': 0.9214457273483276,
  'top3': [(