In [None]:
# patchrel training (without ROI augmentation)


In [1]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [3]:
"""
PatchRel (two-branch patch classifier with ROI-level relevance aggregation)
First steps — NO ROIAugment yet (flag is wired but disabled by default).
Directory layout (expected):
  MyDrive/BRACS/ROIPatches/{split}/{lesion}/{ROI_name}/*.png
Where {lesion} ∈ {N, PB, UDH, FEA, ADH, DCIS, IC} (case-insensitive tolerated).
We map these 7 to 3 superclasses: B (Benign), A (Atypical), M (Malignant).

This file provides:
  - Config dataclass
  - Lesion→Superclass mapping utilities
  - ROI Patch Dataset with clean sanity checks
  - FixedPatchBatchSampler (caps patches/ROI and total patches/batch)
  - Collate fn with optional UNION augmentation hook (disabled)
  - PatchRel model (EfficientNet-B0 backbone via timm)
  - One quick dry-run sanity function

Requirements: timm, torch, torchvision, pillow
"""
from __future__ import annotations
import os
import re
import math
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Sequence

import time
from tqdm import tqdm
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, Sampler
from PIL import Image

try:
    import timm  # type: ignore
except Exception as e:
    raise RuntimeError("Please install timm: pip install timm") from e

# -------------------------------
# Config
# -------------------------------
@dataclass
class Config:
    drive_root: Path = Path("/content/drive/MyDrive/BRACS/ROIPatches")
    split: str = "train"  # {train,val,test}
    image_size: int = 224
    normalize_mean: Tuple[float, float, float] = (0.485, 0.456, 0.406)
    normalize_std: Tuple[float, float, float] = (0.229, 0.224, 0.225)

    # I/O speed knobs
    fast_io_cap: Optional[int] = None  # if set, __getitem__ will load at most this many patches per ROI

    # batching
    batch_total_patches: int = 512
    max_patches_per_roi: int = 30

    # dataloader
    num_workers: int = 2
    pin_memory: bool = True
    persistent_workers: bool = False

    # oversampling of superclasses at ROI level
    oversample_multiplier: Dict[str, int] = None  # set in __post_init__

    # augmentation flags (ROIAugment not used here yet)
    enable_roi_augment: bool = False
    roi_augment_prob: float = 0.25

    def __post_init__(self):
        if self.oversample_multiplier is None:

            self.oversample_multiplier = {"B": 1, "A": 1, "M": 1}


# -------------------------------
# Label mapping utilities
# -------------------------------
SUPERCLASSES = ["B", "A", "M"]
SUPERCLASS_TO_IDX = {c: i for i, c in enumerate(SUPERCLASSES)}

# Map 7 lesion folders into 3 superclasses
LESION_TO_SUPER = {
    "N": "B", "PB": "B", "UDH": "B",
    "FEA": "A", "ADH": "A",
    "DCIS": "M", "IC": "M",
}


def infer_superclass_from_lesion_folder(name: str) -> str:
    """Robustly infer superclass from a lesion folder name.
    Accepts names like 'N', 'PB', 'UDH', 'FEA', 'ADH', 'DCIS', 'IC', and tolerant variants
    such as '0_N', '1-PB', 'Type_UDH', etc. Strategy: split by non-letters and look for a
    token that matches one of the 7 canonical lesion codes.
    """
    tokens = [t for t in re.split(r"[^A-Za-z]+", name.upper()) if t]
    for t in tokens:
        if t in LESION_TO_SUPER:
            return LESION_TO_SUPER[t]
    raise ValueError(
        f"Unrecognized lesion folder '{name}'. Expected one of {list(LESION_TO_SUPER)} or tolerant variants like '0_N'."
    )


# -------------------------------
# Light transforms (TorchVision-free to keep deps small)
# -------------------------------
class SimpleResizeNormalize:
    def __init__(self, size: int, mean: Tuple[float, float, float], std: Tuple[float, float, float]):
        self.size = size
        self.mean = torch.tensor(mean).view(3, 1, 1)
        self.std = torch.tensor(std).view(3, 1, 1)

    def __call__(self, pil_img: Image.Image) -> torch.Tensor:
        # Fast, warning-free: use numpy instead of Torch ByteStorage path
        img = pil_img.convert("RGB").resize((self.size, self.size), resample=Image.BILINEAR)
        import numpy as np
        arr = np.asarray(img, dtype=np.float32) / 255.0  # [H,W,3]
        t = torch.from_numpy(arr).permute(2, 0, 1)       # [3,H,W]
        return (t - self.mean) / self.std


