In [None]:
# patch encoder training

In [None]:
!mkdir -p /content/patches && rsync -ah --info=progress2 "/content/drive/MyDrive/BRACS/ROIPatches/" /content/patches/


        619.26M  80%  180.68kB/s    0:55:46 (xfr#5464, ir-chk=1264/7089)

In [3]:
# Colab: GPU + libs
!nvidia-smi
!pip -q install timm==0.9.16 torchmetrics==1.4.0

from google.colab import drive
drive.mount("/content/drive")  # authenticate in the pop-up

Sat Oct 18 12:30:33 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   36C    P0             47W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [14]:
import os, math, random, json, itertools, time
from pathlib import Path
from collections import defaultdict, Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.cuda.amp import GradScaler, autocast
import torchvision.transforms as T
from PIL import Image
import timm
import numpy as np

SEED = 1337
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

# === Paths (edit to your layout) ===
#ROOT = Path("/content/drive/MyDrive/BRACS/ROIPatches")
ROOT = Path("/content/patches")
SPLITS_CSV = Path("/content/drive/MyDrive/BRACS/splits.csv")   # patch_path,roi_id,split,label

# === Classes (3-way) ===
CLASSES = ["B", "A", "M"]  # Benign, Atypical, Malignant
class_to_idx = {c:i for i,c in enumerate(CLASSES)}

# === Training hyperparams per paper ===
IMAGE_SIZE = 224
BATCH_PATCHES = 512           # fixed #patches per batch (not #ROIs)  :contentReference[oaicite:3]{index=3}
MAX_PATCHES_PER_ROI = 30      # cap per-ROI to stabilize & regularize   :contentReference[oaicite:4]{index=4}
EPOCHS = 40                   # paper uses 40 epochs × 500 batches      :contentReference[oaicite:5]{index=5}
STEPS_PER_EPOCH = 20
LR = 3e-2                     # SGD lr=0.03, momentum=0.75              :contentReference[oaicite:6]{index=6}
MOMENTUM = 0.75
WEIGHT_DECAY = 1e-4           # exclude BN/bias from WD                  :contentReference[oaicite:7]{index=7}
DROP_RATE = 0.1               # dropout penultimate                       :contentReference[oaicite:8]{index=8}
DROP_PATH_RATE = 0.25         # stochastic depth                          :contentReference[oaicite:9]{index=9}
USE_AMP = True                # mixed precision                           :contentReference[oaicite:10]{index=10}

# === Oversampling (class balancing) ===
# In paper they oversample minority classes (Atypia ×3).
# For 3-class (B,A,M), we oversample A (×3)
OVERSAMPLE_MULT = {"B":1, "A":3, "M":1}

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


In [15]:
# H&E-safe augmentations (axis flips + 90° rotations + light color jitter)
def he_safe_train_transforms():
    return T.Compose([
        T.RandomResizedCrop(IMAGE_SIZE, scale=(0.8,1.0)),
        T.RandomHorizontalFlip(),
        T.RandomVerticalFlip(),
        # discrete 90° rotations (avoid free-angle; helps tissue realism)
        T.Lambda(lambda x: T.functional.rotate(x, random.choice([0,90,180,270]))),
        T.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.10, hue=0.02),
        T.ToTensor(),
        T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]), # ImageNet
    ])

def he_safe_eval_transforms():
    return T.Compose([
        T.Resize(256),
        T.CenterCrop(IMAGE_SIZE),
        T.ToTensor(),
        T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ])

