# 📊 Experiment: CLAM wrapper output visualisation and evaluation
**Date:** 2025-04-9  
**Author:** Valentin Oreiller
**Goal:** Test MIL to be used as a filter for tile mining of LUAD tumor tiles

---

## 1. Setup & Imports
## 2. Data Loading
## 3. Preprocessing
## 4. Experiments / Model Training
## 5. Evaluation
## 6. Observations & Next Steps

In [None]:
from pathlib import Path

import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from pytorch_lightning import Trainer
import torch
import torch.nn.functional as F

from histomil.data.torch_datasets import HDF5WSIDatasetCLAM, HDF5WSIDatasetCLAMWithTileID
from histomil.models.models import load_model, get_device
from histomil.models.clam_wrapper import PL_CLAM_SB
from histomil.visualization.heatmap import compute_attention_map

In [None]:
wsi_dir = Path("/mnt/nas6/data/CPTAC")

In [None]:
def get_wsi_path(wsi_id, wsi_dir):
    wsi_paths = [f for f in wsi_dir.rglob(wsi_id + ".svs")]
    if len(wsi_paths) > 1:
        raise ValueError(f"Multiple WSI files found for {wsi_id}: {wsi_paths}")
    return wsi_paths[0]

In [None]:
hdf5_path = "/home/valentin/workspaces/histomil/data/processed/embeddings/superpixels_moco_org.h5"

val_dataset = HDF5WSIDatasetCLAMWithTileID(hdf5_path, split="test")
val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    num_workers=12,
    collate_fn=HDF5WSIDatasetCLAMWithTileID.get_collate_fn_ragged(),
)
test_dataset = HDF5WSIDatasetCLAM(hdf5_path, split="test")
test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    num_workers=12,
    collate_fn=HDF5WSIDatasetCLAM.get_collate_fn_ragged(),
)

In [None]:
# mil_weights = "/home/valentin/workspaces/histomil/models/mil/UNI2_mil_v1.ckpt"
mil_weights = "/home/valentin/workspaces/histomil/models/mil/test/clam/epoch=89-step=148860.ckpt"
device = get_device(gpu_id=1)
# mil_aggregator = AttentionAggregatorPL.load_from_checkpoint(mil_weights)
mil_aggregator = PL_CLAM_SB.load_from_checkpoint(mil_weights, map_location=device)

In [None]:
mil_aggregator.to(device)
mil_aggregator.eval()   

In [None]:
def plot_attention_map(attention_map, thumbnail):
    # Normalize the attention map between 0 and 1
    attention_norm = (attention_map - attention_map.min()) / (
        attention_map.max() - attention_map.min()
    )

    # Plotting
    plt.figure(figsize=(10, 10))

    # Show WSI thumbnail
    plt.imshow(thumbnail, cmap="gray" if thumbnail.ndim == 2 else None)

    # Overlay attention heatmap with transparency
    plt.imshow(attention_norm, cmap="jet", alpha=0.5)  # alpha adjusts transparency

    plt.axis("off")
    plt.title("WSI Thumbnail with Attention Overlay")
    plt.tight_layout()
    plt.show()

In [None]:
def to_percentiles(scores):
    from scipy.stats import rankdata
    scores = rankdata(scores, 'average')/len(scores) * 100   
    return scores

In [None]:
val_loader_list = list(val_loader)

In [None]:
embeddings, labels, tile_ids = val_loader_list[12]

In [None]:
batch_idx = 0
print(f"Batch {batch_idx} with labels {labels[batch_idx]}")

In [None]:
embeddings = embeddings.to(device)

In [None]:
with torch.no_grad():
    output = mil_aggregator(embeddings)

In [None]:
output

In [None]:
# attention_scores = F.softmax(output[3], dim=1).cpu().numpy()
attention_scores = to_percentiles(output[3].cpu().numpy())

In [None]:
wsi_id = tile_ids[0].decode("utf-8").split("__")[0]

In [None]:
wsi_path = get_wsi_path(wsi_id, wsi_dir)
attention_map, thumbnail = compute_attention_map(
    attention_scores,
    tile_ids,
    tile_size=224,
    tile_mpp=1.0,
    wsi_path=wsi_path,
)

In [None]:
plot_attention_map(attention_map, np.array(thumbnail))

In [None]:
trainer = Trainer(devices=[1])

In [None]:
trainer.test(mil_aggregator, test_loader)

In [None]:

trainer.validate(mil_aggregator, test_loader)