# -------------------------------
# ROI Patch Dataset
# -------------------------------
COORD_PAT = re.compile(r"(?:x|X)(-?\d+)[^\d]+(?:y|Y)(-?\d+)")


class ROIPatchDataset(Dataset):
    """Walks the Drive directory and builds an index of (roi_dir, class_idx, list_of_patch_paths).
       Each __getitem__ returns the patches for a *single ROI* lazily (the sampler will pick ROI ids).
    """
    def __init__(self, cfg: Config):
        self.cfg = cfg
        self.root = cfg.drive_root / cfg.split
        if not self.root.exists():
            raise FileNotFoundError(f"Split folder not found: {self.root}")
        self.transform = SimpleResizeNormalize(cfg.image_size, cfg.normalize_mean, cfg.normalize_std)

        self.roi_items: List[Tuple[Path, int, List[Path]]] = []  # (roi_dir, superclass_idx, patch_list)
        self._build_index()

        # Class index list for oversampling decisions
        self.roi_super: List[int] = [it[1] for it in self.roi_items]

    def _build_index(self):
        lesion_folders = [p for p in self.root.iterdir() if p.is_dir()]
        total_patches = 0
        for lf in sorted(lesion_folders):
            super_c = infer_superclass_from_lesion_folder(lf.name)
            c_idx = SUPERCLASS_TO_IDX[super_c]
            for roi_dir in sorted([d for d in lf.iterdir() if d.is_dir()]):
                patches = sorted([p for p in roi_dir.glob("*.png")])
                if len(patches) < 5:
                    continue  # drop tiny ROIs as per paper
                self.roi_items.append((roi_dir, c_idx, patches))
                total_patches += len(patches)

        # --- Sanity checks ---
        assert len(self.roi_items) > 0, f"No ROIs found under {self.root}"
        counts = {"B": 0, "A": 0, "M": 0}
        for _, c_idx, _ in self.roi_items:
            counts[SUPERCLASSES[c_idx]] += 1
        print(f"[Index] Split={self.cfg.split} ROIs={len(self.roi_items)} per-class={counts} (patches~{total_patches})")

    def __len__(self) -> int:
        return len(self.roi_items)

    def __getitem__(self, idx: int):
        roi_dir, c_idx, patch_paths = self.roi_items[idx]
        # Optional early cap at dataset level to avoid loading all files (helps with slow I/O)
        load_paths = patch_paths
        if self.cfg.fast_io_cap is not None and len(load_paths) > self.cfg.fast_io_cap:
            load_paths = load_paths[: self.cfg.fast_io_cap]
        images: List[torch.Tensor] = []
        coords: List[Tuple[int, int]] = []
        for p in load_paths:
            try:
                img = Image.open(p)
            except Exception as e:
                raise RuntimeError(f"Failed to open patch: {p}") from e
            images.append(self.transform(img))
            m = COORD_PAT.search(p.name)
            if m:
                coords.append((int(m.group(1)), int(m.group(2))))
            else:
                coords.append((0, 0))
        images_t = torch.stack(images, dim=0)  # [L, 3, H, W]
        coords_t = torch.tensor(coords, dtype=torch.long)  # [L, 2]
        target = torch.tensor(c_idx, dtype=torch.long)
        return {
            "roi_id": idx,
            "roi_dir": str(roi_dir),
            "images": images_t,
            "coords": coords_t,
            "target": target,
        }


