In [1]:

import torch
import torch.nn as nn
import timm
from google.colab import files

# Define the same model architecture you used during training
CLASSES = ["B", "A", "M"]

class EfficientNetB0_3Way(nn.Module):
    def __init__(self, drop_rate=0.1, drop_path_rate=0.25, pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(
            "efficientnet_b0",
            pretrained=pretrained,
            num_classes=0,  # penultimate features
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate
        )
        in_feats = self.backbone.num_features  # 1280 for b0
        self.cls = nn.Linear(in_feats, len(CLASSES))

    def forward(self, x):
        feats = self.backbone(x)
        logits = self.cls(feats)
        return logits



# Get the first uploaded file path
ckpt_path = '/content/efficientnet_b0_3way_best (3).pt'

# ---- Load checkpoint ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EfficientNetB0_3Way(pretrained=False)

# Try to load as dict
obj = torch.load(ckpt_path, map_location="cpu")

# Handle different checkpoint structures
if isinstance(obj, dict) and "model" in obj:
    sd = obj["model"]
elif isinstance(obj, dict) and "state_dict" in obj:
    sd = obj["state_dict"]
elif isinstance(obj, dict):
    sd = obj
else:
    raise ValueError("Unsupported checkpoint format.")

# Clean up common key prefixes
def clean_state_dict(sd):
    clean = {}
    for k, v in sd.items():
        k = k.replace("classifier.", "cls.")  # older naming
        if k.startswith("module."):
            k = k[len("module."):]
        clean[k] = v
    return clean

sd = clean_state_dict(sd)
missing, unexpected = model.load_state_dict(sd, strict=False)
model.to(device).eval()

print("\n================ Sanity Check ================")
print(f"Missing keys: {len(missing)}")
print(f"Unexpected keys: {len(unexpected)}")
print("First few missing keys:", missing[:10])
print("First few unexpected keys:", unexpected[:10])
print("Model loaded on:", device)
print("✅ If missing/unexpected ≈ 0–2, your model is correct and ready.")




Missing keys: 0
Unexpected keys: 0
First few missing keys: []
First few unexpected keys: []
Model loaded on: cuda
✅ If missing/unexpected ≈ 0–2, your model is correct and ready.


In [3]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [5]:
# ============================================================
# Baselines with Option A fix + Train artifacts caching to Drive
# ============================================================

import os, json
from pathlib import Path
from collections import defaultdict, Counter

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from PIL import Image

# -----------------------------
# CONFIG
# -----------------------------
SPLITS_CSV   = Path("/content/drive/MyDrive/BRACS/splits.csv")
PATCH_ROOT   = Path("/content/drive/MyDrive/BRACS/ROIPatches")
ARTIFACTS_ROOT = Path("/content/drive/MyDrive/BRACS/baseline_artifacts")  # cache root on Drive

EVAL_SPLITS  = ["test"]   # evaluation splits to run
CLASS_ORDER  = ["B", "A", "M"]
CLASS_TO_IDX = {c: i for i, c in enumerate(CLASS_ORDER)}
NUM_CLASSES  = len(CLASS_ORDER)
IMG_SIZE     = 224
BATCH_SIZE   = 128
NUM_WORKERS  = 2

# Cache/save behavior
CACHE_TRAIN_ARTIFACTS   = False   # save train features/probabilities to Drive
REUSE_TRAIN_ARTIFACTS   = True   # if saved artifacts exist for train, load instead of recomputing
CACHE_OTHER_SPLITS      = False  # set True if you also want to cache val/test similarly

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Sanity: make sure the uploaded `model` exists
try:
    _ = model.eval()
except NameError as e:
    raise RuntimeError("`model` not found. Please run the upload/sanity-check cell first so `model` is defined.") from e

# -----------------------------
# UTILS: Autocast context manager (no-op on CPU)
# -----------------------------
def get_autocast():
    if torch.cuda.is_available():
        return torch.amp.autocast('cuda')
    class _NoOp:
        def __enter__(self): return None
        def __exit__(self, *args): return False
    return _NoOp()

# -----------------------------
# DATA
# -----------------------------
def eval_transforms(img_size=224):
    return T.Compose([
        T.Resize((img_size, img_size)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]),
    ])

