In [None]:
import sys

sys.path.append('..')

import math
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn.functional as F
from tqdm import tqdm

from datasets.anorak import ANORAKFewShot
from histo_utils.macenko_torch import (
    normalize_and_unmix,
)
from models.histo_encoder import Uni2Encoder
from models.histo_protonet_decoder import masks_to_token_hard_from_semantic
from training.tiler import GridPadTiler, Tiler

In [None]:
root_dir = "/home/valentin/workspaces/benchmark-vfm-ss/data/ANORAK_10x"

In [None]:
@torch.compiler.disable
def to_per_pixel_targets_semantic(
    targets: list[dict],
    ignore_idx: int,
) -> list[torch.Tensor]:
    """Convert list of instance masks into a single-channel semantic map per image."""
    out: list[torch.Tensor] = []
    for t in targets:
        h, w = t["masks"].shape[-2:]
        y = torch.full((h, w),
                       ignore_idx,
                       dtype=t["labels"].dtype,
                       device=t["labels"].device)
        for i, m in enumerate(t["masks"]):
            y[m] = t["labels"][i]
        out.append(y)  # [H,W] long
    return out

In [None]:
def _device_of(mod: torch.nn.Module) -> torch.device:
    for p in mod.parameters():
        return p.device
    for b in mod.buffers():
        return b.device
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


@torch.no_grad()
def sample_tokens_per_class(
    encoder,  # your Uni2Encoder-like module returning [B,Q,D] for crops
    dataloader,
    *,
    num_classes: int = 7,
    ignore_idx: int = 255,
    tile_size: int = 448,
    stride: int = 448,
    n_per_class: int = 2000,  # cap per class
    bg_idx: int = 0,
    include_background: bool = True,  # set False to drop class 0
    purity_thresh: Optional[
        float] = None,  # e.g. 0.9 to keep only ≥90% pure tokens
    renorm_exclude_ignore: bool = True,
    drop_background_only: bool = True,
    progress: bool = True,
    rng: Optional[torch.Generator] = None,  # for reproducibility
) -> Tuple[torch.Tensor, torch.Tensor, Dict[int, int], Dict[int, int]]:
    """
    Returns:
      X:  [M, D]   tokens (M ≤ num_classes * n_per_class)
      y:  [M]      class ids (long)
      seen_per_class: dict c-> count of valid tokens seen
      kept_per_class: dict c-> count kept (≤ n_per_class)

    Balanced sampling via per-class reservoir sampling.
    """
    device = _device_of(encoder)
    encoder.eval()

    # tilers
    img_tiler = GridPadTiler(tile=tile_size,
                             stride=stride,
                             weighted_blend=False,
                             pad_mode="replicate",
                             pad_value=0.0)
    tgt_tiler = GridPadTiler(tile=tile_size,
                             stride=stride,
                             weighted_blend=False,
                             pad_mode="constant",
                             pad_value=float(ignore_idx))

    # per-class reservoirs
    buffers_X: Dict[int, List[torch.Tensor]] = {
        c: []
        for c in range(num_classes)
    }
    buffers_y: Dict[int, List[int]] = {c: [] for c in range(num_classes)}
    seen_per_class: Dict[int, int] = {c: 0 for c in range(num_classes)}

    # RNG
    if rng is None:
        rng = torch.Generator(device="cpu")
        rng.manual_seed(0)

    it = tqdm(dataloader, desc="gather tokens",
              leave=False) if progress else dataloader
    for imgs, targets in it:
        # tile images
        imgs, _, _, _ = normalize_and_unmix(imgs)
        imgs = imgs.permute(0, 2, 3, 1)  # [B, 3, H,W]
        crops, _, _ = img_tiler.window(imgs)  # [N,3,T,T]
        crops = (crops.to(device) / 255.0)

        # tile semantic targets as single-channel maps
        sem_list = to_per_pixel_targets_semantic(targets,
                                                 ignore_idx)  # list of [H,W]
        sem_list = [y.unsqueeze(0) for y in sem_list]  # -> [1,H,W]
        tgt_crops, _, _ = tgt_tiler.window(sem_list)  # [N,1,T,T]

        # encoder -> tokens
        tokens = encoder(crops)  # [N,Q,D]
        N, Q, D = tokens.shape
        X_flat = tokens.reshape(N * Q, D)  # [N*Q, D]

        # hard labels per token on the grid
        y_hard, valid = masks_to_token_hard_from_semantic(
            tgt_crops.to(device),
            num_classes=num_classes,
            grid_size=encoder.grid_size,  # (Ht,Wt)
            ignore_idx=ignore_idx,
            bg_idx=bg_idx,
            renorm_exclude_ignore=renorm_exclude_ignore,
            drop_background_only=drop_background_only,
            purity_thresh=purity_thresh,
        )  # y_hard:[N,Q], valid:[N,Q] (bool)

        y_flat = y_hard.reshape(-1)  # [N*Q]
        m_flat = valid.reshape(-1)  # [N*Q]

        # optional: drop background
        if not include_background:
            m_flat = m_flat & (y_flat != bg_idx)

        # collect valid
        if m_flat.any():
            Xv = X_flat[m_flat]  # [K,D]
            yv = y_flat[m_flat].to("cpu")  # [K]
            Xv = Xv.to("cpu")  # keep CPU to save VRAM

            # reservoir update per class
            for cls in yv.unique().tolist():
                idxs = (yv == cls).nonzero(as_tuple=False).squeeze(1)
                cls_batch = Xv[idxs]  # [k_c, D]
                for i in range(cls_batch.shape[0]):
                    seen_per_class[cls] += 1
                    buf = buffers_X[cls]
                    if len(buf) < n_per_class:
                        buf.append(cls_batch[i])
                        buffers_y[cls].append(cls)
                    else:
                        # reservoir replacement
                        j = int(
                            torch.randint(0,
                                          seen_per_class[cls], (1, ),
                                          generator=rng).item())
                        if j < n_per_class:
                            buf[j] = cls_batch[i]
                            buffers_y[cls][j] = cls

    # stack results
    X_out = []
    y_out = []
    kept_per_class = {}
    for cls in range(num_classes):
        kept_per_class[cls] = len(buffers_X[cls])
        if kept_per_class[cls] > 0:
            X_out.append(torch.stack(buffers_X[cls], dim=0))  # [m_c, D]
            y_out.append(torch.tensor(buffers_y[cls],
                                      dtype=torch.long))  # [m_c]

    if len(X_out) == 0:
        return (torch.empty(0, encoder.embed_dim),
                torch.empty(0, dtype=torch.long), seen_per_class, {
                    c: 0
                    for c in range(num_classes)
                })

    X = torch.cat(X_out, dim=0)
    y = torch.cat(y_out, dim=0)
    return X, y, seen_per_class, kept_per_class