# -------------------------------
# Sampler + Collate
# -------------------------------
class FixedPatchBatchSampler(Sampler[List[int]]):
    """Yields lists of ROI indices so that the *total* patches in the subsequent collate
    will be ≈ cfg.batch_total_patches, with per-ROI cap cfg.max_patches_per_roi.
    Implements simple oversampling by repeating ROI indices based on superclass multipliers.
    """
    def __init__(self, dataset: ROIPatchDataset, cfg: Config, steps_per_epoch: int = 500):
        self.ds = dataset
        self.cfg = cfg
        self.steps_per_epoch = steps_per_epoch
        self.roi_indices_by_super = {0: [], 1: [], 2: []}
        for i, c in enumerate(self.ds.roi_super):
            self.roi_indices_by_super[c].append(i)
        # Simple safety
        for k, v in self.roi_indices_by_super.items():
            if len(v) == 0:
                print(f"[WARN] No ROIs for class {SUPERCLASSES[k]} in split {cfg.split}")

    def __len__(self):
        return self.steps_per_epoch

    def __iter__(self):
        cfg = self.cfg
        while True:
            # Build a pool with oversampling
            pool: List[int] = []
            for super_c, idxs in self.roi_indices_by_super.items():
                mult = cfg.oversample_multiplier[SUPERCLASSES[super_c]]
                if len(idxs) > 0 and mult > 0:
                    pool.extend(random.choices(idxs, k=max(mult * 4, mult)))
            if not pool:  # fallback
                pool = list(range(len(self.ds)))
            random.shuffle(pool)

            batch_roi_idxs: List[int] = []
            patch_budget = cfg.batch_total_patches
            for roi_idx in pool:
                L = self._roi_length_capped(roi_idx)
                if L <= 0:
                    continue
                if L > patch_budget:
                    if len(batch_roi_idxs) == 0:  # ensure we always yield something
                        batch_roi_idxs.append(roi_idx)
                    break
                batch_roi_idxs.append(roi_idx)
                patch_budget -= L
                if patch_budget <= 0:
                    break

            yield batch_roi_idxs

    def _roi_length_capped(self, roi_idx: int) -> int:
        L = self.ds.roi_items[roi_idx][2]
        return min(len(L), self.cfg.max_patches_per_roi)


def collate_roi_batch(samples: List[dict], cfg: Config) -> dict:
    """Convert a list of ROI samples into a flat patch batch with grouping information.
       ROIAugment hook is wired but disabled by default.
    """
    # Optionally, ROIAugment would go here (disabled now)

    images_list: List[torch.Tensor] = []
    coords_list: List[torch.Tensor] = []
    roi_ids: List[int] = []  # per-patch roi group id within the batch
    targets: List[int] = []  # per-ROI class index (single-label for now)

    group_id = 0
    for s in samples:
        img = s["images"]
        coords = s["coords"]
        # Cap per ROI
        if img.size(0) > cfg.max_patches_per_roi:
            # deterministic crop for reproducibility (can switch to random later)
            img = img[: cfg.max_patches_per_roi]
            coords = coords[: cfg.max_patches_per_roi]
        L = img.size(0)
        images_list.append(img)
        coords_list.append(coords)
        roi_ids.extend([group_id] * L)
        targets.append(int(s["target"]))
        group_id += 1

    images = torch.cat(images_list, dim=0)  # [P,3,H,W]
    coords = torch.cat(coords_list, dim=0)  # [P,2]
    roi_ids_t = torch.tensor(roi_ids, dtype=torch.long)  # [P]
    roi_targets = torch.tensor(targets, dtype=torch.long)  # [R]

    # --- Sanity checks ---
    assert images.ndim == 4 and images.size(1) == 3
    assert coords.size(0) == images.size(0)
    assert roi_ids_t.size(0) == images.size(0)
    assert roi_targets.size(0) == group_id

    return {
        "images": images,
        "coords": coords,
        "roi_ids": roi_ids_t,
        "roi_targets": roi_targets,  # int labels 0..2 (B,A,M)
    }