class PatchCSVDataset(Dataset):
    """
    Reads rows from SPLITS_CSV filtered by split.
    Each row: patch_path, roi_id, split, label (B/A/M).
    If patch_path does not exist as-is, tries PATCH_ROOT/split/roi_id/filename.
    """
    def __init__(self, csv_path: Path, split: str, transform=None):
        assert csv_path.exists(), f"Missing CSV: {csv_path}"
        df = pd.read_csv(csv_path)
        df = df[df["split"] == split].copy()

        for col in ["patch_path", "roi_id", "split", "label"]:
            assert col in df.columns, f"Missing column: {col}"

        df["y"] = df["label"].map(CLASS_TO_IDX)
        self.df = df.reset_index(drop=True)
        self.transform = transform or eval_transforms(IMG_SIZE)
        self.split = split

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        p = Path(row["patch_path"])
        if not p.exists():
            # fallback: PATCH_ROOT/<split>/<roi_id>/<filename>
            p = PATCH_ROOT / str(row["split"]) / str(row["roi_id"]) / Path(row["patch_path"]).name

        img = Image.open(p).convert("RGB")
        x = self.transform(img)
        y = int(row["y"])
        roi_id = str(row["roi_id"])
        return x, y, roi_id

def make_loader(split: str):
    ds = PatchCSVDataset(SPLITS_CSV, split, transform=eval_transforms(IMG_SIZE))
    dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False,
                    num_workers=NUM_WORKERS, pin_memory=True)
    return ds, dl

# -----------------------------
# COLLECTORS (shared once per split)
# -----------------------------
@torch.inference_mode()
def collect_patch_outputs_by_roi(dloader, model, use_amp=True):
    """
    Returns:
      roi_probs: dict[roi_id] -> tensor (N_patches, C) on CPU (float32)
      roi_preds: dict[roi_id] -> np.array (N_patches,)
      roi_true:  dict[roi_id] -> int
    """
    by_roi_probs = defaultdict(list)
    by_roi_preds = defaultdict(list)
    roi_true = {}

    ac = get_autocast() if (use_amp) else get_autocast()  # safe; NoOp on CPU
    dev = next(model.parameters()).device

    for xb, yb, roi_ids in tqdm(dloader, desc="Collecting patch outputs"):
        xb = xb.to(dev, non_blocking=True)
        yb = yb.numpy()
        with ac:
            logits = model(xb)
            probs = torch.softmax(logits, dim=1)
        preds = probs.argmax(dim=1).detach().cpu().numpy()
        probs = probs.detach().float().cpu()  # ensure float32 on CPU
        for i, roi in enumerate(roi_ids):
            by_roi_probs[roi].append(probs[i].unsqueeze(0))
            by_roi_preds[roi].append(int(preds[i]))
            if roi not in roi_true:
                roi_true[roi] = int(yb[i])

    roi_probs = {k: torch.cat(v, dim=0) for k, v in by_roi_probs.items()}
    roi_preds = {k: np.array(v, dtype=np.int64) for k, v in by_roi_preds.items()}
    return roi_probs, roi_preds, roi_true

@torch.inference_mode()
def collect_penultimate_by_roi(dloader, model, use_amp=True):
    """
    Returns:
      roi_feats: dict[roi_id] -> tensor (N_patches, D) on CPU (may be fp16 from AMP; we'll cast later)
      roi_true:  dict[roi_id] -> int
    """
    by_roi_feats = defaultdict(list)
    roi_true = {}

    ac = get_autocast() if (use_amp) else get_autocast()
    dev = next(model.parameters()).device

    for xb, yb, roi_ids in tqdm(dloader, desc="Collecting penultimate features"):
        xb = xb.to(dev, non_blocking=True)
        yb = yb.numpy()
        with ac:
            feats = model.backbone(xb)  # (B, D)
        feats = feats.detach().cpu()   # keep dtype as produced by AMP; we'll cast when feeding FC
        for i, roi in enumerate(roi_ids):
            by_roi_feats[roi].append(feats[i].unsqueeze(0))
            if roi not in roi_true:
                roi_true[roi] = int(yb[i])

    roi_feats = {k: torch.cat(v, dim=0) for k, v in by_roi_feats.items()}
    return roi_feats, roi_true

