In [None]:
# patch encoder training

In [18]:
import time
import torch.backends.cuda as cuda_backends

# GPU math speedups (safe on A100/Ampere+)
torch.set_float32_matmul_precision("high")  # enables TF32 matmuls under the hood
cuda_backends.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [14]:
import os, random
from pathlib import Path
from collections import defaultdict, Counter

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from torchvision import transforms
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights

from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [32]:
# -----------------------------
# Config
# -----------------------------
DRIVE_ROOT = "/content/drive/MyDrive/BRACS/ROIPatches"

IMAGE_SIZE = 224
BATCH_SIZE_PATCHES = 512
MAX_PATCHES_PER_ROI = 30
NUM_WORKERS = 8

# training schedule
EPOCHS = 40
STEPS_PER_EPOCH = 180
AMP = True
GRAD_CLIP = 5.0


# optimizer and regularization
BASE_LR = 3e-2
MOMENTUM = 0.75
WEIGHT_DECAY = 1e-4
DROPOUT_P = 0.1

# !!! constant learning rate !!!
# discriminative LRs & warm-up
LR_DIV_FACTOR = 2.0          # deeper group -> larger LR by factor^k (we invert by distance to end)
WARMUP_EPOCHS = 2
FINETUNE_LR_DIVISOR = 10.0    # post-warmup: LR := LR / FINETUNE_LR_DIVISOR

# subtype mapping
SEVEN_TO_THREE = {
    "0_N":   0, "1_PB": 0, "2_UDH": 0,  # Benign
    "3_FEA": 1, "4_ADH": 1,            # Atypical
    "5_DCIS":2, "6_IC":  2,            # Malignant
}
THREE_CLASS_NAMES = {0: "Benign", 1: "Atypical", 2: "Malignant"}

# oversampling: boost atypical
CLASS_BOOST = {0: 1.0, 1: 1.0, 2: 1.0}


In [33]:

# -----------------------------
# Data indexing
# -----------------------------
def _is_img(p: Path):
    return p.suffix.lower() in {".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp"}

def scan_split_3way(split_dir: Path):
    recs = []
    if not split_dir.exists():
        return recs
    for seven_cls in sorted([d.name for d in split_dir.iterdir() if d.is_dir()]):
        if seven_cls not in SEVEN_TO_THREE:
            continue
        y3 = SEVEN_TO_THREE[seven_cls]
        cls_dir = split_dir / seven_cls
        for roi_dir in cls_dir.iterdir():
            if not roi_dir.is_dir():
                continue
            roi_id = f"{seven_cls}/{roi_dir.name}"
            for img_path in roi_dir.iterdir():
                if _is_img(img_path):
                    recs.append({"path": str(img_path), "y3": y3, "roi_id": roi_id})
    return recs


# -----------------------------
# Dataset & transforms
# -----------------------------
class PatchDataset(Dataset):
    def __init__(self, records, transform):
        self.records = records
        self.transform = transform

    def __len__(self): return len(self.records)

    def __getitem__(self, idx):
        r = self.records[idx]
        img = Image.open(r["path"]).convert("RGB")
        img = self.transform(img)
        return img, r["y3"], r["roi_id"]

weights = EfficientNet_B0_Weights.IMAGENET1K_V1
IMAGENET_MEAN = weights.transforms().mean
IMAGENET_STD  = weights.transforms().std