# -------------------------------
# PatchRel Model (no ROI augmentation required here)
# -------------------------------
class PatchRel(nn.Module):
    def __init__(self, num_classes: int = 3, drop_rate: float = 0.1, drop_path_rate: float = 0.25, pretrained: bool = True):
        super().__init__()
        self.backbone = timm.create_model(
            "efficientnet_b0", pretrained=pretrained, num_classes=0, drop_rate=drop_rate, drop_path_rate=drop_path_rate
        )
        in_feats = getattr(self.backbone, "num_features", 1280)
        self.classifier = nn.Linear(in_feats, num_classes)  # ω branch: per-patch class logits
        self.relevance = nn.Linear(in_feats, num_classes)   # β branch: per-class, per-patch relevance logits

    @staticmethod
    def _group_softmax(logits: torch.Tensor, roi_ids: torch.Tensor, num_rois: int) -> torch.Tensor:
        """
        Softmax over patches within each ROI, per class.
        logits: [P, C], roi_ids: [P] in {0..R-1} -> alpha: [P, C] with sum_j alpha_{ij}^c = 1.
        """
        P, C = logits.shape

        # max per (ROI, class) in original dtype (fp16 under autocast)
        max_per_roi_class = torch.full((num_rois, C), -float("inf"),
                                      device=logits.device, dtype=logits.dtype)
        max_per_roi_class.index_reduce_(0, roi_ids, logits, reduce="amax")  # [R, C]
        centered = logits - max_per_roi_class[roi_ids]  # [P, C]

        # --- do exp + accumulation in float32 to avoid dtype clash & improve stability ---
        centered32 = centered.float()
        exp32 = centered32.exp()  # [P, C] float32

        sum_per_roi_class32 = torch.zeros((num_rois, C), device=logits.device, dtype=torch.float32)
        sum_per_roi_class32.index_add_(0, roi_ids, exp32)  # [R, C] float32

        alpha32 = exp32 / (sum_per_roi_class32[roi_ids] + 1e-12)  # [P, C] float32
        # -------------------------------------------------------------------------------

        return alpha32.to(logits.dtype)  # cast α back to fp16 under autocast


    def forward(self, images: torch.Tensor, roi_ids: torch.Tensor) -> Tuple[torch.Tensor, dict]:
        """
        images: [P,3,H,W], roi_ids: [P] (0..R-1)
        Returns:
          roi_scores: [R, C] (pre-sigmoid) — relevance-weighted sum of per-patch class logits→prob
          aux dict: with patch-level outputs and α
        """
        feats = self.backbone(images)  # [P, D]
        patch_logits = self.classifier(feats)  # [P, C]
        patch_probs = patch_logits.softmax(dim=1)  # per-patch class likelihoods p(ŷ=c|x)

        rel_logits = self.relevance(feats)  # [P, C]
        num_rois = int(roi_ids.max().item()) + 1 if roi_ids.numel() > 0 else 0
        alpha = self._group_softmax(rel_logits, roi_ids, num_rois)  # [P, C]

        # Aggregate to ROI level: p(y^c|R) = sum_j α_jc * p(ŷ_j=c|x_j)
        weighted = alpha * patch_probs  # [P, C]
        roi_scores = torch.zeros((num_rois, patch_logits.size(1)), device=images.device, dtype=weighted.dtype)
        roi_scores.index_add_(0, roi_ids, weighted)  # [R, C]

        # For loss we prefer logits; convert ROI probs→logits with clamp
        roi_probs = roi_scores.clamp(1e-6, 1 - 1e-6)
        roi_logits = torch.log(roi_probs) - torch.log1p(-roi_probs)

        aux = {
            "patch_logits": patch_logits,
            "patch_probs": patch_probs,
            "rel_logits": rel_logits,
            "alpha": alpha,
            "roi_probs": roi_probs,
        }
        return roi_logits, aux


# -------------------------------
# Loss & simple metrics
# -------------------------------
class ROIBCELoss(nn.Module):
    def __init__(self, num_classes: int = 3):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.num_classes = num_classes

    def forward(self, roi_logits: torch.Tensor, roi_targets: torch.Tensor) -> torch.Tensor:
        # roi_targets are single-label ints 0..C-1 → convert to one-hot
        one_hot = torch.zeros((roi_targets.size(0), self.num_classes), device=roi_logits.device, dtype=roi_logits.dtype)
        one_hot.scatter_(1, roi_targets.view(-1, 1), 1.0)
        return self.bce(roi_logits, one_hot)


def gmean_from_confusion(conf: torch.Tensor) -> float:
    """conf: [C,C] where conf[t,p] counts ROIs of true t predicted p.
       Returns geometric mean of per-class recalls.
    """
    with torch.no_grad():
        recalls = []
        for c in range(conf.size(0)):
            denom = conf[c].sum().item()
            rec = (conf[c, c].item() / denom) if denom > 0 else 0.0
            recalls.append(max(rec, 1e-12))
        prod = 1.0
        for r in recalls:
            prod *= r
        return float(prod ** (1.0 / len(recalls)))