# -----------------------------
# BASELINES
# -----------------------------
def baseline_patch_majority(roi_preds, roi_true, **kwargs):
    y_true, y_pred = [], []
    for roi, patch_preds in roi_preds.items():
        cnt = Counter(patch_preds.tolist())
        pred = max(cnt.items(), key=lambda kv: (kv[1], -kv[0]))[0]
        y_pred.append(pred)
        y_true.append(roi_true[roi])
    return np.array(y_true), np.array(y_pred)

def baseline_patch_mean_prob(roi_probs, roi_true, **kwargs):
    y_true, y_pred = [], []
    for roi, probs in roi_probs.items():
        mean_p = probs.mean(dim=0)            # (C,)
        pred = int(mean_p.argmax().item())
        y_pred.append(pred)
        y_true.append(roi_true[roi])
    return np.array(y_true), np.array(y_pred)

def baseline_roi_penultimate_mean(roi_feats, roi_true, fc_layer, **kwargs):
    """
    ROI-Classifier-Penultimate-Mean (Option A fix):
      - Average penultimate features per ROI → (1, D)
      - Cast to FC layer's device & dtype
      - Classify with trained FC layer
    """
    y_true, y_pred = [], []
    fc_device = next(fc_layer.parameters()).device
    fc_dtype  = next(fc_layer.parameters()).dtype
    for roi, feats in roi_feats.items():
        mu = feats.mean(dim=0, keepdim=True)                 # CPU, maybe half
        mu = mu.to(device=fc_device, dtype=fc_dtype, non_blocking=True)
        logits = fc_layer(mu)                                # (1, C)
        pred = int(logits.argmax(dim=1).item())
        y_pred.append(pred)
        y_true.append(roi_true[roi])
    return np.array(y_true), np.array(y_pred)

# -----------------------------
# METRICS
# -----------------------------
def confusion_matrix_three(y_true, y_pred, n_classes=3):
    cm = np.zeros((n_classes, n_classes), dtype=np.int64)
    for t, p in zip(y_true, y_pred):
        cm[t, p] += 1
    return cm

def per_class_precision_recall_f1(cm):
    n = cm.shape[0]
    prec, rec, f1 = [], [], []
    for c in range(n):
        tp = cm[c, c]
        fp = cm[:, c].sum() - tp
        fn = cm[c, :].sum() - tp
        p = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        r = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f = 2 * p * r / (p + r) if (p + r) > 0 else 0.0
        prec.append(p); rec.append(r); f1.append(f)
    return np.array(prec), np.array(rec), np.array(f1)

def gmean_from_recalls(recalls):
    vals = np.clip(recalls, 1e-12, 1.0)
    return float(np.prod(vals) ** (1.0 / len(vals)))

def make_metrics_table(cm, class_names):
    prec, rec, f1 = per_class_precision_recall_f1(cm)
    df = pd.DataFrame({
        "Class": class_names + ["Avg"],
        "Precision": list(prec) + [prec.mean()],
        "Recall":    list(rec)  + [rec.mean()],
        "F1":        list(f1)   + [f1.mean()],
    })
    gmean = gmean_from_recalls(rec)
    return df, gmean

# -----------------------------
# CACHE I/O
# -----------------------------
def split_cache_dir(split: str) -> Path:
    d = ARTIFACTS_ROOT / split
    d.mkdir(parents=True, exist_ok=True)
    return d

def save_artifacts(split: str, roi_probs=None, roi_preds=None, roi_true=None, roi_feats=None):
    d = split_cache_dir(split)
    if roi_probs is not None:
        torch.save(roi_probs, d / "roi_probs.pt")
    if roi_preds is not None:
        torch.save(roi_preds, d / "roi_preds.pt")
    if roi_true is not None:
        torch.save(roi_true,  d / "roi_true.pt")
    if roi_feats is not None:
        torch.save(roi_feats, d / "roi_feats.pt")
    # small index
    meta = {"saved": {
        "roi_probs": roi_probs is not None,
        "roi_preds": roi_preds is not None,
        "roi_true":  roi_true  is not None,
        "roi_feats": roi_feats is not None,
    }}
    with open(d / "meta.json", "w") as f:
        json.dump(meta, f, indent=2)

