<a href="https://colab.research.google.com/github/tousifo/ml_notebooks/blob/main/MedMNIST_QNN_AllInOne.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install pennylane medmnist scikit-learn matplotlib scikit-image
!pip install shap lime

Collecting pennylane
  Downloading pennylane-0.42.3-py3-none-any.whl.metadata (11 kB)
Collecting medmnist
  Downloading medmnist-3.0.2-py3-none-any.whl.metadata (14 kB)
Collecting rustworkx>=0.14.0 (from pennylane)
  Downloading rustworkx-0.17.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting appdirs (from pennylane)
  Downloading appdirs-1.4.4-py2.py3-none-any.whl.metadata (9.0 kB)
Collecting autoray<0.8,>=0.6.11 (from pennylane)
  Downloading autoray-0.7.2-py3-none-any.whl.metadata (5.8 kB)
Collecting pennylane-lightning>=0.42 (from pennylane)
  Downloading pennylane_lightning-0.42.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (11 kB)
Collecting diastatic-malt (from pennylane)
  Downloading diastatic_malt-2.15.2-py3-none-any.whl.metadata (2.6 kB)
Collecting fire (from medmnist)
  Downloading fire-0.7.1-py3-none-any.whl.metadata (5.8 kB)
Collecting scipy-openblas32>=0.3.26 (from pennylane-lightning>=0.42->pennylane)
  Downloading scipy_openblas3

# installs , imports, config, FAST_MODE, seeds, env log

In [2]:
# ===== SNIPPET 1 — Setup & Config =====
# (Optional) First-time installs in a fresh Colab:
# !pip install -q pennylane medmnist scikit-learn shap lime matplotlib scikit-image tqdm

import os, sys, time, json, math, random, pathlib, platform, warnings
from pathlib import Path

import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
warnings.filterwarnings("ignore", category=UserWarning, module="torch.utils.data.dataloader")

import pennylane as qml

# Optional libs (for later steps)
try:
    import shap; HAVE_SHAP=True
except Exception: HAVE_SHAP=False
try:
    from lime import lime_image; HAVE_LIME=True
except Exception: HAVE_LIME=False
try:
    from sklearn.metrics import f1_score, confusion_matrix, roc_auc_score
    HAVE_SK=True
except Exception:
    HAVE_SK=False

# ---------- Run & cache directories ----------
RUN_TS    = time.strftime("%Y%m%d_%H%M%S")
ROOT_DIR  = Path(f"./qnn_hybrid_runs/{RUN_TS}"); ROOT_DIR.mkdir(parents=True, exist_ok=True)
CACHE_DIR = Path("./medmnist_cache"); CACHE_DIR.mkdir(parents=True, exist_ok=True)

# ---------- Dataset plan ----------
# Headline results target >92%: BloodMNIST + PneumoniaMNIST (hybrid CNN→QNN).
# Keep DermaMNIST as a tougher "viability" analysis set.
CORE_DATASETS     = ["BloodMNIST", "PneumoniaMNIST"]
VIABILITY_DATASET = "DermaMNIST"

# ---------- Device & loader defaults ----------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FAST_MODE   = not torch.cuda.is_available()
# Set workers=0 to avoid Colab's "can only test a child process" teardown spam
NUM_WORKERS = 0
PIN_MEMORY  = False

# ---------- Hybrid CNN→QNN architecture ----------
IMG_SIZE      = 28
STEM_CHANS    = [32, 64]   # widened stem (still light)
STEM_DROPOUT  = 0.10
FEAT_DIM      = 256
N_QUBITS      = 8
Q_DEPTH       = 2          # shallow to avoid barren plateaus
ENT_TYPE      = "ring"     # richer than chain, still cheap
REUPLOADS     = 3
Q_MULTI_OBS   = True       # measure <Z> and <X> per qubit (QMO readout)
QRG_RESIDUAL  = True       # quantum residual gating

# ---------- Training hyperparameters ----------
BATCH_SIZE   = 128 if FAST_MODE else 256
GLOBAL_SEED  = 42
AUG_SEED     = 1337
EPOCHS_CORE  = 45          # stronger schedule for core datasets
EPOCHS_VIA   = 20
LR           = 1e-3
WEIGHT_DECAY = 5e-4
LABEL_SMOOTH = 0.05
GRAD_CLIP    = 1.0

# Data augmentation
USE_MIXUP          = True
MIXUP_ALPHA        = 0.10
WARMUP_NO_MIXUP_EP = 3
USE_CUTMIX         = False
CUTMIX_ALPHA       = 0.20

# Quantum-centric regularizers (novelty)
ORTHO_LAMBDA = 1e-3   # ||P^T P − I||_F penalty on the projector
QVAR_LAMBDA  = 1e-3   # encourage variance on quantum features

# Early stopping
USE_EARLY_STOP = True
PATIENCE       = 10
MIN_DELTA      = 1e-4

# Calibration & test-time augmentation
DO_CALIBRATION = True
DO_TTA = True
TTA_N  = 8 if torch.cuda.is_available() else 4