# -------------------------------
# Dry-run sanity
# -------------------------------
@torch.no_grad()
def dry_run_sanity(cfg: Config):
    print("[Sanity] Building dataset and one batch...")
    ds = ROIPatchDataset(cfg)
    # Choose steps so that DataLoader yields 1 batch quickly
    sampler = FixedPatchBatchSampler(ds, cfg, steps_per_epoch=1)
    def _collate(samples):
        return collate_roi_batch(samples, cfg)
    loader = DataLoader(
        ds,
        batch_sampler=sampler,
        collate_fn=_collate,
        num_workers=cfg.num_workers,
        pin_memory=cfg.pin_memory,
        persistent_workers=cfg.persistent_workers,
        prefetch_factor=2 if cfg.num_workers > 0 else None,
    )

    batch = next(iter(loader))
    P = batch["images"].size(0)
    R = batch["roi_targets"].size(0)
    print(f"[Sanity] Batch patches={P}, ROIs={R}")

    # Model forward on CPU for shape checks (switch to CUDA in training script)
    model = PatchRel(num_classes=3)
    model.eval()
    roi_logits, aux = model(batch["images"], batch["roi_ids"])  # [R,C]

    # Check α sums ≈ 1 per (ROI,class)
    alpha = aux["alpha"]  # [P,C]
    sums = torch.zeros((R, alpha.size(1)))
    sums.index_add_(0, batch["roi_ids"], alpha.cpu())
    max_dev = (sums - 1.0).abs().max().item()
    print(f"[Sanity] max |∑_j α_jc − 1| across ROI×class = {max_dev:.3e}")

    # Loss shape check
    crit = ROIBCELoss(num_classes=3)
    loss = crit(roi_logits, batch["roi_targets"])
    print(f"[Sanity] BCE loss (random init) = {float(loss):.4f}")


# -------------------------------
# Validation sampler (covers entire split deterministically)
# -------------------------------
class ValPatchBatchSampler(Sampler[List[int]]):
    """Sequentially iterate over all ROIs, greedily packing them into batches
    under the same patch budget and per-ROI cap. Deterministic; no oversampling."""
    def __init__(self, dataset: ROIPatchDataset, cfg: Config):
        self.ds = dataset
        self.cfg = cfg
        self.order = list(range(len(self.ds)))  # fixed order

    def __len__(self) -> int:
        # Not strictly needed by PyTorch when using batch_sampler, but provide an estimate
        # by simulating packs (cheap upper bound: number of ROIs)
        return max(1, math.ceil(len(self.ds) / 4))

    def __iter__(self):
        i = 0
        N = len(self.order)
        while i < N:
            patch_budget = self.cfg.batch_total_patches
            this_batch: List[int] = []
            while i < N:
                roi_idx = self.order[i]
                # compute capped length
                L = min(len(self.ds.roi_items[roi_idx][2]), self.cfg.max_patches_per_roi)
                if L <= 0:
                    i += 1
                    continue
                if L > patch_budget and len(this_batch) > 0:
                    break
                this_batch.append(roi_idx)
                patch_budget -= L
                i += 1
                if patch_budget <= 0:
                    break
            yield this_batch


# -------------------------------
# Metrics helpers
# -------------------------------
@torch.no_grad()
def build_confusion(num_classes: int = 3) -> torch.Tensor:
    return torch.zeros((num_classes, num_classes), dtype=torch.long)

@torch.no_grad()
def update_confusion(conf: torch.Tensor, true_labels: torch.Tensor, pred_labels: torch.Tensor) -> None:
    for t, p in zip(true_labels.view(-1), pred_labels.view(-1)):
        conf[int(t), int(p)] += 1