def load_artifacts_if_available(split: str):
    d = ARTIFACTS_ROOT / split
    if not d.exists():
        return None
    paths = {
        "roi_probs": d / "roi_probs.pt",
        "roi_preds": d / "roi_preds.pt",
        "roi_true":  d / "roi_true.pt",
        "roi_feats": d / "roi_feats.pt",
    }
    out = {}
    for k, p in paths.items():
        if p.exists():
            out[k] = torch.load(p, map_location="cpu")
        else:
            out[k] = None
    return out

# -----------------------------
# BASELINE REGISTRY & RUNNER
# -----------------------------
BASELINES = {
    "patch_majority":   (baseline_patch_majority,     {"roi_preds", "roi_true"}),
    "patch_mean_prob":  (baseline_patch_mean_prob,    {"roi_probs", "roi_true"}),
    "roi_penult_mean":  (baseline_roi_penultimate_mean, {"roi_feats", "roi_true", "fc_layer"}),
}

def run_baselines(baseline_names=None):
    """
    Scalable runner:
      - Computes needed artifacts per split ONCE (or loads from cache for train).
      - Runs all requested baselines.
      - Saves CSVs and JSON summary.
    """
    assert 'model' in globals(), "`model` (EfficientNetB0_3Way) must be defined."
    _ = model.eval()

    if baseline_names is None:
        baseline_names = list(BASELINES.keys())

    out_dir = Path("./baseline_outputs")
    out_dir.mkdir(parents=True, exist_ok=True)
    summary = {}

    # Ensure we also process train (for caching) if requested by flags
    splits_to_run = list(EVAL_SPLITS)


    for split in splits_to_run:
        print(f"\n=== Split: {split} ===")
        ds, dl = make_loader(split)
        print(f"{split} patches: {len(ds)}")

        # Decide if cache should be used
        use_cache = (split == "train" and REUSE_TRAIN_ARTIFACTS)
        save_cache = ((split == "train" and CACHE_TRAIN_ARTIFACTS) or (CACHE_OTHER_SPLITS))

        roi_probs = roi_preds = roi_true_probs = None
        roi_feats = roi_true_feats = None

        # What do selected baselines need?
        need_probs = any("roi_probs" in BASELINES[n][1] or "roi_preds" in BASELINES[n][1] for n in baseline_names)
        need_feats = any("roi_feats" in BASELINES[n][1] for n in baseline_names)

        # Try to load cached artifacts (train only by default)
        loaded = load_artifacts_if_available(split) if use_cache else None
        if loaded is not None:
            print(f"Cache found for split={split}. Using cached artifacts where available.")
            if need_probs:
                roi_probs = loaded.get("roi_probs", None)
                roi_preds = loaded.get("roi_preds", None)
                roi_true_probs = loaded.get("roi_true", None)
            if need_feats:
                roi_feats = loaded.get("roi_feats", None)
                # truth dict should be same
                roi_true_feats = loaded.get("roi_true", None)

        # Collect missing artifacts
        if need_probs and (roi_probs is None or roi_preds is None or roi_true_probs is None):
            roi_probs, roi_preds, roi_true_probs = collect_patch_outputs_by_roi(dl, model, use_amp=True)
            print(f"Artifacts: roi_probs/roi_preds collected. ROIs={len(roi_true_probs)}")

        if need_feats and (roi_feats is None or roi_true_feats is None):
            roi_feats, roi_true_feats = collect_penultimate_by_roi(dl, model, use_amp=True)
            print(f"Artifacts: roi_feats collected. ROIs={len(roi_true_feats)}")

        # Sanity: unify truth dict
        if need_probs and need_feats:
            roi_true = {k: roi_true_probs.get(k, roi_true_feats[k]) for k in set(roi_true_probs) | set(roi_true_feats)}
        elif need_probs:
            roi_true = roi_true_probs
        elif need_feats:
            roi_true = roi_true_feats
        else:
            roi_true = {}

        # Save cache if asked
        if save_cache:
            print(f"Saving artifacts for split={split} to Drive: {ARTIFACTS_ROOT / split}")
            save_artifacts(split,
                           roi_probs=roi_probs if need_probs else None,
                           roi_preds=roi_preds if need_probs else None,
                           roi_true=roi_true if (need_probs or need_feats) else None,
                           roi_feats=roi_feats if need_feats else None)

        split_res = {}
        for name in baseline_names:
            fn, req = BASELINES[name]
            kwargs = {}
            if "roi_probs" in req: kwargs["roi_probs"] = roi_probs
            if "roi_preds" in req: kwargs["roi_preds"] = roi_preds
            if "roi_feats" in req: kwargs["roi_feats"] = roi_feats
            if "roi_true"  in req: kwargs["roi_true"]  = roi_true
            if "fc_layer"  in req: kwargs["fc_layer"]  = model.cls

            yt, yp = fn(**kwargs)
            cm = confusion_matrix_three(yt, yp, n_classes=NUM_CLASSES)
            table, g = make_metrics_table(cm, CLASS_ORDER)

            print(f"\n[{name}] Confusion Matrix:\n{cm}")
            print(f"[{name}] G-Mean: {g:.6f}")
            try:
                display(table)
            except NameError:
                print(table)

            # Save CSV + raw preds
            split_out = out_dir / split
            split_out.mkdir(parents=True, exist_ok=True)
            table.to_csv(split_out / f"{split}__{name}__metrics.csv", index=False)
            np.save(split_out / f"{split}__{name}__ytrue.npy", yt)
            np.save(split_out / f"{split}__{name}__ypred.npy", yp)

            split_res[name] = {
                "confusion_matrix": cm.tolist(),
                "gmean": g,
                "metrics_table": table.to_dict(orient="list"),
            }

        summary[split] = split_res

    with open(out_dir / "results.json", "w") as f:
        json.dump(summary, f, indent=2)

    print("\nSaved outputs to:", out_dir.resolve())


