In [5]:
try:
    from google.colab import drive  # works in Colab
    drive.mount('/content/drive', force_remount=False)
    DRIVE_ROOT = Path('/content/drive/MyDrive')
    assert DRIVE_ROOT.exists(), "Drive mount failed (MyDrive not found)."
except Exception as e:
    raise RuntimeError(
        "Google Drive is not mounted. Run this in Google Colab or mount Drive first."
    ) from e

Mounted at /content/drive


In [10]:
# late interaction model
# ============================================================
# Late Interaction (LITE-style) scorer for WSI retrieval
# Binary multi-label supervision (any-overlap positives)
# ============================================================

import os, json, random, math, gc
from pathlib import Path
from typing import List, Dict, Tuple
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# -------------------------
# Config (edit as needed)
# -------------------------
SEED = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Drive roots (ensure Drive is mounted!)
DRIVE_ROOT = Path("/content/drive/MyDrive")

SPLIT_TRAIN = "train"     # change if needed
SPLIT_VAL   = "test"      # quick sanity evaluation after training

EMB_ROOT_REL = "BRACS/Embeddings"     # where *.pt embeddings live
MODEL_ROOT_REL = "BRACS/Models"       # where to save model checkpoints

# Excel with labels
XLSX_PATH = "/content/BRACS_BRACS.xlsx"
SHEET = "WSI_Information"
ID_COL = "WSI Filename"   # will be normalized to stem
LABEL_COL = "WSI label"   # multi-label -> separate with delimiter if needed

# Late interaction sizes
L1_MAX = 64     # max #tiles for query slide
L2_MAX = 64     # max #tiles for document slide
ROW_HIDDEN = 128
COL_HIDDEN = 512

# Training (updated)
BATCH_SIZE = 10            # #queries per step
N_POS_PER_Q = 1            # positives sampled per query
N_NEG_PER_Q = 6            # negatives sampled per query
EPOCHS = 40
LR = 1e-3
LR_MIN = 1e-5
WARMUP_EPOCHS = 1          # linear warmup over ~1 epoch
WEIGHT_DECAY = 1e-4
GRAD_CLIP_NORM = 1.0
LOG_EVERY = 20
EARLY_STOP_PATIENCE = 5

# ------------------------------------------
# Utils
# ------------------------------------------
def set_seed(seed=SEED):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

def normalize_id(s: str) -> str:
    return Path(str(s)).stem

def load_label_map_from_excel(
    xlsx_path: str, sheet=SHEET, id_col=ID_COL, label_col=LABEL_COL, delimiter=None
) -> Dict[str, List[str]]:
    """
    Returns: dict slide_id -> list of labels
    If label_col contains multi-label strings, provide delimiter (e.g., ';')
    """
    df = pd.read_excel(xlsx_path, sheet_name=sheet, engine="openpyxl")
    ids = df[id_col].astype(str).apply(normalize_id).tolist()
    raw = df[label_col].astype(str).tolist()
    if delimiter is None:
        labels = [[x.strip()] for x in raw]           # single label per slide
    else:
        labels = [[t.strip() for t in x.split(delimiter) if t.strip()] for x in raw]
    return {i: l for i, l in zip(ids, labels)}

def any_overlap(Lq: List[str], Ld: List[str]) -> bool:
    return len(set(Lq) & set(Ld)) > 0

# ------------------------------------------
# Embedding index
# ------------------------------------------
class EmbIndex:
    """Scans Drive for embeddings and pairs them with labels."""
    def __init__(self, drive_root: Path, emb_rel: str, split: str, label_map: Dict[str, List[str]]):
        self.dir = drive_root / emb_rel / split
        if not self.dir.exists():
            raise FileNotFoundError(f"Embeddings dir not found: {self.dir}")
        # gather files and keep only those with labels
        files = sorted(self.dir.glob("*.pt"))
        self.items = []
        for f in files:
            sid = f.stem
            if sid in label_map:
                self.items.append((sid, f))
        if not self.items:
            raise RuntimeError(f"No embeddings with labels in {self.dir}")
        self.label_map = label_map

    def __len__(self):
        return len(self.items)

    def path_of(self, slide_id: str) -> Path:
        return self.dir / f"{slide_id}.pt"

    def labels_of(self, slide_id: str) -> List[str]:
        return self.label_map[slide_id]

    def all_ids(self) -> List[str]:
        return [sid for sid,_ in self.items]