In [4]:
# Training Loop
# -------------------------------
class Trainer:
    def __init__(self, cfg_train: Config, cfg_val: Config, steps_per_epoch: int = 500, out_dir: Path = Path("/content/drive/MyDrive/BRACS/checkpoints/PatchRel_noaug2")):
        self.cfg_train = cfg_train
        self.cfg_val = cfg_val
        self.steps_per_epoch = steps_per_epoch
        self.out_dir = out_dir
        self.out_dir.mkdir(parents=True, exist_ok=True)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Datasets
        self.ds_train = ROIPatchDataset(cfg_train)
        self.ds_val = ROIPatchDataset(cfg_val)

        # Samplers
        self.sampler_train = FixedPatchBatchSampler(self.ds_train, cfg_train, steps_per_epoch=self.steps_per_epoch)
        self.sampler_val = ValPatchBatchSampler(self.ds_val, cfg_val)

        # Loaders
        self.loader_train = DataLoader(
            self.ds_train,
            batch_sampler=self.sampler_train,
            collate_fn=lambda s: collate_roi_batch(s, cfg_train),
            num_workers=cfg_train.num_workers,
            pin_memory=cfg_train.pin_memory,
            persistent_workers=cfg_train.persistent_workers,
            prefetch_factor=2 if cfg_train.num_workers > 0 else None,
        )
        self.loader_val = DataLoader(
            self.ds_val,
            batch_sampler=self.sampler_val,
            collate_fn=lambda s: collate_roi_batch(s, cfg_val),
            num_workers=cfg_val.num_workers,
            pin_memory=cfg_val.pin_memory,
            persistent_workers=cfg_val.persistent_workers,
            prefetch_factor=2 if cfg_val.num_workers > 0 else None,
        )

        # Model & optim
        self.model = PatchRel(num_classes=3).to(self.device)
        self.crit = ROIBCELoss(num_classes=3)
        self.opt = torch.optim.SGD(self.model.parameters(), lr=3e-2, momentum=0.75, weight_decay=1e-4)
        # use new AMP API
        self.scaler = torch.amp.GradScaler('cuda', enabled=(self.device.type == 'cuda'))

        self.best = {"gmean": -1.0, "epoch": -1}

    def train_one_epoch(self, epoch: int) -> Dict[str, float]:
        self.model.train()
        running_loss = 0.0
        epoch_class_counts = torch.zeros(3, dtype=torch.long)

        pbar = tqdm(total=self.steps_per_epoch, desc=f"Epoch {epoch:02d} [train]", leave=False)
        for step, batch in enumerate(self.loader_train, start=1):
            t0 = time.time()

            imgs = batch["images"].to(self.device, non_blocking=True)
            roi_ids = batch["roi_ids"].to(self.device)
            roi_targets = batch["roi_targets"].to(self.device)

            with torch.no_grad():
                epoch_class_counts.index_add_(0, roi_targets.cpu(), torch.ones_like(roi_targets.cpu()))

            self.opt.zero_grad(set_to_none=True)
            with torch.amp.autocast('cuda', enabled=(self.device.type == 'cuda')):
                roi_logits, _ = self.model(imgs, roi_ids)
                loss = self.crit(roi_logits, roi_targets)
            self.scaler.scale(loss).backward()
            self.scaler.step(self.opt)
            self.scaler.update()

            running_loss += loss.item()

            # progress bar update
            patches = int(imgs.size(0))
            dt = max(time.time() - t0, 1e-6)
            pbar.set_postfix({"loss": f"{loss.item():.4f}", "patches": patches, "pps": f"{patches/dt:.1f}"})
            pbar.update(1)

            if step >= self.steps_per_epoch:
                break
        pbar.close()

        avg_loss = running_loss / max(1, self.steps_per_epoch)
        counts = {SUPERCLASSES[i]: int(epoch_class_counts[i].item()) for i in range(3)}
        return {"loss": avg_loss, "counts_B": counts["B"], "counts_A": counts["A"], "counts_M": counts["M"]}

    @torch.no_grad()
    def validate(self) -> Tuple[float, torch.Tensor]:
        self.model.eval()
        conf = build_confusion(num_classes=3).to(self.device)
        for batch in self.loader_val:
            imgs = batch["images"].to(self.device, non_blocking=True)
            roi_ids = batch["roi_ids"].to(self.device)
            roi_targets = batch["roi_targets"].to(self.device)
            with torch.amp.autocast('cuda', enabled=(self.device.type == 'cuda')):
                roi_logits, _ = self.model(imgs, roi_ids)
                # For 3-way decision, use argmax over sigmoid probabilities
                probs = roi_logits.sigmoid()
                preds = probs.argmax(dim=1)
            update_confusion(conf, roi_targets, preds)
        g = gmean_from_confusion(conf.cpu())
        return g, conf.cpu()

    def save_ckpt(self, epoch: int, is_best: bool = False) -> None:
        path = self.out_dir / ("epoch_%03d.pt" % epoch)
        torch.save({
            "epoch": epoch,
            "model": self.model.state_dict(),
            "optimizer": self.opt.state_dict(),
            "scaler": self.scaler.state_dict(),
            "meta": {
                "normalize_mean": self.cfg_train.normalize_mean,
                "normalize_std": self.cfg_train.normalize_std,
                "image_size": self.cfg_train.image_size,
                "superclasses": SUPERCLASSES,
            },
        }, path)
        if is_best:
            best_path = self.out_dir / "best_by_gmean.pt"
            torch.save({
                "epoch": epoch,
                "model": self.model.state_dict(),
                "meta": {
                    "normalize_mean": self.cfg_train.normalize_mean,
                    "normalize_std": self.cfg_train.normalize_std,
                    "image_size": self.cfg_train.image_size,
                    "superclasses": SUPERCLASSES,
                    "best_gmean": self.best["gmean"],
                },
            }, best_path)

    def fit(self, num_epochs: int = 40, save_every: int = 2):
        for epoch in range(1, num_epochs + 1):
            print(f"Epoch {epoch:02d} — training...")
            tr = self.train_one_epoch(epoch)

            print("  running validation...")
            val_g, conf = self.validate()

            print(
                f"Epoch {epoch:02d} | train_loss={tr['loss']:.4f} | val_gmean={val_g:.4f} | "
                f"batched_ROIs_per_class B/A/M=({tr['counts_B']},{tr['counts_A']},{tr['counts_M']})"
            )

            if (epoch % save_every) == 0:
                self.save_ckpt(epoch, is_best=False)

            if val_g > self.best["gmean"]:
                self.best.update({"gmean": val_g, "epoch": epoch})
                self.save_ckpt(epoch, is_best=True)
                print(f"  ↳ saved BEST checkpoint @ epoch {epoch} (g-mean={val_g:.4f})")