train_tfms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomApply([transforms.RandomRotation(degrees=(90, 270))], p=0.5),
    transforms.RandomRotation(degrees=45),
    transforms.ColorJitter(hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

eval_tfms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

In [34]:
# -----------------------------
# Sampler (class-balanced + atypical boost)
# -----------------------------
def make_weighted_sampler_3way(records):
    class_counts = Counter([r["y3"] for r in records])
    weights = []
    for r in records:
        y = r["y3"]
        w = CLASS_BOOST[y] * (1.0 / max(1, class_counts[y]))
        weights.append(w)
    return WeightedRandomSampler(weights, num_samples=len(records), replacement=True)


# -----------------------------
# Collate: cap ≤30 patches/ROI
# -----------------------------
def _cap_per_roi(images, targets, rois, max_per_roi=30):
    idxs_by_roi = defaultdict(list)
    for i, roi in enumerate(rois):
        idxs_by_roi[roi].append(i)
    keep = []
    for _, idxs in idxs_by_roi.items():
        if len(idxs) <= max_per_roi:
            keep.extend(idxs)
        else:
            keep.extend(random.sample(idxs, max_per_roi))
    keep.sort()
    images = images[keep]
    targets = targets[keep]
    rois = [rois[i] for i in keep]
    return images, targets, rois

def collate_fn_max30(batch):
    imgs, ys, rois = [], [], []
    for img, y, roi in batch:
        imgs.append(img.unsqueeze(0))
        ys.append(y)
        rois.append(roi)
    images = torch.cat(imgs, dim=0)
    targets = torch.tensor(ys, dtype=torch.long)
    images, targets, rois = _cap_per_roi(images, targets, rois, MAX_PATCHES_PER_ROI)
    return images, targets, rois

In [35]:
# -----------------------------
# Model (EfficientNet-B0)
# -----------------------------
class EfficientNetB0_3Way(nn.Module):
    def __init__(self, num_classes=3, dropout_p=0.1):
        super().__init__()
        self.backbone = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
        in_feat = self.backbone.classifier[1].in_features  # 1280
        self.backbone.classifier = nn.Identity()
        self.pen_dropout = nn.Dropout(p=dropout_p)
        self.classifier = nn.Linear(in_feat, num_classes)

    @torch.no_grad()
    def extract_penultimate(self, x):
        x = self.backbone.features(x)
        x = self.backbone.avgpool(x)
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        pen = self.extract_penultimate(x)
        pen = self.pen_dropout(pen)
        logits = self.classifier(pen)
        return logits, pen


# -----------------------------
# Optimizer param groups:
#   - split EfficientNet features into depth-ordered groups
#   - discriminative LRs
#   - weight decay only on weights (no BN/bias)
# -----------------------------
def _params_weight_bias(mod):
    weights, biases = [], []
    for name, p in mod.named_parameters(recurse=False):
        if not p.requires_grad:
            continue
        if p.ndimension() == 1 or name.endswith(".bias"):
            biases.append(p)
        else:
            weights.append(p)
    # include children recursively
    for child in mod.children():
        w, b = _params_weight_bias(child)
        weights.extend(w); biases.extend(b)
    return weights, biases

def build_discriminative_param_groups(model, base_lr, wd, mom, lr_div_factor):
    """
    Split torchvision EfficientNet:
      groups = [features[0], features[1], ..., features[N-1], classifier_head]
    Earlier idx -> smaller LR; later idx -> larger; classifier largest.
    """
    groups = []

    # break features sequential into groups
    feats = model.backbone.features  # nn.Sequential([...])
    for i in range(len(feats)):
        weights, biases = _params_weight_bias(feats[i])
        if weights or biases:
            groups.append({"weights": weights, "biases": biases})

    # add the classification head
    head_w, head_b = _params_weight_bias(model.classifier)
    groups.append({"weights": head_w, "biases": head_b})

    # assign LRs (discriminative): earlier groups get lower LR
    num_groups = len(groups)
    optim_groups = []
    for g_idx, g in enumerate(groups):
        # distance from end: deeper groups -> larger LR
        depth = num_groups - (g_idx + 1)  # 0 for last, increases as we go earlier
        # divide by (lr_div_factor ** depth) to shrink earlier layers
        cur_lr = base_lr / (lr_div_factor ** depth)
        if g["weights"]:
            optim_groups.append({"params": g["weights"], "lr": cur_lr, "momentum": mom, "weight_decay": wd})
        if g["biases"]:
            optim_groups.append({"params": g["biases"], "lr": cur_lr, "momentum": mom, "weight_decay": 0.0})
    init_lrs = [pg["lr"] for pg in optim_groups]
    return optim_groups, init_lrs


# -----------------------------
# Metrics (per-class recall & g-mean)
# -----------------------------
def per_class_recall(preds, targets, num_classes=3):
    preds = preds.detach().cpu().numpy()
    targets = targets.detach().cpu().numpy()
    recalls = []
    for c in range(num_classes):
        idx = (targets == c)
        if idx.sum() == 0:
            recalls.append(1.0)  # neutral if a class is absent in this val fold/batch
        else:
            correct = (preds[idx] == targets[idx]).sum()
            recalls.append(correct / idx.sum())
    return np.array(recalls, dtype=np.float64)

def gmean(recalls, eps=1e-8):
    return float(np.exp(np.mean(np.log(recalls + eps))))


In [36]:
# -----------------------------
# Data loaders
# -----------------------------
def build_loaders():
    train_dir = Path(DRIVE_ROOT) / "train"
    val_dir   = Path(DRIVE_ROOT) / "val"

    train_recs = scan_split_3way(train_dir)
    val_recs   = scan_split_3way(val_dir)

    train_ds = PatchDataset(train_recs, transform=train_tfms)
    val_ds   = PatchDataset(val_recs,   transform=eval_tfms)

    train_sampler = make_weighted_sampler_3way(train_recs)

    train_loader = DataLoader(
      train_ds,
      batch_size=BATCH_SIZE_PATCHES,
      sampler=train_sampler,
      num_workers=NUM_WORKERS,
      pin_memory=True,
      collate_fn=collate_fn_max30,
      drop_last=True,
      persistent_workers=True,   # <— add
      prefetch_factor=4          # <— add
  )
    val_loader = DataLoader(
      val_ds,
      batch_size=BATCH_SIZE_PATCHES,
      shuffle=False,
      num_workers=NUM_WORKERS,
      pin_memory=True,
      collate_fn=collate_fn_max30,
      drop_last=False,
      persistent_workers=True,   # <— add
      prefetch_factor=4          # <— add
  )

    return train_loader, val_loader

In [37]:
def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.backends.cudnn.benchmark = True

    print("--DATA LOADING--")

    train_loader, val_loader = build_loaders()

    print("--DATA LOADING DONE--")

    model = EfficientNetB0_3Way(num_classes=3, dropout_p=DROPOUT_P).to(device)
    # Channels-last speeds up memory access on NVIDIA GPUs
    model = model.to(memory_format=torch.channels_last)

    # Build discriminative optimizer groups (layer-wise LR)
    optim_groups, init_lrs = build_discriminative_param_groups(
        model, base_lr=BASE_LR, wd=WEIGHT_DECAY, mom=MOMENTUM, lr_div_factor=LR_DIV_FACTOR
    )
    optimizer = torch.optim.SGD(optim_groups)  # momentum/weight_decay are in groups

    scaler = torch.cuda.amp.GradScaler(enabled=AMP)
    ce = nn.CrossEntropyLoss()

    save_dir = "./checkpoints"
    os.makedirs(save_dir, exist_ok=True)
    ckpt_path = os.path.join(save_dir, "efficientnet_b0_3way_best.pt")

    best_val_g = -1.0
    warmup_steps = WARMUP_EPOCHS * STEPS_PER_EPOCH
    global_step = 0

    # Will hold counts from the *previous* epoch (printed at the start of the next)
    prev_epoch_class_counts = None

    for epoch in range(1, EPOCHS + 1):
        # ---------- Epoch header ----------
        if prev_epoch_class_counts is None:
            print(f"\nEpoch {epoch:02d} — first epoch: class counts will be shown from next epoch.")
        else:
            b, a, m = prev_epoch_class_counts.tolist()
            print(f"\nEpoch {epoch:02d} — last epoch class counts (patches seen): "
                  f"Benign={b}, Atypical={a}, Malignant={m}")

        # ---------- Train ----------
        model.train()
        train_iter = iter(train_loader)
        running_loss = 0.0
        epoch_class_counts = torch.zeros(3, dtype=torch.long)  # counts of targets actually processed
        epoch_patches = 0
        t0 = time.time()

        for step in range(STEPS_PER_EPOCH):
            try:
                images, targets, rois = next(train_iter)
            except StopIteration:
                train_iter = iter(train_loader)
                images, targets, rois = next(train_iter)

            # Send to GPU using channels_last
            images  = images.to(device, non_blocking=True).to(memory_format=torch.channels_last)
            targets = targets.to(device, non_blocking=True)

            # Track class counts & throughput
            # (targets can be smaller than batch size due to per-ROI cap)
            batch_counts = torch.bincount(targets, minlength=3).cpu()
            epoch_class_counts += batch_counts
            epoch_patches += int(targets.numel())

            # Warm-up & post-warmup LR handling (per-group)
            if global_step <= warmup_steps:
                for pg, lr in zip(optimizer.param_groups, init_lrs):
                    pg["lr"] = lr
            else:
                for pg, lr in zip(optimizer.param_groups, init_lrs):
                    pg["lr"] = lr / FINETUNE_LR_DIVISOR

            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=AMP):
                logits, _ = model(images)
                loss = ce(logits, targets)

            scaler.scale(loss).backward()
            nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()
            global_step += 1

        # End-of-epoch timing
        epoch_time = time.time() - t0
        patches_per_sec = epoch_patches / max(1e-6, epoch_time)

        # ---------- Validate ----------
        model.eval()
        all_recalls = []
        with torch.no_grad(), torch.cuda.amp.autocast(enabled=AMP):
            for images, targets, rois in val_loader:
                images  = images.to(device, non_blocking=True).to(memory_format=torch.channels_last)
                targets = targets.to(device, non_blocking=True)
                logits, _ = model(images)
                preds = logits.argmax(1)
                rec = per_class_recall(preds, targets, num_classes=3)
                all_recalls.append(rec)

        if all_recalls:
            recalls = np.stack(all_recalls, axis=0).mean(axis=0)
            val_g = gmean(recalls)
        else:
            recalls = np.zeros(3, dtype=np.float64)
            val_g = 0.0

        # Min/max LR snapshot across groups for visibility
        cur_lrs = [pg["lr"] for pg in optimizer.param_groups]
        lr_min, lr_max = (min(cur_lrs), max(cur_lrs)) if len(cur_lrs) else (None, None)

        # ---------- Log line (richer) ----------
        print(f"Epoch {epoch:02d} | train_loss={running_loss/STEPS_PER_EPOCH:.4f} "
              f"| val_gmean={val_g:.4f} | recall(B,A,M)={np.round(recalls,3)}")
        print(f"           | epoch_time={epoch_time:.1f}s | patches={epoch_patches} | throughput={patches_per_sec:.1f} patches/s "
              f"| lr[min,max]=({lr_min:.5g}, {lr_max:.5g})")
        print(f"           | train_class_counts (B,A,M) = {epoch_class_counts.tolist()}")

        # Save counts to show at *start* of next epoch
        prev_epoch_class_counts = epoch_class_counts.clone()

        # ---------- Save best ----------
        if val_g > best_val_g:
            best_val_g = val_g
            torch.save({
                "model": model.state_dict(),
                "class_names": THREE_CLASS_NAMES,
                "val_gmean": best_val_g,
                "image_size": IMAGE_SIZE
            }, ckpt_path)
            print(f"  ↳ saved best: {ckpt_path} (g-mean={best_val_g:.4f})")

    print("Done.")
    print(f"Best val g-mean: {best_val_g:.4f}")
    print(f"Checkpoint: {ckpt_path}")