# Snapshot ensemble of late checkpoints
SNAPSHOT_ENSEMBLE    = True
SNAPSHOT_START_EPOCH = max(4, EPOCHS_CORE // 2)
SNAPSHOT_GAP         = 1

# Reproducibility
def set_global_seed_all(seed: int):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
set_global_seed_all(GLOBAL_SEED)
if DEVICE.type == "cpu":
    try: torch.set_num_threads(2)
    except Exception: pass

# Helpful: AMP helpers (used in Snippet 3)
from contextlib import nullcontext
def amp_context():
    if torch.cuda.is_available():
        return torch.amp.autocast(device_type="cuda")
    return nullcontext()
def make_scaler():
    return torch.amp.GradScaler(device="cuda") if torch.cuda.is_available() else None

# Environment log
env = {
    "timestamp": RUN_TS, "python": sys.version, "platform": platform.platform(),
    "torch": torch.__version__, "cuda": torch.version.cuda if torch.version.cuda else None,
    "cudnn": torch.backends.cudnn.version(),
    "device": (torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"),
    "pennylane": qml.__version__, "have_shap": HAVE_SHAP, "have_lime": HAVE_LIME, "have_sklearn": HAVE_SK,
    "fast_mode": FAST_MODE,
    "img_size": IMG_SIZE, "stem_chans": STEM_CHANS, "feat_dim": FEAT_DIM,
    "n_qubits": N_QUBITS, "q_depth": Q_DEPTH, "ent": ENT_TYPE, "reuploads": REUPLOADS,
    "q_multi_obs": Q_MULTI_OBS, "qrg_resid": QRG_RESIDUAL,
    "batch_size": BATCH_SIZE, "epochs_core": EPOCHS_CORE, "epochs_via": EPOCHS_VIA
}
with open(ROOT_DIR/"env.json","w") as f: json.dump(env, f, indent=2)

print("[Env]", json.dumps(env, indent=2))
print(f"[Paths] ROOT_DIR={ROOT_DIR.resolve()}  CACHE_DIR={CACHE_DIR.resolve()}")
print("[Ready:1] Setup complete.")

[Env] {
  "timestamp": "20251011_162052",
  "python": "3.12.11 (main, Jun  4 2025, 08:56:18) [GCC 11.4.0]",
  "platform": "Linux-6.6.97+-x86_64-with-glibc2.35",
  "torch": "2.8.0+cu126",
  "cuda": "12.6",
  "cudnn": 91002,
  "device": "Tesla T4",
  "pennylane": "0.42.3",
  "have_shap": true,
  "have_lime": true,
  "have_sklearn": true,
  "fast_mode": false,
  "img_size": 28,
  "stem_chans": [
    32,
    64
  ],
  "feat_dim": 256,
  "n_qubits": 8,
  "q_depth": 2,
  "ent": "ring",
  "reuploads": 3,
  "q_multi_obs": true,
  "qrg_resid": true,
  "batch_size": 256,
  "epochs_core": 45,
  "epochs_via": 20
}
[Paths] ROOT_DIR=/content/qnn_hybrid_runs/20251011_162052  CACHE_DIR=/content/medmnist_cache
[Ready:1] Setup complete.


# datasets, loaders, PCA (with progress), QNN builder


In [3]:
# ===== SNIPPET 2 — Data & Hybrid Model =====
# MedMNIST datasets
try:
    import medmnist
    from medmnist import INFO
    from medmnist.dataset import PathMNIST, DermaMNIST, BloodMNIST, PneumoniaMNIST
except Exception as e:
    raise RuntimeError("Please install medmnist: pip install medmnist") from e

DS_MAP = {
    "pathmnist": PathMNIST,
    "dermamnist": DermaMNIST,
    "bloodmnist": BloodMNIST,
    "pneumoniamnist": PneumoniaMNIST,
}

# ---------- label utilities ----------
def _single_label_to_index(y):
    import numpy as _np
    if torch.is_tensor(y):
        if y.ndim==0: return int(y.item())
        if y.ndim==1: return int(y[0].item())
        if y.ndim==2: return int(y.argmax().item())
    arr=_np.asarray(y)
    if arr.ndim==0: return int(arr.item())
    if arr.ndim==1: return int(arr[0].item())
    if arr.ndim==2: return int(arr.argmax())
    return int(arr.item())

def to_index(batch_y):
    import numpy as _np
    if isinstance(batch_y, torch.Tensor):
        if batch_y.ndim==2 and batch_y.size(1)>1: return batch_y.argmax(1).long()
        return batch_y.view(-1).long()
    arr=_np.asarray(batch_y)
    if arr.ndim==2 and arr.shape[1]>1: return torch.tensor(arr.argmax(1), dtype=torch.long)
    return torch.tensor(arr.reshape(-1), dtype=torch.long)

# ---------- dataset & transforms ----------
from torchvision.transforms import ColorJitter, RandomAffine

def get_datasets(name, cache_dir: Path):
    key = name.lower()
    info = INFO[key]; n_classes = len(info["label"])

    base_tf = [transforms.ToTensor()]
    tr_tf = transforms.Compose([
        transforms.RandomHorizontalFlip(0.5),
        transforms.RandomVerticalFlip(0.2),
        RandomAffine(degrees=12, translate=(0.05,0.05), scale=(0.95,1.05)),
        ColorJitter(0.05,0.05,0.05,0.02),
        *base_tf
    ])
    ev_tf = transforms.Compose(base_tf)

    DS = DS_MAP[key]
    tr = DS(root=str(cache_dir), split="train", download=True, transform=tr_tf, as_rgb=True)
    va = DS(root=str(cache_dir), split="val",   download=True, transform=ev_tf,  as_rgb=True)
    te = DS(root=str(cache_dir), split="test",  download=True, transform=ev_tf,  as_rgb=True)

    # dataset-driven mean/std on train
    loader = DataLoader(tr, batch_size=512, shuffle=False, num_workers=0, pin_memory=False)
    s = torch.zeros(3); ss=torch.zeros(3); n=0
    for x,_ in tqdm(loader, desc=f"[{name}] compute mean/std", leave=False):
        bs=x.size(0); n += bs*x.shape[2]*x.shape[3]
        s  += x.sum(dim=[0,2,3]); ss += (x**2).sum(dim=[0,2,3])
    mean = (s/n).tolist()
    std  = torch.sqrt((ss/n) - torch.tensor(mean)**2).tolist()
    norm = transforms.Normalize(mean=mean, std=std)
    tr.transform = transforms.Compose([*tr.transform.transforms, norm])
    va.transform = transforms.Compose([*va.transform.transforms, norm])
    te.transform = transforms.Compose([*te.transform.transforms, norm])

    # per-class counts
    def cls_counts(ds, k):
        c = np.zeros(k, dtype=int)
        dl = DataLoader(ds, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)
        for _,y in tqdm(dl, total=len(ds), desc=f"[{name}] class counts", leave=False):
            y = y[0] if isinstance(y, torch.Tensor) and y.ndim>0 else y
            c[_single_label_to_index(y)] += 1
        return c.tolist()

    dist = {"train": cls_counts(tr, n_classes), "val": cls_counts(va, n_classes), "test": cls_counts(te, n_classes)}
    return tr, va, te, n_classes, mean, std, info, dist

def make_loaders(tr, va, te):
    def seed_worker(worker_id):
        seed = AUG_SEED + worker_id
        np.random.seed(seed); random.seed(seed)
    g = torch.Generator(); g.manual_seed(AUG_SEED)
    tr_loader = DataLoader(tr, batch_size=BATCH_SIZE, shuffle=True,
                           num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
                           worker_init_fn=seed_worker, generator=g, persistent_workers=False)
    va_loader = DataLoader(va, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
    te_loader = DataLoader(te, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
    return tr_loader, va_loader, te_loader

# ---------- CNN stem (lightweight; QNN is the star) ----------
class CNNStem(nn.Module):
    def __init__(self, chans=STEM_CHANS, feat_dim=FEAT_DIM, dropout=STEM_DROPOUT):
        super().__init__()
        c1,c2 = chans
        self.net = nn.Sequential(
            nn.Conv2d(3, c1, 3, padding=1), nn.BatchNorm2d(c1), nn.ReLU(inplace=True),
            nn.Conv2d(c1, c1, 3, padding=1), nn.BatchNorm2d(c1), nn.ReLU(inplace=True),
            nn.MaxPool2d(2), nn.Dropout(dropout),

            nn.Conv2d(c1, c2, 3, padding=1), nn.BatchNorm2d(c2), nn.ReLU(inplace=True),
            nn.Conv2d(c2, c2, 3, padding=1), nn.BatchNorm2d(c2), nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1)
        )
        self.fc = nn.Linear(c2, feat_dim)
    def forward(self, x):
        z = self.net(x).view(x.size(0), -1)
        return self.fc(z)

# ---------- TCQP-Orth: trainable projector to qubit angles ----------
class ProjectToQubits(nn.Module):
    def __init__(self, in_dim=FEAT_DIM, n_qubits=N_QUBITS):
        super().__init__()
        self.lin = nn.Linear(in_dim, n_qubits)
        nn.init.orthogonal_(self.lin.weight)
        nn.init.zeros_(self.lin.bias)
        self.norm = nn.LayerNorm(n_qubits)
        self.scale = nn.Parameter(torch.ones(1))  # learnable scale
    def forward(self, f):
        qin = self.lin(f)
        qin = self.norm(qin)
        qin = torch.tanh(qin) * self.scale
        return qin

def projector_orthogonality_loss(module: ProjectToQubits):
    W = module.lin.weight  # [n_qubits, in_dim]
    gram = W @ W.t()
    I = torch.eye(gram.size(0), device=gram.device)
    return F.mse_loss(gram, I)

# ---------- QNN with QMO readout (⟨Z⟩ & ⟨X⟩), reuploads, shallow entanglement ----------
def build_qnode_hybrid(n_qubits, layers, entanglement, reuploads=1, multi_obs=True):
    dev = qml.device("default.qubit", wires=n_qubits)

    def entangle():
        if entanglement=="chain":
            for i in range(n_qubits-1): qml.CNOT(wires=[i, i+1])
        elif entanglement=="ring":
            for i in range(n_qubits-1): qml.CNOT(wires=[i, i+1]); qml.CNOT(wires=[n_qubits-1, 0])
        elif entanglement=="all":
            for i in range(n_qubits):
                for j in range(i+1, n_qubits):
                    qml.CNOT(wires=[i, j])
        else:
            raise ValueError("Unknown entanglement")

    @qml.qnode(dev, interface="torch", diff_method="backprop")
    def qnode(inputs, theta_rot_shared, theta_phase_shared):
        for _ in range(reuploads):
            qml.AngleEmbedding(inputs, wires=range(n_qubits), rotation="Y")
            for _ in range(layers):
                for i in range(n_qubits):
                    qml.RY(theta_rot_shared[i], wires=i)
                    qml.RZ(theta_phase_shared[i], wires=i)
                entangle()
        if multi_obs:
            obs = []
            for i in range(n_qubits): obs.append(qml.expval(qml.PauliZ(i)))
            for i in range(n_qubits): obs.append(qml.expval(qml.PauliX(i)))
            return obs
        else:
            return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]

    feat_out = n_qubits*2 if multi_obs else n_qubits
    weight_shapes = {"theta_rot_shared": (n_qubits,), "theta_phase_shared": (n_qubits,)}
    return qnode, weight_shapes, feat_out

# ---------- Hybrid CNN→QNN model (QRG readout fusion) ----------
class HybridCnnQnn(nn.Module):
    def __init__(self, n_classes,
                 n_qubits=N_QUBITS, depth=Q_DEPTH, ent=ENT_TYPE,
                 reuploads=REUPLOADS, multi_obs=Q_MULTI_OBS):
        super().__init__()
        self.stem = CNNStem(STEM_CHANS, FEAT_DIM, STEM_DROPOUT)
        self.project = ProjectToQubits(FEAT_DIM, n_qubits)

        qnode, weight_shapes, qfeat = build_qnode_hybrid(
            n_qubits, depth, ent, reuploads=reuploads, multi_obs=multi_obs
        )
        self.q = qml.qnn.TorchLayer(qnode, weight_shapes)
        for p in self.q.parameters():
            nn.init.normal_(p, mean=0.0, std=1e-2)

        self.qfeat_dim = qfeat
        self.qrg_gate = nn.Sequential(nn.Linear(n_qubits, 1), nn.Sigmoid()) if QRG_RESIDUAL else None
        head_in = qfeat + (n_qubits if QRG_RESIDUAL else 0)
        self.head = nn.Sequential(
            nn.LayerNorm(head_in),
            nn.Linear(head_in, max(64, n_classes*8)),
            nn.ReLU(inplace=True),
            nn.Dropout(0.10),
            nn.Linear(max(64, n_classes*8), n_classes)
        )

    def forward(self, x):
        f = self.stem(x)               # [B, FEAT_DIM]
        qin = self.project(f)          # [B, n_qubits]
        qout = self.q(qin).to(qin.dtype)  # [B, qfeat]
        if self.qrg_gate is not None:
            g = self.qrg_gate(qin)     # [B,1]
            fused = torch.cat([qout, g*qin], dim=1)
        else:
            fused = qout
        return self.head(fused)

print(f"[Ready:2] Data & Hybrid model ready. MedMNIST v{getattr(medmnist, '__version__', '?')}")

[Ready:2] Data & Hybrid model ready. MedMNIST v3.0.2


In [4]:
# ===== SNIPPET 3 — Training & Eval Utils =====
import json, math

# ---- Metrics ----
def epoch_metrics(logits: torch.Tensor, yy: torch.Tensor, n_classes: int):
    y_pred = logits.argmax(1).cpu().numpy()
    y_true = yy.cpu().numpy()
    acc = float((y_pred == y_true).mean())
    if HAVE_SK:
        from sklearn.metrics import f1_score, roc_auc_score
        f1m = float(f1_score(y_true, y_pred, average="macro"))
        probs = torch.softmax(logits, 1).detach().cpu().numpy()
        try:
            auroc = float(roc_auc_score(y_true, probs, multi_class="ovr", average="macro")) if n_classes>2 \
                    else float(roc_auc_score(y_true, probs[:,1]))
        except Exception:
            auroc = float("nan")
    else:
        f1m, auroc = float("nan"), float("nan")
    return acc, f1m, auroc

def ece_mce(probs: torch.Tensor, labels: torch.Tensor, n_bins: int = 15):
    ret = probs.max(dim=1)
    confs, preds = ret.values, ret.indices
    bins = torch.linspace(0,1,n_bins+1, device=probs.device)
    ece = torch.zeros((), device=probs.device); mce=torch.zeros((), device=probs.device)
    for i in range(n_bins):
        lo, hi = bins[i], bins[i+1]
        mask = (confs > lo) & (confs <= (hi if i<n_bins-1 else hi+1e-8))
        if mask.any():
            acc  = (preds[mask]==labels[mask]).float().mean()
            conf = confs[mask].mean()
            gap  = (conf-acc).abs()
            ece += gap * (mask.float().mean())
            mce = torch.maximum(mce, gap)
    return float(ece.item()), float(mce.item())

def plot_reliability(probs: torch.Tensor, labels: torch.Tensor, path):
    bins = torch.linspace(0,1,16, device=probs.device)
    ret = probs.max(dim=1); confs, preds = ret.values, ret.indices
    xs, ys = [], []
    for i in range(15):
        lo, hi = bins[i], bins[i+1]
        mask = (confs > lo) & (confs <= (hi if i<14 else hi+1e-8))
        if mask.any():
            xs.append(confs[mask].mean().item())
            ys.append((preds[mask]==labels[mask]).float().mean().item())
    plt.figure(figsize=(4,4)); plt.plot([0,1],[0,1], linestyle="--")
    plt.plot(xs, ys, marker="o"); plt.xlabel("Confidence"); plt.ylabel("Accuracy"); plt.title("Reliability")
    plt.tight_layout(); plt.savefig(path, dpi=300); plt.close()

class TemperatureScaler(nn.Module):
    def __init__(self): super().__init__(); self.t=nn.Parameter(torch.ones(1))
    def forward(self, logits): return logits / self.t.clamp_min(1e-4)
    def fit(self, logits: torch.Tensor, labels: torch.Tensor, lr=0.1, steps=200):
        labels = labels.long()
        opt = torch.optim.LBFGS([self.t], lr=lr, max_iter=steps)
        def closure():
            opt.zero_grad()
            loss = F.cross_entropy(self.forward(logits), labels); loss.backward()
            return loss
        opt.step(closure)

# ---- Augmentations ----
def mixup_data(x, y, alpha=0.1):
    if alpha<=0: return x,y,1.0, torch.arange(x.size(0), device=x.device)
    lam = np.random.beta(alpha, alpha); idx = torch.randperm(x.size(0), device=x.device)
    return lam*x + (1-lam)*x[idx], (y, y[idx]), lam, idx

def cutmix_data(x, y, alpha=0.2):
    lam = np.random.beta(alpha, alpha); idx = torch.randperm(x.size(0), device=x.device)
    W,H = x.size(3), x.size(2)
    cut_rat = math.sqrt(1-lam); w,h=int(W*cut_rat), int(H*cut_rat)
    cx,cy = np.random.randint(W), np.random.randint(H)
    x1,y1 = np.clip(cx-w//2,0,W), np.clip(cy-h//2,0,H)
    x2,y2 = np.clip(cx+w//2,0,W), np.clip(cy+h//2,0,H)
    x_new = x.clone(); x_new[:,:,y1:y2,x1:x2] = x[idx][:,:,y1:y2,x1:x2]
    lam_adj = 1 - ((x2-x1)*(y2-y1)/(W*H))
    return x_new, (y, y[idx]), lam_adj

# ---- TTA ----
def tta_predict(model, x, n=TTA_N):
    model.eval(); out=[]
    with torch.no_grad():
        for i in range(n):
            xt = x.clone()
            if i%4==1: xt = torch.flip(xt, dims=[3])
            elif i%4==2: xt = torch.flip(xt, dims=[2])
            elif i%4==3: xt = torch.rot90(xt, k=1, dims=[2,3])
            l = model(xt); out.append(torch.softmax(l,1))
    return torch.stack(out,0).mean(0)

# ---- Speed ----
def measure_inference_ms_per_img(model, loader, warmup=2, batches=8):
    model.eval(); it=iter(loader)
    # warmup
    for _ in range(warmup):
        try: x,_ = next(it)
        except StopIteration: it=iter(loader); x,_ = next(it)
        _ = model(x.to(DEVICE))
    # timed
    total=0; t1=None; t2=None
    for _ in range(batches):
        try: x,_ = next(it)
        except StopIteration: it=iter(loader); x,_ = next(it)
        x=x.to(DEVICE)
        if torch.cuda.is_available(): torch.cuda.synchronize()
        t_start=time.time(); _=model(x)
        if torch.cuda.is_available(): torch.cuda.synchronize()
        t_end=time.time()
        if t1 is None: t1=t_start
        t2=t_end; total += x.size(0)
    elapsed=(t2-t1) if (t1 is not None and t2 is not None) else 0.0
    return float((elapsed/max(total,1))*1000.0)

# ---- Class weights helper ----
def class_weights_from_counts(counts):
    w = np.array(counts, dtype=np.float64); w = w.sum() / np.clip(w, 1, None)
    w = w / w.mean()
    return w.tolist()

# ---- Training loop (AMP, early stop, snapshots, TCQP-Orth, QVAR) ----
FREEZE_STEM_EPOCHS = 5  # let QNN learn first

def _set_requires_grad(module, flag: bool):
    for p in module.parameters(): p.requires_grad_(flag)

def train_model(n_classes, loaders, epochs, tag, class_weights=None):
    tr_loader, va_loader, _ = loaders
    model = HybridCnnQnn(n_classes).to(DEVICE)
    _set_requires_grad(model.stem, False)  # freeze CNN stem first
    opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WEIGHT_DECAY)
    # Scheduler (prefer OneCycle; fallback to Cosine)
    try:
        sched = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=LR, epochs=epochs, steps_per_epoch=max(1,len(tr_loader)))
    except Exception:
        try:
            sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=(epochs//2 if epochs>10 else 5))
        except TypeError:
            sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs, eta_min=1e-5)
    scaler = make_scaler()

    cw = torch.tensor(class_weights, dtype=torch.float32, device=DEVICE) if class_weights is not None else None
    ce = lambda logits, target: F.cross_entropy(logits, target, weight=cw, label_smoothing=LABEL_SMOOTH)

    best=-1.0; best_state=None; no_improve=0
    tag = Path(tag); tag.parent.mkdir(parents=True, exist_ok=True)
    log_path = tag.parent/f"{tag.name}_train_log.csv"
    with open(log_path,"w") as f: f.write("epoch,split,loss,acc,f1,auroc\n")

    snapshots=[]
    for ep in range(1, epochs+1):
        if ep == FREEZE_STEM_EPOCHS+1:
            _set_requires_grad(model.stem, True)  # unfreeze CNN
            opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
            # reattach scheduler if the previous one needs steps_per_epoch
            try:
                sched = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=LR, epochs=epochs-ep+1, steps_per_epoch=max(1,len(tr_loader)))
            except Exception:
                pass

        model.train(); tot=0.0; n=0; lg=[]; yy=[]
        pbar = tqdm(tr_loader, desc=f"[{tag.name}] Train ep {ep}/{epochs}", leave=False)
        for x,y in pbar:
            x=x.to(DEVICE); y=to_index(y).to(DEVICE)
            use_mix = (ep > WARMUP_NO_MIXUP_EP) and USE_MIXUP and (not USE_CUTMIX)

            with amp_context():
                if use_mix:
                    x_aug, ypair, lam, _ = mixup_data(x, y, MIXUP_ALPHA)
                    logits = model(x_aug)
                    loss = lam*ce(logits, ypair[0]) + (1-lam)*ce(logits, ypair[1])
                elif USE_CUTMIX:
                    x_aug, ypair, lam = cutmix_data(x, y, CUTMIX_ALPHA)
                    logits = model(x_aug)
                    loss = lam*ce(logits, ypair[0]) + (1-lam)*ce(logits, ypair[1])
                else:
                    logits = model(x)
                    loss = ce(logits, y)

                # TCQP-Orth penalty
                loss += ORTHO_LAMBDA * projector_orthogonality_loss(model.project)
                # QVAR: encourage variance on quantum features (avoid collapse)
                with torch.no_grad():
                    f = model.stem(x); qin = model.project(f); qout = model.q(qin).to(qin.dtype)
                var = qout.var(dim=0).mean()
                loss += QVAR_LAMBDA * F.relu(0.05 - var)

            opt.zero_grad()
            if scaler is not None:
                scaler.scale(loss).backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
                scaler.step(opt); scaler.update()
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
                opt.step()
            try: sched.step()
            except Exception: pass

            tot += float(loss.item()) * x.size(0); n += x.size(0)
            lg.append(logits.detach().cpu()); yy.append(y.detach().cpu())
            pbar.set_postfix({"loss": f"{loss.item():.3f}"})

        tr_logits = torch.cat(lg); tr_y = torch.cat(yy)
        tr_acc, tr_f1, tr_auc = epoch_metrics(tr_logits, tr_y, n_classes)
        with open(log_path,"a") as f: f.write(f"{ep},train,{tot/max(n,1):.6f},{tr_acc:.6f},{tr_f1:.6f},{tr_auc:.6f}\n")

        # ---- validation ----
        model.eval(); lg=[]; yy=[]; tot=0.0; n=0
        for x,y in tqdm(va_loader, desc=f"[{tag.name}] Val   ep {ep}/{epochs}", leave=False):
            x=x.to(DEVICE); y=to_index(y).to(DEVICE)
            with torch.no_grad(), amp_context():
                logits = model(x); loss = F.cross_entropy(logits, y)
            tot += float(loss.item())*x.size(0); n+=x.size(0)
            lg.append(logits.cpu()); yy.append(y.cpu())
        va_logits = torch.cat(lg); va_y = torch.cat(yy)
        va_acc, va_f1, va_auc = epoch_metrics(va_logits, va_y, n_classes)
        with open(log_path,"a") as f: f.write(f"{ep},val,{tot/max(n,1):.6f},{va_acc:.6f},{va_f1:.6f},{va_auc:.6f}\n")

        # snapshots
        if SNAPSHOT_ENSEMBLE and ep >= SNAPSHOT_START_EPOCH and ((ep-SNAPSHOT_START_EPOCH)%SNAPSHOT_GAP==0):
            p = tag.parent / f"{tag.name}_snapshot_ep{ep}.pt"
            torch.save(model.state_dict(), p); snapshots.append(str(p))

        # early stop
        if va_acc > best + MIN_DELTA:
            best=va_acc; best_state={k:v.detach().cpu().clone() for k,v in model.state_dict().items()}; no_improve=0
        else:
            no_improve += 1
        tqdm.write(f"[{tag.name}] ep{ep:02d} val_acc={va_acc:.3f} (best {best:.3f})")
        if USE_EARLY_STOP and no_improve>=PATIENCE:
            tqdm.write(f"[{tag.name}] Early stop at epoch {ep} (patience {PATIENCE})")
            break

    if best_state: model.load_state_dict(best_state)
    torch.save(model.state_dict(), tag.parent/f"{tag.name}_best.pt")
    return model, snapshots

# ---- Test evaluation helpers ----
def evaluate_model(model, loaders, n_classes, use_tta=False):
    _, _, te_loader = loaders
    model.eval(); lg=[]; yy=[]
    for x,y in tqdm(te_loader, desc="[Eval] test", leave=False):
        x=x.to(DEVICE); y=to_index(y).to(DEVICE)
        if use_tta and DO_TTA:
            p = tta_predict(model, x, n=TTA_N); l = torch.log(p + 1e-8)
        else:
            l = model(x)
        lg.append(l.cpu()); yy.append(y.cpu())
    logits = torch.cat(lg); y = torch.cat(yy)
    acc,f1m,aucm = epoch_metrics(logits, y, n_classes)
    probs = torch.softmax(logits,1); ece,mce = ece_mce(probs, y, 15)
    return {"acc":acc,"f1":f1m,"auroc":aucm,"ece":ece,"mce":mce,"logits":logits,"y":y,"probs":probs}

@torch.no_grad()
def evaluate_snapshot_ensemble(snapshot_paths, base_model_ctor, loaders, n_classes, use_tta=False):
    _, _, te_loader = loaders
    probs_all=[]; y_all=[]
    for x,y in tqdm(te_loader, desc="[SnapshotEnsemble] test", leave=False):
        x=x.to(DEVICE); y=to_index(y).to(DEVICE)
        p_accum=None
        for ckpt in snapshot_paths:
            m = base_model_ctor().to(DEVICE)
            sd = torch.load(ckpt, map_location=DEVICE)
            m.load_state_dict(sd); m.eval()
            p = tta_predict(m, x, n=TTA_N) if (use_tta and DO_TTA) else torch.softmax(m(x),1)
            p_accum = p if p_accum is None else (p_accum + p)
        p_accum = p_accum / max(1, len(snapshot_paths))
        probs_all.append(p_accum.cpu()); y_all.append(y.cpu())
    probs = torch.cat(probs_all); y = torch.cat(y_all)
    logits = torch.log(probs + 1e-8)
    acc,f1m,aucm = epoch_metrics(logits, y, n_classes); ece,mce = ece_mce(probs, y, 15)
    return {"acc":acc,"f1":f1m,"auroc":aucm,"ece":ece,"mce":mce}

print("[Ready:3] Training & Eval utilities ready.")


[Ready:3] Training & Eval utilities ready.


# metrics, calibration, TTA, timing, training/eval with tqdm + early stopping

In [5]:
# ===== HARD RESET MATPLOTLIB + SAFE PLOTTING SHIM =====
# 1) Close any open figures and purge all matplotlib modules from sys.modules
import sys, gc
mods = [m for m in list(sys.modules.keys()) if m.startswith('matplotlib')]
for m in mods:
    try:
        del sys.modules[m]
    except KeyError:
        pass
gc.collect()

# 2) Re-import matplotlib cleanly with a non-interactive backend to avoid auto-draw hooks
import importlib
import matplotlib as mpl
mpl.use('Agg')  # must be set BEFORE importing pyplot
import matplotlib.pyplot as plt

# 3) Keep plotting simple & robust
plt.ioff()  # disable interactive draws
mpl.rcParams.update({
    'figure.dpi': 100,
    'savefig.dpi': 300,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'axes.titlesize': 10,
    'axes.labelsize': 9,
    'xtick.labelsize': 8,
    'ytick.labelsize': 8,
    # avoid outlines/hatches entirely
    'patch.linewidth': 0.0,
    'hatch.color': 'none',
    'hatch.linewidth': 0.0,
})

print("[Matplotlib] fully reset with 'Agg' backend & non-interactive mode.")

# 4) Safe 'bar chart' renderer that uses a tiny image instead of plt.bar
import numpy as np
from pathlib import Path

def bars_as_image(values, labels, title, save_path, bar_rgb=(0.18,0.18,0.18)):
    values = [float(v) for v in values]
    n = max(1, len(values))
    W = 50*n + 60; H = 220
    img = np.ones((H, W, 3), dtype=np.float32)
    vmax = max(values) if values else 1.0
    for i, v in enumerate(values):
        frac = 0.0 if vmax <= 0 else max(0.0, min(1.0, v / vmax))
        x0 = 30 + 50*i; x1 = x0 + 36
        y1 = H - 35; y0 = int(y1 - frac*(H-70))
        img[y0:y1, x0:x1, :] = bar_rgb
    # draw with imshow + text (no patches/linewidths/hatches touched)
    plt.figure(figsize=(max(4.8, n*0.9), 3.2))
    plt.imshow(img); plt.axis("off"); plt.title(title)
    for i, lab in enumerate(labels):
        lx = 30 + 50*i + 18
        plt.text(lx, H-20, lab, ha="center", va="top", fontsize=8, color="black")
        plt.text(lx, H-42, f"{values[i]:.3f}", ha="center", va="bottom", fontsize=8, color="black")
    plt.tight_layout()
    save_path = Path(save_path)
    plt.savefig(save_path, dpi=300)
    plt.close()
    print(f"[bars_as_image] saved → {save_path}")

# 5) Replace the cross-dataset bar helper to use our image-based renderer
def bar_cross(key, title, fname):
    # uses global all_summaries when Snippet 4 calls it
    vals = [float(s.get(key, float("nan"))) for s in all_summaries]
    labs = [s["dataset"] for s in all_summaries]
    bars_as_image(vals, labs, title, ROOT_DIR/fname)

print("[Shim] bars_as_image + bar_cross() installed (no use of plt.bar).")


[Matplotlib] fully reset with 'Agg' backend & non-interactive mode.
[Shim] bars_as_image + bar_cross() installed (no use of plt.bar).


In [7]:
# ===== SNIPPET 4 — BUG-PROOF RUNNER (no Matplotlib; all figures via PIL/Numpy) =====
import os, json, time, math, numpy as np, torch
from pathlib import Path

# --- Safe image IO (PIL first, fallback to imageio) ---
try:
    from PIL import Image, ImageDraw, ImageFont
    HAVE_PIL = True
except Exception:
    HAVE_PIL = False
try:
    import imageio.v3 as iio
    HAVE_IIO = True
except Exception:
    HAVE_IIO = False

def _to_uint8(img01):
    x = np.clip(np.asarray(img01), 0, 1)
    return (x*255.0 + 0.5).astype(np.uint8)

def save_png(arr01, path):
    path = Path(path)
    arr8 = _to_uint8(arr01)
    if HAVE_PIL:
        Image.fromarray(arr8).save(path)
    elif HAVE_IIO:
        iio.imwrite(path, arr8)
    else:
        raise RuntimeError("No image backend (Pillow or imageio) available to save images.")
    return str(path)

def _draw_text(draw, xy, text, fill=(0,0,0), size=12, anchor=None):
    if not HAVE_PIL: return
    try:
        font = ImageFont.load_default()
        draw.text(xy, text, fill=fill, font=font, anchor=anchor)
    except Exception:
        draw.text(xy, text, fill=fill)

# --- Renderers (no mpl) ---
def bars_as_image(values, labels, title, save_path, bar_rgb=(46/255, 64/255, 83/255)):
    values = [float(v) for v in values]
    n = max(1, len(values))
    W = 60 + 60*n; H = 260
    img = np.ones((H, W, 3), dtype=np.float32)  # white
    vmax = max(values) if values else 1.0
    for i, v in enumerate(values):
        frac = 0.0 if vmax <= 0 else max(0.0, min(1.0, v / vmax))
        x0 = 40 + 60*i; x1 = x0 + 40
        y1 = H-50; y0 = int(y1 - frac*(H-100))
        img[y0:y1, x0:x1, :] = bar_rgb
        # write value
        if HAVE_PIL:
            im = Image.fromarray(_to_uint8(img))
            d  = ImageDraw.Draw(im)
            _draw_text(d, (x0+20, y0-10), f"{v:.3f}", fill=(0,0,0), size=12, anchor="ms")
            _draw_text(d, (x0+20, H-35), labels[i], fill=(0,0,0), size=12, anchor="ms")
            _draw_text(d, (W//2, 16), title, fill=(0,0,0), size=12, anchor="ms")
            img = np.asarray(im).astype(np.float32)/255.0
    save_png(img, save_path)

def reliability_plot_image(probs: torch.Tensor, labels: torch.Tensor, path, n_bins=15):
    # compute bin stats
    ret = probs.max(dim=1)
    confs = ret.values.detach().cpu().numpy()
    preds = ret.indices.detach().cpu().numpy()
    y = labels.cpu().numpy()
    bins = np.linspace(0,1,n_bins+1)
    xs, ys = [], []
    for i in range(n_bins):
        lo, hi = bins[i], bins[i+1]
        m = (confs > lo) & (confs <= (hi if i<n_bins-1 else hi+1e-8))
        if m.sum() == 0: continue
        xs.append(confs[m].mean()); ys.append((preds[m]==y[m]).mean())
    # canvas
    W,H = 420, 340
    img = np.ones((H,W,3), dtype=np.float32)
    if HAVE_PIL:
        im = Image.fromarray(_to_uint8(img)); d = ImageDraw.Draw(im)
        # axes
        left, top, right, bottom = 60, 40, W-20, H-60
        # border
        d.rectangle([left, top, right, bottom], outline=(0,0,0), width=1)
        # diagonal
        d.line([left, bottom, right, top], fill=(180,0,0), width=2)
        # ticks
        for t in np.linspace(0,1,6):
            x = left + int(t*(right-left)); y = bottom - int(t*(bottom-top))
            d.line([x, bottom, x, bottom+5], fill=(0,0,0))
            d.line([left-5, y, left, y], fill=(0,0,0))
            _draw_text(d, (x, bottom+12), f"{t:.1f}", anchor="mt")
            _draw_text(d, (left-12, y), f"{t:.1f}", anchor="rm")
        # points and polyline
        pts=[]
        for x,yv in zip(xs, ys):
            X = left + int(x*(right-left)); Y = bottom - int(yv*(bottom-top))
            r=3
            d.ellipse([X-r,Y-r,X+r,Y+r], fill=(30,30,30))
            pts.append((X,Y))
        if len(pts)>1:
            d.line(pts, fill=(30,30,30), width=2)
        _draw_text(d, (W//2, 18), "Reliability Diagram", anchor="ms")
        im = im.convert("RGB"); im.save(path)
    else:
        save_png(img, path)

def confusion_matrix_image(y_true, y_pred, labels, path):
    y_true = np.asarray(y_true); y_pred = np.asarray(y_pred)
    K = len(labels)
    # compute cm
    cm = np.zeros((K,K), dtype=np.int64)
    for t,p in zip(y_true, y_pred): cm[int(t), int(p)] += 1
    cm_norm = cm / np.maximum(1, cm.sum(axis=1, keepdims=True))
    # color map (blue)
    Hc, Wc = 30*K, 30*K
    img = np.ones((Hc+60, Wc+120, 3), dtype=np.float32)
    for i in range(K):
        for j in range(K):
            v = float(cm_norm[i,j])
            col = np.array([0.8-0.6*v, 0.9-0.6*v, 1.0])  # light→blue
            y0 = 40 + 30*i; x0 = 80 + 30*j
            img[y0:y0+30, x0:x0+30, :] = col
    if HAVE_PIL:
        im = Image.fromarray(_to_uint8(img)); d = ImageDraw.Draw(im)
        # grid & labels
        for i in range(K+1):
            y = 40 + 30*i
            d.line([80, y, 80+30*K, y], fill=(0,0,0))
            x = 80 + 30*i
            d.line([x, 40, x, 40+30*K], fill=(0,0,0))
        for i,l in enumerate(labels):
            _draw_text(d, (60, 55+30*i), str(l), anchor="rm")        # true axis
            _draw_text(d, (95+30*i, 35), str(l), anchor="mb")        # pred axis
        _draw_text(d, (80+(30*K)//2, 18), "Confusion Matrix (Test)", anchor="ms")
        im.save(path)
    else:
        save_png(img, path)

def overlay_heat(img01, heat01, alpha=0.5):
    img = np.clip(img01,0,1)
    heat = np.clip(heat01,0,1)
    # simple jet-ish (red) overlay
    color = np.zeros_like(img); color[...,0] = heat  # red channel
    return np.clip((1-alpha)*img + alpha*color, 0, 1)

def make_xai_panel(img01, sal, smg, ig, occ, save_path, title=None):
    # img01 HxWx3, sal/smg/ig/occ HxW in 0..1
    H,W,_ = img01.shape
    tiles = []
    def put(t):
        tiles.append(overlay_heat(img01, t, 0.5))
    tiles.append(img01)
    put(sal); put(smg); put(ig); put(occ)
    # concat horizontally with labels
    pad = 10
    canvas = np.ones((H+40, 5*W + 6*pad, 3), dtype=np.float32)
    names = ["Image","Saliency","SmoothGrad","Integrated Gradients","Occlusion"]
    if HAVE_PIL:
        im = Image.fromarray(_to_uint8(canvas)); d = ImageDraw.Draw(im)
        x = pad
        for k,tile in enumerate(tiles):
            tile8 = _to_uint8(tile)
            im.paste(Image.fromarray(tile8), (x, 30))
            _draw_text(d, (x+W//2, 18), names[k], anchor="ms")
            x += W + pad
        if title:
            _draw_text(d, (im.size[0]//2, im.size[1]-10), title, anchor="ms")
        im.save(save_path)
    else:
        save_png(canvas, save_path)

def shap_bar_image(values, labels, title, save_path):
    bars_as_image([float(v) for v in values], labels, title, save_path)

# --- Safer single-worker loaders (avoid multiprocess finalizer warnings) ---
def make_loaders_single(tr, va, te):
    def seed_worker(worker_id):
        seed = AUG_SEED + worker_id
        np.random.seed(seed); random.seed(seed)
    g = torch.Generator(); g.manual_seed(AUG_SEED)
    tr_loader = DataLoader(tr, batch_size=BATCH_SIZE, shuffle=True, num_workers=0,
                           pin_memory=False, worker_init_fn=seed_worker, generator=g)
    va_loader = DataLoader(va, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=False)
    te_loader = DataLoader(te, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=False)
    return tr_loader, va_loader, te_loader

# --- Helper (no-mpl) to dump cross-dataset comparisons later ---
def bar_cross_image(all_summaries, key, title, fname):
    labels = [s["dataset"] for s in all_summaries]
    vals   = [float(s.get(key, float("nan"))) for s in all_summaries]
    bars_as_image(vals, labels, title, ROOT_DIR/fname)

# --- Eval helper that never plots with mpl ---
def eval_split_logits(model, loader):
    model.eval(); lg=[]; yy=[]
    with torch.no_grad():
        for x,y in loader:
            x=x.to(DEVICE); y=to_index(y).to(DEVICE)
            lg.append(model(x).cpu()); yy.append(y.cpu())
    return torch.cat(lg), torch.cat(yy)

# --- MAIN per-dataset runner (all images PIL-based) ---
def run_one_dataset(name):
    ds_dir = ROOT_DIR / name; ds_dir.mkdir(parents=True, exist_ok=True)
    print(f"\n[Core] Preparing {name} …")
    tr, va, te, K, MEAN, STD, INFO_META, DIST = get_datasets(name, CACHE_DIR)
    with open(ds_dir/"class_distribution.json","w") as f: json.dump(DIST, f, indent=2)
    with open(ds_dir/"dataset_meta.json","w") as f:
        json.dump({"dataset": name, "n_classes": K, "labels": INFO_META["label"],
                   "medmnist_version": getattr(medmnist,"__version__","?"),
                   "mean": MEAN, "std": STD, "cache_dir": str(CACHE_DIR.resolve())}, f, indent=2)

    loaders = make_loaders_single(tr, va, te)  # <— single-worker stable loaders
    cweights = class_weights_from_counts(DIST["train"])

    # ----- train or load cached -----
    main_tag = (ds_dir / "hybrid_main")
    model, snapshots = train_model(K, loaders, EPOCHS_CORE if name in CORE_DATASETS else EPOCHS_VIA,
                                   main_tag, class_weights=cweights)

    # ----- evaluate -----
    test_plain = evaluate_model(model, loaders, K, use_tta=False)
    test_tta   = evaluate_model(model, loaders, K, use_tta=True)
    print(f"[{name}] Test plain: acc={test_plain['acc']:.3f} f1={test_plain['f1']:.3f} auroc={test_plain['auroc']:.3f} ece={test_plain['ece']:.3f}")
    print(f"[{name}] Test  TTA : acc={test_tta['acc']:.3f} f1={test_tta['f1']:.3f} auroc={test_tta['auroc']:.3f} ece={test_tta['ece']:.3f}")

    # ----- calibration -----
    calib={}
    if DO_CALIBRATION:
        val_logits, val_y = eval_split_logits(model, loaders[1])
        scaler = TemperatureScaler().to(DEVICE); scaler.fit(val_logits.to(DEVICE), val_y.to(DEVICE), lr=0.1, steps=200)
        t_logits = test_plain["logits"].to(DEVICE)
        p_before = torch.softmax(t_logits,1).cpu(); p_after = torch.softmax(scaler(t_logits),1).cpu()
        e_b,m_b = ece_mce(p_before, test_plain["y"], 15); e_a,m_a = ece_mce(p_after, test_plain["y"], 15)
        calib={"T": float(scaler.t.item()), "ECE_before": e_b, "MCE_before": m_b, "ECE_after": e_a, "MCE_after": m_a}
        with open(ds_dir/"calibration.json","w") as f: json.dump(calib, f, indent=2)
        # reliability diagrams
        reliability_plot_image(p_before, test_plain["y"], ds_dir/"reliability_before.png")
        reliability_plot_image(p_after,  test_plain["y"], ds_dir/"reliability_after.png")
        print(f"[{name}] Calibrated: T={calib['T']:.3f}  ECE {e_b:.3f}→{e_a:.3f}")

    # ----- confusion matrix (image) -----
    labels_txt = [str(v) for v in INFO_META["label"].values()] if isinstance(INFO_META["label"], dict) else [str(v) for v in INFO_META["label"]]
    y_true = test_plain["y"].numpy(); y_pred = test_plain["logits"].argmax(1).numpy()
    confusion_matrix_image(y_true, y_pred, labels_txt, ds_dir/"confusion_matrix_test.png")

    # ----- speed -----
    ms_img = measure_inference_ms_per_img(model, loaders[2], warmup=2, batches=8)
    with open(ds_dir/"timing_main.json","w") as f: json.dump({"ms_per_img": ms_img}, f, indent=2)

    # ----- XAI (Saliency, SmoothGrad, IG, Occlusion) + SHAP + LIME -----
    def unnorm_local(x):
        m = torch.tensor(MEAN, device=x.device).view(1,-1,1,1)
        s = torch.tensor(STD, device=x.device).view(1,-1,1,1)
        return (x*s + m).clamp(0,1)

    def pick_correct_sample(model, loader, max_batches=30):
        model.eval()
        with torch.no_grad():
            for b,(x,y) in enumerate(loader):
                x=x.to(DEVICE); y=to_index(y).to(DEVICE)
                p = model(x).argmax(1); m = p.eq(y)
                if m.any():
                    i = int(torch.nonzero(m)[0]); return x[i:i+1], int(y[i].item()), int(p[i].item())
                if b>=max_batches: break
        return None,None,None

    x1, y_true1, y_pred1 = pick_correct_sample(model, loaders[2])
    if x1 is not None:
        def grad_saliency(model, x, t):
            xi = x.clone().detach().requires_grad_(True)
            l = model(xi); s = l[0, t]; model.zero_grad(); s.backward()
            g = xi.grad.detach().abs().max(1)[0]; g = (g - g.min())/(g.max()-g.min()+1e-8)
            return g[0].cpu().numpy()
        def smoothgrad(model, x, t, n=25, noise_std=0.1, seed=202):
            rng = torch.Generator(device=x.device).manual_seed(seed); acc = torch.zeros_like(x)
            for _ in range(n):
                noise = torch.normal(0, noise_std, size=x.shape, generator=rng, device=x.device)
                xi = (x+noise).clamp(0,1).detach().requires_grad_(True)
                l = model(xi); s = l[0, t]; model.zero_grad(); s.backward()
                acc += xi.grad.detach().abs()
            g = (acc/n).max(1)[0]; g=(g-g.min())/(g.max()-g.min()+1e-8); return g[0].cpu().numpy()
        def integrated_gradients(model, x, t, steps=32):
            baseline = torch.zeros_like(x); grads=[]
            for i in range(1, steps+1):
                xi = (baseline + i/steps*(x-baseline)).detach().requires_grad_(True)
                l = model(xi); s = l[0, t]; model.zero_grad(); s.backward()
                grads.append(xi.grad.detach())
            avg = torch.stack(grads,0).mean(0)
            ig = ((x-baseline)*avg).abs().max(1)[0]
            ig = (ig - ig.min())/(ig.max()-ig.min()+1e-8)
            return ig[0].cpu().numpy()
        def occlusion_map(model, x, t, win=6, stride=4):
            model.eval(); x0=x.clone(); H,W=x0.shape[2],x0.shape[3]
            base = torch.softmax(model(x0),1)[0, t].item()
            heat = torch.zeros((H,W))
            for yy in range(0,H-win+1,stride):
                for xx in range(0,W-win+1,stride):
                    x2=x0.clone(); x2[:,:,yy:yy+win,xx:xx+win]=0.0
                    p = torch.softmax(model(x2),1)[0, t].item()
                    heat[yy:yy+win,xx:xx+win] += (base-p)
            heat = heat.numpy(); heat=(heat-heat.min())/(heat.max()-heat.min()+1e-8)
            return heat

        img = unnorm_local(x1).detach().cpu().numpy()[0].transpose(1,2,0)
        sal = grad_saliency(model, x1, y_pred1)
        smg = smoothgrad(model, x1, y_pred1)
        ig  = integrated_gradients(model, x1, y_pred1)
        occ = occlusion_map(model, x1, y_pred1)
        make_xai_panel(img, sal, smg, ig, occ, ds_dir/"xai_panels.png", title=f"{name} (T={y_true1}, P={y_pred1})")

        # ----- SHAP on quantum inputs -----
        if HAVE_SHAP:
            try:
                model.eval()
                # background from val
                bg=[]
                for xb,_ in loaders[1]:
                    xb=xb.to(DEVICE)
                    with torch.no_grad():
                        f = model.stem(xb); qin = model.project(f)
                    bg.append(qin[:64].detach().cpu().numpy()); break
                bg = np.concatenate(bg,0)

                def prob_from_qinput(qin_np):
                    qin = torch.tensor(qin_np, dtype=torch.float32, device=DEVICE)
                    with torch.no_grad():
                        qout = model.q(qin).to(qin.dtype)
                        fused = torch.cat([qout, model.qrg_gate(qin)*qin], dim=1) if model.qrg_gate is not None else qout
                        logits = model.head(fused); p = torch.softmax(logits,1).cpu().numpy()
                    return p

                expl = shap.KernelExplainer(prob_from_qinput, bg)
                with torch.no_grad():
                    z_sample = model.project(model.stem(x1)).detach().cpu().numpy()
                shap_vals = expl.shap_values(z_sample, nsamples=200)
                sv = shap_vals[y_pred1][0] if isinstance(shap_vals, list) else shap_vals[0]
                order = np.argsort(np.abs(sv))[::-1]
                vals = list(np.abs(sv)[order]); labs = [f"q{i+1}" for i in order]
                shap_bar_image(vals, labs, f"SHAP |value| per quantum input — {name}", ds_dir/"shap_bar.png")
            except Exception as e:
                with open(ds_dir/"shap_error.txt","w") as f: f.write(str(e))

        # ----- LIME (image) -----
        if HAVE_LIME:
            try:
                from lime import lime_image
                def classifier_fn(imgs_np):
                    x = torch.tensor(imgs_np.transpose(0,3,1,2), dtype=torch.float32)
                    m = torch.tensor(MEAN).view(1,-1,1,1); s = torch.tensor(STD).view(1,-1,1,1)
                    x = (x - m)/s
                    with torch.no_grad():
                        p = torch.softmax(model(x.to(DEVICE)),1).cpu().numpy()
                    return p
                img_vis = img.copy()
                explainer = lime_image.LimeImageExplainer()
                explanation = explainer.explain_instance(img_vis, classifier_fn, top_labels=1, hide_color=0, num_samples=1000)
                temp, mask = explanation.get_image_and_mask(label=y_pred1, positive_only=True, num_features=6, hide_rest=False)
                save_png(np.clip(temp.astype(np.float32)/255.0,0,1), ds_dir/"lime.png")
            except Exception as e:
                with open(ds_dir/"lime_error.txt","w") as f: f.write(str(e))

    # ----- snapshot ensemble (metrics only; plots via image bars) -----
    snapshot_paths = sorted([str(p) for p in ds_dir.glob("hybrid_main_snapshot_ep*.pt")])
    if SNAPSHOT_ENSEMBLE and snapshot_paths:
        def ctor(): return HybridCnnQnn(K).to(DEVICE)
        ens_plain = evaluate_snapshot_ensemble(snapshot_paths, ctor, loaders, K, use_tta=False)
        ens_tta   = evaluate_snapshot_ensemble(snapshot_paths, ctor, loaders, K, use_tta=True)
        with open(ds_dir/"snapshot_ensemble_summary.json","w") as f:
            json.dump({"plain": ens_plain, "tta": ens_tta}, f, indent=2)
        bars_as_image([test_plain["acc"], ens_plain["acc"]], ["Single","SnapshotEns"],
                      f"{name}: Accuracy — Plain", ds_dir/"ensemble_vs_single_plain.png")
        bars_as_image([test_tta["acc"], ens_tta["acc"]], ["Single","SnapshotEns"],
                      f"{name}: Accuracy — TTA", ds_dir/"ensemble_vs_single_tta.png")

    # ----- summary json -----
    main_summary = {
        "dataset": name,
        "acc": test_plain["acc"], "f1": test_plain["f1"], "auroc": test_plain["auroc"], "ece": test_plain["ece"],
        "acc_tta": test_tta["acc"], "f1_tta": test_tta["f1"], "auroc_tta": test_tta["auroc"], "ece_tta": test_tta["ece"],
        **({} if not calib else {"T": calib["T"], "ECE_after": calib["ECE_after"]}),
        "ms_per_img": ms_img
    }
    with open(ds_dir/"main_summary.json","w") as f:
        json.dump({k: float(v) if isinstance(v,(int,float)) else v for k,v in main_summary.items()}, f, indent=2)
    print(f"[{name}] DONE → {ds_dir}")
    return main_summary

# ---------- Run the plan (same datasets as before) ----------
RUN_DATASETS = CORE_DATASETS + [VIABILITY_DATASET]
all_summaries=[]
for ds in RUN_DATASETS:
    s = run_one_dataset(ds); all_summaries.append(s)

with open(ROOT_DIR/"cross_dataset_summary.json","w") as f: json.dump(all_summaries, f, indent=2)

# Cross-dataset comparisons — using our image bars
bar_cross_image(all_summaries, "acc",   "Hybrid CNN→QNN — Accuracy across datasets",   "cross_acc.png")
bar_cross_image(all_summaries, "f1",    "Hybrid CNN→QNN — Macro-F1 across datasets",   "cross_f1.png")
bar_cross_image(all_summaries, "auroc", "Hybrid CNN→QNN — Macro-AUROC across datasets","cross_auroc.png")
bar_cross_image(all_summaries, "ece",   "Hybrid CNN→QNN — ECE across datasets",        "cross_ece.png")

# LaTeX table (unchanged)
rows=[]
for s in all_summaries:
    rows.append(" & ".join([s["dataset"], f"{s['acc']:.3f}", f"{s['f1']:.3f}", f"{s['auroc']:.3f}", f"{s['ece']:.3f}", f"{s['ms_per_img']:.2f}"]) + " \\\\")
table = "\\begin{tabular}{lccccc}\n\\toprule\nDataset & Acc & Macro-F1 & Macro-AUROC & ECE & ms/img \\\\\n\\midrule\n" + \
        "\n".join(rows) + "\n\\bottomrule\n\\end{tabular}\n"
with open(ROOT_DIR/"table_cross_datasets.tex","w") as f: f.write(table)

print("\n[ALL DONE] Outputs under:", ROOT_DIR.resolve())
print("Per-dataset figs (no-mpl): confusion_matrix_test.png, reliability_*.png, xai_panels.png, shap_bar.png, lime.png, ensemble_vs_single_*.png, cross_*.png")



[Core] Preparing BloodMNIST …


[BloodMNIST] compute mean/std:   0%|          | 0/24 [00:00<?, ?it/s]

[BloodMNIST] class counts:   0%|          | 0/11959 [00:00<?, ?it/s]

[BloodMNIST] class counts:   0%|          | 0/1712 [00:00<?, ?it/s]

[BloodMNIST] class counts:   0%|          | 0/3421 [00:00<?, ?it/s]

[hybrid_main] Train ep 1/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 1/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep01 val_acc=0.162 (best 0.162)


[hybrid_main] Train ep 2/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 2/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep02 val_acc=0.159 (best 0.162)


[hybrid_main] Train ep 3/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 3/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep03 val_acc=0.162 (best 0.162)


[hybrid_main] Train ep 4/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 4/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep04 val_acc=0.221 (best 0.221)


[hybrid_main] Train ep 5/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 5/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep05 val_acc=0.334 (best 0.334)


[hybrid_main] Train ep 6/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 6/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep06 val_acc=0.437 (best 0.437)


[hybrid_main] Train ep 7/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 7/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep07 val_acc=0.562 (best 0.562)


[hybrid_main] Train ep 8/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 8/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep08 val_acc=0.669 (best 0.669)


[hybrid_main] Train ep 9/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 9/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep09 val_acc=0.716 (best 0.716)


[hybrid_main] Train ep 10/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 10/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep10 val_acc=0.758 (best 0.758)


[hybrid_main] Train ep 11/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 11/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep11 val_acc=0.756 (best 0.758)


[hybrid_main] Train ep 12/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 12/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep12 val_acc=0.771 (best 0.771)


[hybrid_main] Train ep 13/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 13/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep13 val_acc=0.783 (best 0.783)


[hybrid_main] Train ep 14/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 14/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep14 val_acc=0.783 (best 0.783)


[hybrid_main] Train ep 15/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 15/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep15 val_acc=0.796 (best 0.796)


[hybrid_main] Train ep 16/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 16/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep16 val_acc=0.801 (best 0.801)


[hybrid_main] Train ep 17/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 17/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep17 val_acc=0.849 (best 0.849)


[hybrid_main] Train ep 18/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 18/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep18 val_acc=0.808 (best 0.849)


[hybrid_main] Train ep 19/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 19/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep19 val_acc=0.858 (best 0.858)


[hybrid_main] Train ep 20/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 20/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep20 val_acc=0.850 (best 0.858)


[hybrid_main] Train ep 21/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 21/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep21 val_acc=0.880 (best 0.880)


[hybrid_main] Train ep 22/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 22/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep22 val_acc=0.918 (best 0.918)


[hybrid_main] Train ep 23/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 23/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep23 val_acc=0.888 (best 0.918)


[hybrid_main] Train ep 24/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 24/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep24 val_acc=0.881 (best 0.918)


[hybrid_main] Train ep 25/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 25/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep25 val_acc=0.859 (best 0.918)


[hybrid_main] Train ep 26/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 26/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep26 val_acc=0.898 (best 0.918)


[hybrid_main] Train ep 27/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 27/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep27 val_acc=0.905 (best 0.918)


[hybrid_main] Train ep 28/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 28/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep28 val_acc=0.905 (best 0.918)


[hybrid_main] Train ep 29/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 29/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep29 val_acc=0.907 (best 0.918)


[hybrid_main] Train ep 30/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 30/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep30 val_acc=0.912 (best 0.918)


[hybrid_main] Train ep 31/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 31/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep31 val_acc=0.925 (best 0.925)


[hybrid_main] Train ep 32/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 32/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep32 val_acc=0.938 (best 0.938)


[hybrid_main] Train ep 33/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 33/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep33 val_acc=0.896 (best 0.938)


[hybrid_main] Train ep 34/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 34/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep34 val_acc=0.929 (best 0.938)


[hybrid_main] Train ep 35/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 35/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep35 val_acc=0.938 (best 0.938)


[hybrid_main] Train ep 36/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 36/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep36 val_acc=0.943 (best 0.943)


[hybrid_main] Train ep 37/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 37/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep37 val_acc=0.936 (best 0.943)


[hybrid_main] Train ep 38/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 38/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep38 val_acc=0.932 (best 0.943)


[hybrid_main] Train ep 39/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 39/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep39 val_acc=0.943 (best 0.943)


[hybrid_main] Train ep 40/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 40/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep40 val_acc=0.936 (best 0.943)


[hybrid_main] Train ep 41/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 41/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep41 val_acc=0.931 (best 0.943)


[hybrid_main] Train ep 42/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 42/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep42 val_acc=0.938 (best 0.943)


[hybrid_main] Train ep 43/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 43/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep43 val_acc=0.936 (best 0.943)


[hybrid_main] Train ep 44/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 44/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep44 val_acc=0.937 (best 0.943)


[hybrid_main] Train ep 45/45:   0%|          | 0/47 [00:00<?, ?it/s]

[hybrid_main] Val   ep 45/45:   0%|          | 0/7 [00:00<?, ?it/s]

[hybrid_main] ep45 val_acc=0.933 (best 0.943)


[Eval] test:   0%|          | 0/14 [00:00<?, ?it/s]

[Eval] test:   0%|          | 0/14 [00:00<?, ?it/s]

[BloodMNIST] Test plain: acc=0.934 f1=0.924 auroc=0.995 ece=0.087
[BloodMNIST] Test  TTA : acc=0.932 f1=0.921 auroc=0.995 ece=0.091
[BloodMNIST] Calibrated: T=0.636  ECE 0.087→0.013


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

[SnapshotEnsemble] test:   0%|          | 0/14 [00:00<?, ?it/s]

[SnapshotEnsemble] test:   0%|          | 0/14 [00:00<?, ?it/s]

[BloodMNIST] DONE → qnn_hybrid_runs/20251011_162052/BloodMNIST

[Core] Preparing PneumoniaMNIST …


100%|██████████| 4.17M/4.17M [00:06<00:00, 633kB/s]


[PneumoniaMNIST] compute mean/std:   0%|          | 0/10 [00:00<?, ?it/s]

[PneumoniaMNIST] class counts:   0%|          | 0/4708 [00:00<?, ?it/s]

[PneumoniaMNIST] class counts:   0%|          | 0/524 [00:00<?, ?it/s]

[PneumoniaMNIST] class counts:   0%|          | 0/624 [00:00<?, ?it/s]

[hybrid_main] Train ep 1/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 1/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep01 val_acc=0.355 (best 0.355)


[hybrid_main] Train ep 2/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 2/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep02 val_acc=0.700 (best 0.700)


[hybrid_main] Train ep 3/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 3/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep03 val_acc=0.845 (best 0.845)


[hybrid_main] Train ep 4/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 4/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep04 val_acc=0.828 (best 0.845)


[hybrid_main] Train ep 5/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 5/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep05 val_acc=0.836 (best 0.845)


[hybrid_main] Train ep 6/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 6/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep06 val_acc=0.851 (best 0.851)


[hybrid_main] Train ep 7/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 7/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep07 val_acc=0.891 (best 0.891)


[hybrid_main] Train ep 8/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 8/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep08 val_acc=0.895 (best 0.895)


[hybrid_main] Train ep 9/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 9/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep09 val_acc=0.918 (best 0.918)


[hybrid_main] Train ep 10/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 10/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep10 val_acc=0.903 (best 0.918)


[hybrid_main] Train ep 11/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 11/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep11 val_acc=0.929 (best 0.929)


[hybrid_main] Train ep 12/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 12/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep12 val_acc=0.933 (best 0.933)


[hybrid_main] Train ep 13/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 13/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep13 val_acc=0.912 (best 0.933)


[hybrid_main] Train ep 14/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 14/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep14 val_acc=0.752 (best 0.933)


[hybrid_main] Train ep 15/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 15/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep15 val_acc=0.920 (best 0.933)


[hybrid_main] Train ep 16/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 16/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep16 val_acc=0.643 (best 0.933)


[hybrid_main] Train ep 17/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 17/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep17 val_acc=0.941 (best 0.941)


[hybrid_main] Train ep 18/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 18/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep18 val_acc=0.950 (best 0.950)


[hybrid_main] Train ep 19/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 19/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep19 val_acc=0.948 (best 0.950)


[hybrid_main] Train ep 20/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 20/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep20 val_acc=0.935 (best 0.950)


[hybrid_main] Train ep 21/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 21/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep21 val_acc=0.926 (best 0.950)


[hybrid_main] Train ep 22/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 22/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep22 val_acc=0.838 (best 0.950)


[hybrid_main] Train ep 23/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 23/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep23 val_acc=0.817 (best 0.950)


[hybrid_main] Train ep 24/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 24/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep24 val_acc=0.933 (best 0.950)


[hybrid_main] Train ep 25/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 25/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep25 val_acc=0.943 (best 0.950)


[hybrid_main] Train ep 26/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 26/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep26 val_acc=0.842 (best 0.950)


[hybrid_main] Train ep 27/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 27/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep27 val_acc=0.782 (best 0.950)


[hybrid_main] Train ep 28/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 28/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep28 val_acc=0.964 (best 0.964)


[hybrid_main] Train ep 29/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 29/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep29 val_acc=0.903 (best 0.964)


[hybrid_main] Train ep 30/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 30/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep30 val_acc=0.823 (best 0.964)


[hybrid_main] Train ep 31/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 31/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep31 val_acc=0.735 (best 0.964)


[hybrid_main] Train ep 32/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 32/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep32 val_acc=0.830 (best 0.964)


[hybrid_main] Train ep 33/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 33/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep33 val_acc=0.933 (best 0.964)


[hybrid_main] Train ep 34/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 34/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep34 val_acc=0.933 (best 0.964)


[hybrid_main] Train ep 35/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 35/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep35 val_acc=0.958 (best 0.964)


[hybrid_main] Train ep 36/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 36/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep36 val_acc=0.952 (best 0.964)


[hybrid_main] Train ep 37/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 37/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep37 val_acc=0.937 (best 0.964)


[hybrid_main] Train ep 38/45:   0%|          | 0/19 [00:00<?, ?it/s]

[hybrid_main] Val   ep 38/45:   0%|          | 0/3 [00:00<?, ?it/s]

[hybrid_main] ep38 val_acc=0.906 (best 0.964)
[hybrid_main] Early stop at epoch 38 (patience 10)


[Eval] test:   0%|          | 0/3 [00:00<?, ?it/s]

[Eval] test:   0%|          | 0/3 [00:00<?, ?it/s]

[PneumoniaMNIST] Test plain: acc=0.825 f1=0.788 auroc=0.959 ece=0.094
[PneumoniaMNIST] Test  TTA : acc=0.845 f1=0.818 auroc=0.958 ece=0.060
[PneumoniaMNIST] Calibrated: T=0.529  ECE 0.094→0.120


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

[SnapshotEnsemble] test:   0%|          | 0/3 [00:00<?, ?it/s]

[SnapshotEnsemble] test:   0%|          | 0/3 [00:00<?, ?it/s]

[PneumoniaMNIST] DONE → qnn_hybrid_runs/20251011_162052/PneumoniaMNIST

[Core] Preparing DermaMNIST …


100%|██████████| 19.7M/19.7M [00:01<00:00, 10.6MB/s]


[DermaMNIST] compute mean/std:   0%|          | 0/14 [00:00<?, ?it/s]

[DermaMNIST] class counts:   0%|          | 0/7007 [00:00<?, ?it/s]

[DermaMNIST] class counts:   0%|          | 0/1003 [00:00<?, ?it/s]

[DermaMNIST] class counts:   0%|          | 0/2005 [00:00<?, ?it/s]

[hybrid_main] Train ep 1/20:   0%|          | 0/28 [00:00<?, ?it/s]

[hybrid_main] Val   ep 1/20:   0%|          | 0/4 [00:00<?, ?it/s]

[hybrid_main] ep01 val_acc=0.076 (best 0.076)


[hybrid_main] Train ep 2/20:   0%|          | 0/28 [00:00<?, ?it/s]

[hybrid_main] Val   ep 2/20:   0%|          | 0/4 [00:00<?, ?it/s]

[hybrid_main] ep02 val_acc=0.034 (best 0.076)


[hybrid_main] Train ep 3/20:   0%|          | 0/28 [00:00<?, ?it/s]

[hybrid_main] Val   ep 3/20:   0%|          | 0/4 [00:00<?, ?it/s]

[hybrid_main] ep03 val_acc=0.037 (best 0.076)


[hybrid_main] Train ep 4/20:   0%|          | 0/28 [00:00<?, ?it/s]

[hybrid_main] Val   ep 4/20:   0%|          | 0/4 [00:00<?, ?it/s]

[hybrid_main] ep04 val_acc=0.065 (best 0.076)


[hybrid_main] Train ep 5/20:   0%|          | 0/28 [00:00<?, ?it/s]

[hybrid_main] Val   ep 5/20:   0%|          | 0/4 [00:00<?, ?it/s]

[hybrid_main] ep05 val_acc=0.174 (best 0.174)


[hybrid_main] Train ep 6/20:   0%|          | 0/28 [00:00<?, ?it/s]

[hybrid_main] Val   ep 6/20:   0%|          | 0/4 [00:00<?, ?it/s]

[hybrid_main] ep06 val_acc=0.301 (best 0.301)


[hybrid_main] Train ep 7/20:   0%|          | 0/28 [00:00<?, ?it/s]

[hybrid_main] Val   ep 7/20:   0%|          | 0/4 [00:00<?, ?it/s]

[hybrid_main] ep07 val_acc=0.383 (best 0.383)


[hybrid_main] Train ep 8/20:   0%|          | 0/28 [00:00<?, ?it/s]

[hybrid_main] Val   ep 8/20:   0%|          | 0/4 [00:00<?, ?it/s]

[hybrid_main] ep08 val_acc=0.394 (best 0.394)


[hybrid_main] Train ep 9/20:   0%|          | 0/28 [00:00<?, ?it/s]

[hybrid_main] Val   ep 9/20:   0%|          | 0/4 [00:00<?, ?it/s]

[hybrid_main] ep09 val_acc=0.469 (best 0.469)


[hybrid_main] Train ep 10/20:   0%|          | 0/28 [00:00<?, ?it/s]

[hybrid_main] Val   ep 10/20:   0%|          | 0/4 [00:00<?, ?it/s]

[hybrid_main] ep10 val_acc=0.346 (best 0.469)


[hybrid_main] Train ep 11/20:   0%|          | 0/28 [00:00<?, ?it/s]

[hybrid_main] Val   ep 11/20:   0%|          | 0/4 [00:00<?, ?it/s]

[hybrid_main] ep11 val_acc=0.403 (best 0.469)


[hybrid_main] Train ep 12/20:   0%|          | 0/28 [00:00<?, ?it/s]

[hybrid_main] Val   ep 12/20:   0%|          | 0/4 [00:00<?, ?it/s]

[hybrid_main] ep12 val_acc=0.494 (best 0.494)


[hybrid_main] Train ep 13/20:   0%|          | 0/28 [00:00<?, ?it/s]

[hybrid_main] Val   ep 13/20:   0%|          | 0/4 [00:00<?, ?it/s]

[hybrid_main] ep13 val_acc=0.457 (best 0.494)


[hybrid_main] Train ep 14/20:   0%|          | 0/28 [00:00<?, ?it/s]

[hybrid_main] Val   ep 14/20:   0%|          | 0/4 [00:00<?, ?it/s]

[hybrid_main] ep14 val_acc=0.463 (best 0.494)


[hybrid_main] Train ep 15/20:   0%|          | 0/28 [00:00<?, ?it/s]

[hybrid_main] Val   ep 15/20:   0%|          | 0/4 [00:00<?, ?it/s]

[hybrid_main] ep15 val_acc=0.491 (best 0.494)


[hybrid_main] Train ep 16/20:   0%|          | 0/28 [00:00<?, ?it/s]

[hybrid_main] Val   ep 16/20:   0%|          | 0/4 [00:00<?, ?it/s]

[hybrid_main] ep16 val_acc=0.473 (best 0.494)


[hybrid_main] Train ep 17/20:   0%|          | 0/28 [00:00<?, ?it/s]

[hybrid_main] Val   ep 17/20:   0%|          | 0/4 [00:00<?, ?it/s]

[hybrid_main] ep17 val_acc=0.519 (best 0.519)


[hybrid_main] Train ep 18/20:   0%|          | 0/28 [00:00<?, ?it/s]

[hybrid_main] Val   ep 18/20:   0%|          | 0/4 [00:00<?, ?it/s]

[hybrid_main] ep18 val_acc=0.479 (best 0.519)


[hybrid_main] Train ep 19/20:   0%|          | 0/28 [00:00<?, ?it/s]

[hybrid_main] Val   ep 19/20:   0%|          | 0/4 [00:00<?, ?it/s]

[hybrid_main] ep19 val_acc=0.519 (best 0.519)


[hybrid_main] Train ep 20/20:   0%|          | 0/28 [00:00<?, ?it/s]

[hybrid_main] Val   ep 20/20:   0%|          | 0/4 [00:00<?, ?it/s]

[hybrid_main] ep20 val_acc=0.516 (best 0.519)


[Eval] test:   0%|          | 0/8 [00:00<?, ?it/s]

[Eval] test:   0%|          | 0/8 [00:00<?, ?it/s]

[DermaMNIST] Test plain: acc=0.523 f1=0.375 auroc=0.867 ece=0.197
[DermaMNIST] Test  TTA : acc=0.525 f1=0.377 auroc=0.867 ece=0.194
[DermaMNIST] Calibrated: T=0.723  ECE 0.197→0.148


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

[DermaMNIST] DONE → qnn_hybrid_runs/20251011_162052/DermaMNIST

[ALL DONE] Outputs under: /content/qnn_hybrid_runs/20251011_162052
Per-dataset figs (no-mpl): confusion_matrix_test.png, reliability_*.png, xai_panels.png, shap_bar.png, lime.png, ensemble_vs_single_*.png, cross_*.png