class PatchCSV(Dataset):
    """
    Expects SPLITS_CSV with columns:
      patch_path,roi_id,split,label
    label ∈ {'B','A','M'}
    """
    def __init__(self, csv_path: Path, split: str, transform=None,
                 max_per_roi: int = None, oversample_mult: dict = None):
        import csv
        self.transform = transform
        rows = []
        with open(csv_path, newline='') as f:
            reader = csv.DictReader(f)
            for r in reader:
                if r["split"] != split: continue
                lbl = r["label"]
                if lbl not in class_to_idx: continue
                p = r["patch_path"]
                # ensure absolute path on Drive
                if not p.startswith("/"):
                    p = str(ROOT/ p)
                rows.append({"path": p, "roi": r["roi_id"], "y": class_to_idx[lbl]})

        # group by ROI
        self.roi_to_indices = defaultdict(list)
        self.samples = []
        for i, row in enumerate(rows):
            self.samples.append(row)
            self.roi_to_indices[row["roi"]].append(i)

        # cap patches per ROI
        if max_per_roi is not None:
            kept = []
            for roi, idxs in self.roi_to_indices.items():
                if len(idxs) > max_per_roi:
                    idxs = random.sample(idxs, max_per_roi)
                kept += idxs
            kept = set(kept)
            self.samples = [self.samples[i] for i in range(len(self.samples)) if i in kept]

        # oversample by class (for train only)
        if oversample_mult is not None:
            expanded = []
            for s in self.samples:
                cls = CLASSES[s["y"]]
                m = oversample_mult.get(cls, 1)
                expanded += [s]*m
            self.samples = expanded

        # rebuild roi_to_indices after modifications
        self.roi_to_indices = defaultdict(list)
        for i, row in enumerate(self.samples):
            self.roi_to_indices[row["roi"]].append(i)

        self.split = split
        self.class_to_idx = class_to_idx

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        row = self.samples[idx]
        img = Image.open(row["path"]).convert("RGB")
        if self.transform: img = self.transform(img)
        y = row["y"]
        roi = row["roi"]
        return img, torch.tensor(y, dtype=torch.long), roi


In [16]:
class FixedPatchBatchSampler(Sampler):
    """
    Iterates over dataset indices to produce batches with exactly BATCH_PATCHES patches.
    (We don't constrain #ROIs per batch, consistent with paper’s fixed patch count.) :contentReference[oaicite:12]{index=12}
    """
    def __init__(self, dataset: Dataset, steps_per_epoch: int, batch_patches: int):
        self.N = len(dataset)
        self.steps = steps_per_epoch
        self.bs = batch_patches

    def __iter__(self):
        for _ in range(self.steps):
            # uniform sample across all patches
            idxs = np.random.randint(0, self.N, size=self.bs)
            yield idxs.tolist()

    def __len__(self):
        return self.steps