if __name__ == "__main__":
    train()

--DATA LOADING--
--DATA LOADING DONE--

Epoch 01 — first epoch: class counts will be shown from next epoch.


  scaler = torch.cuda.amp.GradScaler(enabled=AMP)
  with torch.cuda.amp.autocast(enabled=AMP):
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=AMP):


Epoch 01 | train_loss=1.1046 | val_gmean=0.7682 | recall(B,A,M)=[0.831 0.895 0.609]
           | epoch_time=209.0s | patches=92160 | throughput=440.9 patches/s | lr[min,max]=(5.8594e-05, 0.03)
           | train_class_counts (B,A,M) = [30938, 30453, 30769]
  ↳ saved best: ./checkpoints/efficientnet_b0_3way_best.pt (g-mean=0.7682)

Epoch 02 — last epoch class counts (patches seen): Benign=30938, Atypical=30453, Malignant=30769
Epoch 02 | train_loss=1.1041 | val_gmean=0.7729 | recall(B,A,M)=[0.835 0.894 0.619]
           | epoch_time=207.4s | patches=92160 | throughput=444.4 patches/s | lr[min,max]=(5.8594e-05, 0.03)
           | train_class_counts (B,A,M) = [30979, 30451, 30730]
  ↳ saved best: ./checkpoints/efficientnet_b0_3way_best.pt (g-mean=0.7729)