if __name__ == "__main__":
    # Speed hint for first run
    torch.backends.cudnn.benchmark = True

    train_cfg = Config(
        drive_root=Path("/content/drive/MyDrive/BRACS/ROIPatches"),
        split="train",
        batch_total_patches=512,
        max_patches_per_roi=20,
        num_workers=8,
        pin_memory=True,
        persistent_workers=True,
        fast_io_cap=20,   # align dataset reads to cap to reduce Drive I/O
    )
    val_cfg = Config(
        drive_root=Path("/content/drive/MyDrive/BRACS/ROIPatches"),
        split="val",
        batch_total_patches=512,
        max_patches_per_roi=20,
        num_workers=8,
        pin_memory=True,
        persistent_workers=True,
        fast_io_cap=20,
    )

    # Optionally sanity check
    # dry_run_sanity(train_cfg)

    trainer = Trainer(train_cfg, val_cfg, steps_per_epoch=180)
    trainer.fit(num_epochs=40, save_every=2)

[Index] Split=train ROIs=2084 per-class={'B': 911, 'A': 250, 'M': 923} (patches~93938)
[Index] Split=val ROIs=178 per-class={'B': 64, 'A': 31, 'M': 83} (patches~9197)


model.safetensors:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

Epoch 01 — training...


  max_per_roi_class.index_reduce_(0, roi_ids, logits, reduce="amax")  # [R, C]
                                                                                                           

  running validation...




Epoch 01 | train_loss=0.2194 | val_gmean=0.8747 | batched_ROIs_per_class B/A/M=(720,2160,720)
  ↳ saved BEST checkpoint @ epoch 1 (g-mean=0.8747)
Epoch 02 — training...




  running validation...
Epoch 02 | train_loss=0.0853 | val_gmean=0.8754 | batched_ROIs_per_class B/A/M=(720,2160,720)
  ↳ saved BEST checkpoint @ epoch 2 (g-mean=0.8754)
Epoch 03 — training...




  running validation...
Epoch 03 | train_loss=0.0559 | val_gmean=0.8375 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 04 — training...




  running validation...