# -----------------------------
# EVALUATION + SAVE
# -----------------------------
def run_three_baselines():
    run_baselines(baseline_names=[
        "patch_majority",
        "patch_mean_prob",
        "roi_penult_mean",
    ])

# Run
run_three_baselines()



=== Split: test ===
test patches: 10744


Collecting patch outputs: 100%|██████████| 84/84 [54:23<00:00, 38.85s/it]


Artifacts: roi_probs/roi_preds collected. ROIs=350


Collecting penultimate features: 100%|██████████| 84/84 [00:34<00:00,  2.43it/s]

Artifacts: roi_feats collected. ROIs=350

[patch_majority] Confusion Matrix:
[[139  10  14]
 [ 24  39   9]
 [  5  16  94]]
[patch_majority] G-Mean: 0.722764





Unnamed: 0,Class,Precision,Recall,F1
0,B,0.827381,0.852761,0.839879
1,A,0.6,0.541667,0.569343
2,M,0.803419,0.817391,0.810345
3,Avg,0.7436,0.737273,0.739856



[patch_mean_prob] Confusion Matrix:
[[138  11  14]
 [ 21  40  11]
 [  5  13  97]]
[patch_mean_prob] G-Mean: 0.734792


Unnamed: 0,Class,Precision,Recall,F1
0,B,0.841463,0.846626,0.844037
1,A,0.625,0.555556,0.588235
2,M,0.795082,0.843478,0.818565
3,Avg,0.753848,0.748553,0.750279



[roi_penult_mean] Confusion Matrix:
[[138  11  14]
 [ 22  39  11]
 [  6  13  96]]
[roi_penult_mean] G-Mean: 0.726104


Unnamed: 0,Class,Precision,Recall,F1
0,B,0.831325,0.846626,0.838906
1,A,0.619048,0.541667,0.577778
2,M,0.793388,0.834783,0.813559
3,Avg,0.74792,0.741025,0.743414



Saved outputs to: /content/baseline_outputs