Epoch 03 — last epoch class counts (patches seen): Benign=30979, Atypical=30451, Malignant=30730
Epoch 03 | train_loss=1.1029 | val_gmean=0.7736 | recall(B,A,M)=[0.832 0.893 0.623]
           | epoch_time=212.0s | patches=92160 | through

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f9b97fe5300>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f9b97fe5300>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 09 | train_loss=1.1015 | val_gmean=0.7743 | recall(B,A,M)=[0.836 0.896 0.62 ]
           | epoch_time=211.0s | patches=92160 | throughput=436.7 patches/s | lr[min,max]=(5.8594e-06, 0.003)
           | train_class_counts (B,A,M) = [30774, 30521, 30865]

Epoch 10 — last epoch class counts (patches seen): Benign=30774, Atypical=30521, Malignant=30865
Epoch 10 | train_loss=1.1027 | val_gmean=0.7715 | recall(B,A,M)=[0.834 0.897 0.614]
           | epoch_time=208.4s | patches=92160 | throughput=442.2 patches/s | lr[min,max]=(5.8594e-06, 0.003)
           | train_class_counts (B,A,M) = [30879, 30521, 30760]

Epoch 11 — last epoch class counts (patches seen): Benign=30879, Atypical=30521, Malignant=30760
Epoch 11 | train_loss=1.1014 | val_gmean=0.7682 | recall(B,A,M)=[0.834 0.895 0.607]
           | epoch_time=202.4s | patches=92160 | throughput=455.2 patches/s | lr[min,max]=(5.8594e-06, 0.003)
           | train_class_counts (B,A,M) = [30748, 30609, 30803]