In [None]:
dm = ANORAKFewShot(
    root_dir,
    devices=1,
    num_workers=0,
    fold=0,
    img_size=(448, 448),
    batch_size=1,
    num_classes=7,
    ignore_idx=255,
)
dm.setup("fit")
train_loader = dm.train_dataloader()

In [None]:
encoder = Uni2Encoder()
device = torch.device("cuda:0")
encoder = encoder.to(device)

In [None]:
X, y, seen, kept = sample_tokens_per_class(
    encoder=encoder.eval(),
    dataloader=train_loader,              # or train_loader
    num_classes=7,
    ignore_idx=255,
    tile_size=448,
    stride=448,
    n_per_class=3000,
    include_background=True,           # often nicer for histo viz
    purity_thresh=0.9,                  # require ≥90% pure tokens
)

print("per-class seen:", seen)
print("per-class kept:", kept)
print("X", X.shape, "y", y.shape)

In [None]:
from umap import UMAP
# or: from sklearn.manifold import TSNE

umap = UMAP(
    n_neighbors=10,
    min_dist=0.05,
    spread=1.5,
    metric="cosine",
    random_state=0,
    n_jobs=24,
)
Z = umap.fit_transform(X.numpy())  # [M,2]


In [None]:
import numpy as np
np.unique(y.numpy())

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm

# Create 7 discrete colors (for classes 0–6)
cmap = ListedColormap(plt.cm.tab10.colors[:7])

# Define bin edges centered on integers 0–6
bounds = np.arange(-0.5, 7.5, 1)  # [-0.5, 0.5, 1.5, ..., 6.5]
norm = BoundaryNorm(bounds, cmap.N, clip=True)

plt.figure(figsize=(7,7))
sc = plt.scatter(Z[:, 0], Z[:, 1], c=y.numpy(), s=3, cmap=cmap, norm=norm)
cbar = plt.colorbar(sc, ticks=np.arange(0, 7))
cbar.set_label("Class")
plt.title("UMAP of spatial tokens")
plt.show()