# ------------------------------------------
# Dataset that yields (query, positives, negatives)
# ------------------------------------------
def _sample_tiles(E: torch.Tensor, K: int) -> torch.Tensor:
    """
    E: [N_tiles, D]
    Returns: [K, D] (random sample with replacement if N<K; else without replacement)
    """
    N = E.size(0)
    if N <= 0:
        raise ValueError("Empty embedding set")
    if N >= K:
        idx = torch.randperm(N)[:K]
        return E[idx]
    # pad by sampling with replacement
    add = K - N
    dup_idx = torch.randint(0, N, (add,))
    return torch.cat([E, E[dup_idx]], dim=0)

def _pad_tiles(E: torch.Tensor, K: int) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    If E has >=K tiles: subsample K; if <K: sample with replacement to K.
    Return:
       E_out: [K,D]
       mask : [K] (True for real positions; here mask is all True—sampling already makes length K)
    """
    E_out = _sample_tiles(E, K)  # [K,D]
    mask = torch.ones(K, dtype=torch.bool)
    return E_out, mask

class PairwiseBatchDataset(Dataset):
    """
    Each item:
      - one query id q
      - 1..N_POS_PER_Q positives (any label overlap)
      - N_NEG_PER_Q negatives (no overlap)
      Returns embeddings tensors for dynamic batching.
    """
    def __init__(self, index: EmbIndex, L1_max=L1_MAX, L2_max=L2_MAX, n_pos=N_POS_PER_Q, n_neg=N_NEG_PER_Q):
        self.index = index
        self.ids = index.all_ids()
        self.L1 = L1_max; self.L2 = L2_max
        self.n_pos = n_pos; self.n_neg = n_neg

        # build positive/negative pools per id
        lbls = index.label_map
        self.pos_pool = {}
        self.neg_pool = {}
        for q in self.ids:
            Lq = lbls[q]
            pos = [d for d in self.ids if d != q and any_overlap(Lq, lbls[d])]
            neg = [d for d in self.ids if d != q and not any_overlap(Lq, lbls[d])]
            self.pos_pool[q] = pos
            self.neg_pool[q] = neg
        # keep only queries with at least one positive & one negative
        self.ids = [q for q in self.ids if len(self.pos_pool[q])>0 and len(self.neg_pool[q])>0]
        if not self.ids:
            raise RuntimeError("No queries with both positives and negatives. Check labels/split.")

    def __len__(self):
        return len(self.ids)

    def _load_emb(self, slide_id: str) -> torch.Tensor:
        pkg = torch.load(self.index.path_of(slide_id), map_location="cpu")
        E = pkg["tile_embeds"]
        if not torch.is_tensor(E): E = torch.tensor(E, dtype=torch.float32)
        return E.float()

    def __getitem__(self, idx):
        qid = self.ids[idx]
        qE = self._load_emb(qid)            # [Nq,D]

        # sample positives and negatives
        pos_ids = random.sample(self.pos_pool[qid], k=min(self.n_pos, len(self.pos_pool[qid])))
        if len(pos_ids) < self.n_pos:       # pad by reusing
            pos_ids += random.choices(self.pos_pool[qid], k=self.n_pos - len(pos_ids))

        neg_ids = random.sample(self.neg_pool[qid], k=min(self.n_neg, len(self.neg_pool[qid])))
        if len(neg_ids) < self.n_neg:
            neg_ids += random.choices(self.neg_pool[qid], k=self.n_neg - len(neg_ids))

        posE = [self._load_emb(pid) for pid in pos_ids]
        negE = [self._load_emb(nid) for nid in neg_ids]

        # sample/pad to fixed tile counts
        qE, qmask = _pad_tiles(qE, self.L1)
        posE = [ _pad_tiles(E, self.L2)[0] for E in posE ]
        negE = [ _pad_tiles(E, self.L2)[0] for E in negE ]

        return {
            "qid": qid,
            "qE": qE,                   # [L1,D]
            "posE": torch.stack(posE),  # [P,L2,D]
            "negE": torch.stack(negE),  # [N,L2,D]
        }

def collate_pairwise(batch):
    # Stack queries; positives/negatives stack along batch
    qE  = torch.stack([b["qE"] for b in batch], 0)                           # [B,L1,D]
    pos = torch.stack([b["posE"] for b in batch], 0)                          # [B,P,L2,D]
    neg = torch.stack([b["negE"] for b in batch], 0)                          # [B,N,L2,D]
    qids = [b["qid"] for b in batch]
    return {"qE": qE, "posE": pos, "negE": neg, "qids": qids}

# ------------------------------------------
# Separable LITE scorer (row-MLP then col-MLP)
# Works on fixed (L1_max, L2_max)
# ------------------------------------------
class SeparableLITE(nn.Module):
    def __init__(self, L1_max=L1_MAX, L2_max=L2_MAX, row_hidden=ROW_HIDDEN, col_hidden=COL_HIDDEN):
        super().__init__()
        self.L1 = L1_max; self.L2 = L2_max

        self.row_mlp = nn.Sequential(
            nn.LayerNorm(L2_max),
            nn.Linear(L2_max, row_hidden),
            nn.ReLU(inplace=True),
            nn.Linear(row_hidden, L2_max)
        )
        self.col_mlp = nn.Sequential(
            nn.LayerNorm(L1_max),
            nn.Linear(L1_max, col_hidden),
            nn.ReLU(inplace=True),
            nn.Linear(col_hidden, L1_max)
        )
        self.final = nn.Linear(L1_max * L2_max, 1)

    @staticmethod
    def _cosine_sim(Q, D):
        # Q: [B,L1,D], D: [B,L2,D] -> S: [B,L1,L2]
        Qn = F.normalize(Q, dim=-1)
        Dn = F.normalize(D, dim=-1)
        return torch.matmul(Qn, Dn.transpose(-1,-2))

    def forward(self, Q: torch.Tensor, D: torch.Tensor) -> torch.Tensor:
        """
        Q: [B, L1, D], D: [B, L2, D]
        returns scores: [B]
        """
        S = self._cosine_sim(Q, D)                        # [B,L1,L2]
        S1 = self.row_mlp(S)                              # row-wise transform (along L2)
        S2 = self.col_mlp(S1.transpose(-2,-1)).transpose(-2,-1)  # column-wise (along L1)
        score = self.final(S2.flatten(1)).squeeze(-1)     # [B]
        return score

# ------------------------------------------
# Loss: Pairwise logistic ranking (multi-positive)
# ------------------------------------------
def pairwise_logistic_loss(scores_pos: torch.Tensor, scores_neg: torch.Tensor, margin: float = 0.0):
    """
    scores_pos: [B, P]  (one or more positives per query)
    scores_neg: [B, N]
    Returns mean over all B * P * N pairs
    """
    diffs = scores_pos.unsqueeze(-1) - scores_neg.unsqueeze(-2) - margin  # [B,P,N]
    return F.softplus(-diffs).mean()

# ------------------------------------------
# Scheduler: linear warmup -> cosine decay
# ------------------------------------------
def make_warmup_cosine_scheduler(optimizer, total_steps: int, warmup_steps: int, lr_start: float = 0.0, lr_max: float = LR, lr_min: float = LR_MIN):
    def lr_lambda(step):
        if step < warmup_steps:
            # linear warmup from lr_start to lr_max
            return (lr_start + (lr_max - lr_start) * (step / max(1, warmup_steps))) / lr_max
        # cosine from lr_max to lr_min
        progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
        target = lr_min + (lr_max - lr_min) * cosine
        return target / lr_max
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

# ------------------------------------------
# Trainer
# ------------------------------------------
class LateInteractionTrainer:
    def __init__(self, model: nn.Module, lr=LR, wd=WEIGHT_DECAY):
        self.model = model.to(DEVICE)
        self.opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)

    def _score_batch(self, qE, dE):
        """
        TRAIN-TIME scorer: keeps autograd graph.
        qE: [B,L1,D], dE: [B,K,L2,D]  ->  scores: [B,K]
        """
        B, K = dE.size(0), dE.size(1)
        q_rep = qE.unsqueeze(1).expand(-1, K, -1, -1).reshape(B*K, qE.size(1), qE.size(2))
        d_flat = dE.reshape(B*K, dE.size(2), dE.size(3))
        s_flat = self.model(q_rep, d_flat)          # [B*K], requires_grad=True
        return s_flat.view(B, K)

    def train_epoch(self, loader, epoch=0, log_every=LOG_EVERY, grad_clip=GRAD_CLIP_NORM, scheduler=None):
        self.model.train()
        running = 0.0
        for it, batch in enumerate(loader, 1):
            qE  = batch["qE"].to(DEVICE)                    # [B,L1,D]
            pos = batch["posE"].to(DEVICE)                  # [B,P,L2,D]
            neg = batch["negE"].to(DEVICE)                  # [B,N,L2,D]

            s_pos = self._score_batch(qE, pos)              # [B,P], requires_grad
            s_neg = self._score_batch(qE, neg)              # [B,N], requires_grad

            loss = pairwise_logistic_loss(s_pos, s_neg, margin=0.0)

            self.opt.zero_grad(set_to_none=True)
            loss.backward()
            if grad_clip is not None:
                nn.utils.clip_grad_norm_(self.model.parameters(), grad_clip)
            self.opt.step()
            if scheduler is not None:
                scheduler.step()

            running += loss.item()
            if it % log_every == 0:
                lr_cur = self.opt.param_groups[0]["lr"]
                print(f"[epoch {epoch} iter {it}/{len(loader)}] loss={running/log_every:.4f}  lr={lr_cur:.2e}")
                running = 0.0

            del qE, pos, neg, s_pos, s_neg, loss
            if DEVICE.type == "cuda":
                torch.cuda.empty_cache()

    @torch.no_grad()
    def quick_eval(self, val_loader, topk=(1,3,5)):
        self.model.eval()
        hits = {k: 0 for k in topk}; total = 0
        for batch in val_loader:
            qE  = batch["qE"].to(DEVICE)
            pos = batch["posE"].to(DEVICE)
            neg = batch["negE"].to(DEVICE)
            s_pos = self._score_batch(qE, pos)
            s_neg = self._score_batch(qE, neg)
            for b in range(qE.size(0)):
                cand_scores = torch.cat([s_pos[b], s_neg[b]], dim=0)
                order = torch.argsort(cand_scores, descending=True)
                total += 1
                for k in topk:
                    hits[k] += (order[:k] < s_pos.size(1)).any().item()
        return {k: hits[k]/max(1,total) for k in topk}

# ------------------------------------------
# Build loaders and train
# ------------------------------------------
def make_loader(drive_root: Path, split: str, label_map: Dict[str, List[str]], batch_size=BATCH_SIZE, shuffle=True, drop_last=True):
    index = EmbIndex(drive_root, EMB_ROOT_REL, split, label_map)
    ds = PairwiseBatchDataset(index, L1_max=L1_MAX, L2_max=L2_MAX, n_pos=N_POS_PER_Q, n_neg=N_NEG_PER_Q)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=0, pin_memory=(DEVICE.type=="cuda"),
                    collate_fn=collate_pairwise, drop_last=drop_last)
    return dl

def train_lite():
    set_seed(SEED)

    # --- labels ---
    label_map = load_label_map_from_excel(XLSX_PATH, SHEET, ID_COL, LABEL_COL, delimiter=None)

    # --- data loaders ---
    train_loader = make_loader(DRIVE_ROOT, SPLIT_TRAIN, label_map, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
    val_loader   = make_loader(DRIVE_ROOT, SPLIT_VAL,   label_map, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

    # --- model & trainer ---
    model = SeparableLITE(L1_max=L1_MAX, L2_max=L2_MAX, row_hidden=ROW_HIDDEN, col_hidden=COL_HIDDEN)
    trainer = LateInteractionTrainer(model, lr=LR, wd=WEIGHT_DECAY)

    # --- scheduler (warmup -> cosine) ---
    steps_per_epoch = len(train_loader)
    total_steps = EPOCHS * steps_per_epoch
    warmup_steps = max(1, int(WARMUP_EPOCHS * steps_per_epoch))
    scheduler = make_warmup_cosine_scheduler(trainer.opt, total_steps, warmup_steps, lr_start=0.0, lr_max=LR, lr_min=LR_MIN)
    print(f"Training: {EPOCHS} epochs, {steps_per_epoch} steps/epoch, total_steps={total_steps}, warmup_steps={warmup_steps}")

    # --- training loop with early stopping ---
    save_dir = DRIVE_ROOT / MODEL_ROOT_REL / SPLIT_TRAIN
    save_dir.mkdir(parents=True, exist_ok=True)
    best_top1 = -1.0
    bad = 0

    for ep in range(1, EPOCHS+1):
        trainer.train_epoch(train_loader, epoch=ep, scheduler=scheduler)

        metrics = trainer.quick_eval(val_loader, topk=(1,3,5))
        print(f"[val @ epoch {ep}] Top-1={metrics[1]*100:.2f}  Top-3={metrics[3]*100:.2f}  Top-5={metrics[5]*100:.2f}")

        # early stopping on Top-1
        if metrics[1] > best_top1 + 1e-6:
            best_top1 = metrics[1]
            bad = 0
            # save best
            best_path = save_dir / f"lite_separable_best.pt"
            torch.save({"model_state": model.state_dict(),
                        "config": {
                            "L1_MAX": L1_MAX, "L2_MAX": L2_MAX,
                            "ROW_HIDDEN": ROW_HIDDEN, "COL_HIDDEN": COL_HIDDEN,
                            "EMB_ROOT_REL": EMB_ROOT_REL, "SPLIT_TRAIN": SPLIT_TRAIN,
                            "SPLIT_VAL": SPLIT_VAL
                        }}, best_path)
            print(f"[best] Saved checkpoint: {best_path}")
        else:
            bad += 1
            if bad >= EARLY_STOP_PATIENCE:
                print("Early stop triggered.")
                break

    # --- save last model + config to Drive ---
    last_path = save_dir / f"lite_separable_last.pt"
    torch.save({"model_state": model.state_dict(),
                "config": {
                    "L1_MAX": L1_MAX, "L2_MAX": L2_MAX,
                    "ROW_HIDDEN": ROW_HIDDEN, "COL_HIDDEN": COL_HIDDEN,
                    "EMB_ROOT_REL": EMB_ROOT_REL, "SPLIT_TRAIN": SPLIT_TRAIN,
                    "SPLIT_VAL": SPLIT_VAL
                }}, last_path)
    print(f"Saved last model: {last_path}")

# -------------------------
# Run training
# -------------------------
# Make sure Drive is mounted and DRIVE_ROOT is correct before running:
# from google.colab import drive; drive.mount('/content/drive')



In [11]:
train_lite()

Training: 40 epochs, 37 steps/epoch, total_steps=1480, warmup_steps=37
[epoch 1 iter 20/37] loss=0.6997  lr=5.41e-04
[val @ epoch 1] Top-1=18.07  Top-3=43.37  Top-5=71.08
[best] Saved checkpoint: /content/drive/MyDrive/BRACS/Models/train/lite_separable_best.pt
[epoch 2 iter 20/37] loss=0.7019  lr=1.00e-03
[val @ epoch 2] Top-1=14.46  Top-3=34.94  Top-5=59.04
[epoch 3 iter 20/37] loss=0.7076  lr=9.96e-04
[val @ epoch 3] Top-1=9.64  Top-3=36.14  Top-5=75.90
[epoch 4 iter 20/37] loss=0.7347  lr=9.90e-04
[val @ epoch 4] Top-1=8.43  Top-3=32.53  Top-5=65.06
[epoch 5 iter 20/37] loss=0.7324  lr=9.80e-04
[val @ epoch 5] Top-1=19.28  Top-3=49.40  Top-5=67.47
[best] Saved checkpoint: /content/drive/MyDrive/BRACS/Models/train/lite_separable_best.pt
[epoch 6 iter 20/37] loss=0.7254  lr=9.67e-04
[val @ epoch 6] Top-1=7.23  Top-3=42.17  Top-5=66.27
[epoch 7 iter 20/37] loss=0.7411  lr=9.52e-04
[val @ epoch 7] Top-1=13.25  Top-3=37.35  Top-5=62.65
[epoch 8 iter 20/37] loss=0.7188  lr=9.33e-04
[val @