In [17]:
class EfficientNetB0_3Way(nn.Module):
    def __init__(self, drop_rate=0.1, drop_path_rate=0.25, pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(
            "efficientnet_b0",
            pretrained=pretrained,
            num_classes=0,            # penultimate features
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate
        )
        in_feats = self.backbone.num_features  # 1280 for b0  :contentReference[oaicite:13]{index=13}
        self.cls = nn.Linear(in_feats, len(CLASSES))

    def forward(self, x):
        feats = self.backbone(x)      # (B, 1280)
        logits = self.cls(feats)      # (B, 3)
        return logits


In [18]:
from torchmetrics.classification import MulticlassPrecision, MulticlassRecall, MulticlassF1Score

def make_metrics(num_classes=3):
    prec = MulticlassPrecision(num_classes=num_classes, average=None).to(DEVICE)
    rec  = MulticlassRecall(num_classes=num_classes, average=None).to(DEVICE)
    f1   = MulticlassF1Score(num_classes=num_classes, average=None).to(DEVICE)
    macro_prec = MulticlassPrecision(num_classes=num_classes, average="macro").to(DEVICE)
    macro_rec  = MulticlassRecall(num_classes=num_classes, average="macro").to(DEVICE)
    macro_f1   = MulticlassF1Score(num_classes=num_classes, average="macro").to(DEVICE)
    return dict(prec=prec, rec=rec, f1=f1, macro_prec=macro_prec, macro_rec=macro_rec, macro_f1=macro_f1)

def gmean_from_recalls(rec_np):
    # rec_np: (C,) with class recalls in [0,1]
    rec_np = np.clip(rec_np, 1e-8, 1.0)
    return float(np.exp(np.mean(np.log(rec_np))))  # geometric mean (paper’s main metric) :contentReference[oaicite:14]{index=14}


In [19]:
def sanity_check_batch(dl):
    xb, yb, rb = next(iter(dl))
    assert xb.shape[-2:] == (IMAGE_SIZE, IMAGE_SIZE), "Bad spatial size"
    assert xb.dtype == torch.float32
    assert yb.dtype == torch.long and yb.ndim == 1
    print("Batch ok:", xb.shape, yb.shape, "labels in", yb.min().item(), yb.max().item())
    print("Input mean/std:", xb.mean().item(), xb.std().item())

def tiny_overfit_test(model, dl, steps=200):
    model.train()
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    scaler = GradScaler(enabled=False)
    it = iter(dl)
    for i in range(steps):
        try: xb,yb,_ = next(it)
        except StopIteration: it = iter(dl); xb,yb,_ = next(it)
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        opt.zero_grad(set_to_none=True)
        logits = model(xb)
        loss = F.cross_entropy(logits, yb)
        loss.backward(); opt.step()
        if (i+1)%50==0:
            with torch.no_grad():
                acc = (logits.argmax(1)==yb).float().mean().item()
            print(f"[tiny overfit] step {i+1}: loss={loss.item():.4f}, acc={acc:.3f}")


In [20]:
# Train with oversampling + cap per ROI; Val/Test without
train_ds = PatchCSV(SPLITS_CSV, "train", transform=he_safe_train_transforms(),
                    max_per_roi=MAX_PATCHES_PER_ROI, oversample_mult=OVERSAMPLE_MULT)
val_ds   = PatchCSV(SPLITS_CSV, "val",   transform=he_safe_eval_transforms(),
                    max_per_roi=None, oversample_mult=None)
test_ds  = PatchCSV(SPLITS_CSV, "test",  transform=he_safe_eval_transforms(),
                    max_per_roi=None, oversample_mult=None)

print("train/val/test sizes:", len(train_ds), len(val_ds), len(test_ds))
print("class_to_idx:", train_ds.class_to_idx)

# batch sampler for fixed #patches
train_bsamp = FixedPatchBatchSampler(train_ds, STEPS_PER_EPOCH, BATCH_PATCHES)
val_bsamp   = FixedPatchBatchSampler(val_ds,   steps_per_epoch=math.ceil(len(val_ds)/BATCH_PATCHES),
                                     batch_patches=BATCH_PATCHES)
test_bsamp  = FixedPatchBatchSampler(test_ds,  steps_per_epoch=math.ceil(len(test_ds)/BATCH_PATCHES),
                                     batch_patches=BATCH_PATCHES)

def collate(items):
    xs, ys, rs = zip(*items)
    return torch.stack(xs,0), torch.tensor(ys), list(rs)

train_loader = DataLoader(train_ds, batch_sampler=train_bsamp, num_workers=4, pin_memory=True, collate_fn=collate)
val_loader   = DataLoader(val_ds,   batch_sampler=val_bsamp,   num_workers=4, pin_memory=True, collate_fn=collate)
test_loader  = DataLoader(test_ds,  batch_sampler=test_bsamp,  num_workers=4, pin_memory=True, collate_fn=collate)

sanity_check_batch(train_loader)


train/val/test sizes: 46365 9347 10744
class_to_idx: {'B': 0, 'A': 1, 'M': 2}
Batch ok: torch.Size([512, 3, 224, 224]) torch.Size([512]) labels in 0 2
Input mean/std: 0.8646245002746582 0.9363856315612793


In [21]:
model = EfficientNetB0_3Way(DROP_RATE, DROP_PATH_RATE, pretrained=True).to(DEVICE)

# param groups: apply weight decay except BN/bias
decay, no_decay = [], []
for n,p in model.named_parameters():
    if p.requires_grad:
        if p.ndim == 1 or n.endswith(".bias") or "bn" in n.lower() or "norm" in n.lower():
            no_decay.append(p)
        else:
            decay.append(p)
optim = torch.optim.SGD(
    [{"params": decay, "weight_decay": WEIGHT_DECAY},
     {"params": no_decay, "weight_decay": 0.0}],
    lr=LR, momentum=MOMENTUM, nesterov=False
)
scaler = GradScaler(enabled=USE_AMP)

# quick optimizer sanity
num_params = sum(p.numel() for p in model.parameters())
num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("params:", num_params, "trainable:", num_trainable)
for i,pg in enumerate(optim.param_groups):
    print(f"PG{i}: lr={pg['lr']} wd={pg['weight_decay']} #tensors={len(pg['params'])}")




  scaler = GradScaler(enabled=USE_AMP)


params: 4011391 trainable: 4011391
PG0: lr=0.03 wd=0.0001 #tensors=82
PG1: lr=0.03 wd=0.0 #tensors=131


In [23]:
from tqdm import tqdm

def run_epoch(model, loader, train=True, epoch_idx=None):
    mode = "Train" if train else "Val"
    model.train() if train else model.eval()

    metrics = make_metrics(num_classes=len(CLASSES))
    total_loss, total = 0.0, 0
    cls_counter = Counter()
    start = time.time()

    # ---- progress bar ----
    pbar = tqdm(enumerate(loader), total=len(loader),
                desc=f"[{mode}] Epoch {epoch_idx if epoch_idx is not None else ''}",
                ncols=120)

    for step, (xb, yb, _) in pbar:
        xb, yb = xb.to(DEVICE, non_blocking=True), yb.to(DEVICE, non_blocking=True)
        bs = xb.size(0)

        # forward + loss
        if train:
            optim.zero_grad(set_to_none=True)
            with autocast(enabled=USE_AMP):
                logits = model(xb)
                loss = F.cross_entropy(logits, yb)
            scaler.scale(loss).backward()
            scaler.step(optim)
            scaler.update()
        else:
            with torch.no_grad():
                logits = model(xb)
                loss = F.cross_entropy(logits, yb)

        # metric + counters
        with torch.no_grad():
            preds = logits.argmax(1)
            metrics["rec"].update(preds, yb)
            metrics["macro_rec"].update(preds, yb)
            for c, n in zip(*torch.unique(yb, return_counts=True)):
                cls_counter[int(c.item())] += int(n.item())

        total_loss += loss.item() * bs
        total += bs

        # ---- live logging ----
        elapsed = time.time() - start
        patches_per_s = total / max(elapsed, 1e-6)
        est_total = len(loader)
        remaining = (est_total - (step + 1)) * (elapsed / max(step + 1, 1))
        msg = (f"Step [{step+1:3d}/{len(loader)}] "
               f"loss={loss.item():.4f} | patches/s={patches_per_s:.1f} "
               f"| elapsed={elapsed/60:.1f}m | ETA={remaining/60:.1f}m")
        pbar.set_postfix_str(msg)

    # ---- aggregate metrics ----
    rec = metrics["rec"].compute().detach().cpu().numpy()
    macro_rec = float(metrics["macro_rec"].compute().item())
    gmean = gmean_from_recalls(rec)
    avg_loss = total_loss / max(1, total)
    elapsed = time.time() - start
    counts = [cls_counter.get(i, 0) for i in range(len(CLASSES))]

    print(f"\n[{mode}] Epoch done in {elapsed/60:.2f} min | "
          f"avg_loss={avg_loss:.4f} | gmean={gmean:.4f} | "
          f"recall(B,A,M)={np.round(rec,3)} | counts={counts}")

    return dict(loss=avg_loss, gmean=gmean, rec=rec, time=elapsed, seen_per_class=counts)


# === Main training loop ===
best = {"gmean": -1, "state": None, "epoch": -1}

for epoch in range(1, EPOCHS + 1):
    tr = run_epoch(model, train_loader, train=True, epoch_idx=epoch)
    va = run_epoch(model, val_loader, train=False, epoch_idx=epoch)

    print(f"\nEpoch {epoch:02d} | train_loss={tr['loss']:.4f} | "
          f"val_gmean={va['gmean']:.4f} | "
          f"recall(B,A,M)={np.round(va['rec'],3)} | "
          f"time={tr['time']:.1f}s")

    if va["gmean"] > best["gmean"]:
        best.update({"gmean": va["gmean"], "state": model.state_dict(), "epoch": epoch})
        ckpt_path = "/content/drive/MyDrive/BRACS/checkpoints/efficientnet_b0_3way_best.pt"
        os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
        torch.save({
            "model": best["state"],
            "epoch": epoch,
            "meta": {
                "class_to_idx": class_to_idx,
                "normalize_mean": [0.485,0.456,0.406],
                "normalize_std":  [0.229,0.224,0.225],
                "image_size": IMAGE_SIZE
            }
        }, ckpt_path)
        print(f"  ↳ saved best: {ckpt_path} (g-mean={best['gmean']:.4f})")


  with autocast(enabled=USE_AMP):
[Train] Epoch 1: 100%|█| 20/20 [16:55<00:00, 50.76s/it, Step [ 20/20] loss=0.4286 | patches/s=10.1 | elapsed=16.9m | ETA



[Train] Epoch done in 16.93 min | avg_loss=0.4980 | gmean=0.4994 | recall(B,A,M)=[0.849 0.164 0.894] | counts=[4297, 945, 4998]


[Val] Epoch 1:   0%|                                                                             | 0/19 [03:14<?, ?it/s]


KeyboardInterrupt: 