Epoch 04 | train_loss=0.0392 | val_gmean=0.8066 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 05 — training...




  running validation...
Epoch 05 | train_loss=0.0315 | val_gmean=0.7889 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 06 — training...




  running validation...
Epoch 06 | train_loss=0.0333 | val_gmean=0.8144 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 07 — training...




  running validation...
Epoch 07 | train_loss=0.0189 | val_gmean=0.8054 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 08 — training...




  running validation...
Epoch 08 | train_loss=0.0236 | val_gmean=0.6125 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 09 — training...




  running validation...
Epoch 09 | train_loss=0.0281 | val_gmean=0.7882 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 10 — training...




  running validation...
Epoch 10 | train_loss=0.0219 | val_gmean=0.6603 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 11 — training...




  running validation...
Epoch 11 | train_loss=0.0093 | val_gmean=0.6486 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 12 — training...




  running validation...
Epoch 12 | train_loss=0.0074 | val_gmean=0.7161 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 13 — training...




  running validation...
Epoch 13 | train_loss=0.0133 | val_gmean=0.7610 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 14 — training...




  running validation...
Epoch 14 | train_loss=0.0065 | val_gmean=0.6534 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 15 — training...




  running validation...
Epoch 15 | train_loss=0.0046 | val_gmean=0.7199 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 16 — training...




  running validation...
Epoch 16 | train_loss=0.0028 | val_gmean=0.7169 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 17 — training...




  running validation...
Epoch 17 | train_loss=0.0013 | val_gmean=0.7022 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 18 — training...




  running validation...
Epoch 18 | train_loss=0.0010 | val_gmean=0.6964 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 19 — training...




  running validation...
Epoch 19 | train_loss=0.0013 | val_gmean=0.7071 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 20 — training...




  running validation...
Epoch 20 | train_loss=0.0006 | val_gmean=0.7221 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 21 — training...




  running validation...
Epoch 21 | train_loss=0.0006 | val_gmean=0.6964 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 22 — training...




  running validation...
Epoch 22 | train_loss=0.0009 | val_gmean=0.6711 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 23 — training...




  running validation...
Epoch 23 | train_loss=0.0004 | val_gmean=0.6803 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 24 — training...




  running validation...
Epoch 24 | train_loss=0.0006 | val_gmean=0.6595 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 25 — training...




  running validation...
Epoch 25 | train_loss=0.0007 | val_gmean=0.7386 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 26 — training...




  running validation...
Epoch 26 | train_loss=0.0031 | val_gmean=0.7561 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 27 — training...




  running validation...
Epoch 27 | train_loss=0.0040 | val_gmean=0.7981 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 28 — training...




  running validation...
Epoch 28 | train_loss=0.0021 | val_gmean=0.7849 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 29 — training...




  running validation...
Epoch 29 | train_loss=0.0030 | val_gmean=0.7153 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 30 — training...




  running validation...
Epoch 30 | train_loss=0.0080 | val_gmean=0.7254 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 31 — training...




  running validation...
Epoch 31 | train_loss=0.0156 | val_gmean=0.7285 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 32 — training...




  running validation...
Epoch 32 | train_loss=0.0048 | val_gmean=0.7424 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 33 — training...




  running validation...
Epoch 33 | train_loss=0.0030 | val_gmean=0.7714 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 34 — training...




  running validation...
Epoch 34 | train_loss=0.0085 | val_gmean=0.7515 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 35 — training...




  running validation...
Epoch 35 | train_loss=0.0041 | val_gmean=0.7682 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 36 — training...




  running validation...
Epoch 36 | train_loss=0.0017 | val_gmean=0.7014 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 37 — training...




  running validation...
Epoch 37 | train_loss=0.0019 | val_gmean=0.7816 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 38 — training...




  running validation...
Epoch 38 | train_loss=0.0020 | val_gmean=0.7579 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 39 — training...




  running validation...
Epoch 39 | train_loss=0.0015 | val_gmean=0.7199 | batched_ROIs_per_class B/A/M=(720,2160,720)
Epoch 40 — training...




  running validation...
Epoch 40 | train_loss=0.0016 | val_gmean=0.7436 | batched_ROIs_per_class B/A/M=(720,2160,720)