Epoch 12 — last epoch class c

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f9b97fe5300>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f9b97fe5300>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 17 | train_loss=1.1022 | val_gmean=0.7750 | recall(B,A,M)=[0.83  0.896 0.626]
           | epoch_time=201.4s | patches=92160 | throughput=457.6 patches/s | lr[min,max]=(5.8594e-06, 0.003)
           | train_class_counts (B,A,M) = [30754, 30917, 30489]
  ↳ saved best: ./checkpoints/efficientnet_b0_3way_best.pt (g-mean=0.7750)

Epoch 18 — last epoch class counts (patches seen): Benign=30754, Atypical=30917, Malignant=30489
Epoch 18 | train_loss=1.1015 | val_gmean=0.7722 | recall(B,A,M)=[0.834 0.896 0.616]
           | epoch_time=203.5s | patches=92160 | throughput=452.8 patches/s | lr[min,max]=(5.8594e-06, 0.003)
           | train_class_counts (B,A,M) = [30799, 30583, 30778]

Epoch 19 — last epoch class counts (patches seen): Benign=30799, Atypical=30583, Malignant=30778
Epoch 19 | train_loss=1.1023 | val_gmean=0.7716 | recall(B,A,M)=[0.833 0.894 0.617]
           | epoch_time=204.2s | patches=92160 | throughput=451.3 patches/s | lr[min,max]=(5.8594e-06, 0.003)
           | train_

In [31]:
SEVEN_TO_THREE = {
    "0_N":0, "1_PB":0, "2_UDH":0, "3_FEA":1, "4_ADH":1, "5_DCIS":2, "6_IC":2,
}
# Sanity:
root = Path(DRIVE_ROOT) / "train"
actual = sorted([d.name for d in root.iterdir() if d.is_dir()])
print("[FOLDERS]:", actual)
missing = [d for d in actual if d not in SEVEN_TO_THREE]
print("[UNMAPPED]:", missing)   # must be []

[FOLDERS]: ['0_N', '1_PB', '2_UDH', '3_FEA', '4_ADH', '5_DCIS', '6_IC']
[UNMAPPED]: []
