# 📊 Experiment: MIL output visualisation and evaluation
**Date:** 2025-04-9  
**Author:** Valentin Oreiller
**Goal:** Test MIL to be used as a filter for tile mining of LUAD tumor tiles

---

## 1. Setup & Imports
## 2. Data Loading
## 3. Preprocessing
## 4. Experiments / Model Training
## 5. Evaluation
## 6. Observations & Next Steps

In [None]:
from pathlib import Path

import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from pytorch_lightning import Trainer

from histomil.data.torch_datasets import HDF5WSIDatasetWithTileID, HDF5WSIDataset
from histomil.models.models import load_model, get_device
from histomil.models.mil_models import AttentionAggregatorPL
from histomil.visualization.heatmap import compute_attention_map

In [None]:
wsi_dir = Path("/mnt/nas6/data/CPTAC")

In [None]:
def get_wsi_path(wsi_id, wsi_dir):
    wsi_paths = [f for f in wsi_dir.rglob(wsi_id + ".svs")]
    if len(wsi_paths) > 1:
        raise ValueError(f"Multiple WSI files found for {wsi_id}: {wsi_paths}")
    return wsi_paths[0]

In [None]:
hdf5_path = "/home/valentin/workspaces/histomil/data/processed/embeddings/superpixels_resnet50__alpha_0.5__ablation.h5"

val_dataset = HDF5WSIDatasetWithTileID(hdf5_path, split="test")
val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    num_workers=12,
    collate_fn=HDF5WSIDatasetWithTileID.get_collate_fn_ragged(),
)
test_dataset = HDF5WSIDataset(hdf5_path, split="test")
test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    num_workers=12,
    collate_fn=HDF5WSIDataset.get_collate_fn_ragged(),
)

In [None]:
feature_extractor_weights = "/mnt/nas7/data/Personal/Darya/saved_models/superpixels_resnet50__alpha_0.5__ablation_99.pth"
mil_weights = "/home/valentin/workspaces/histomil/models/mil/superpixels_org_alpha0.5_tutobene.ckpt"
device = get_device(gpu_id=0)
mil_aggregator = AttentionAggregatorPL.load_from_checkpoint(mil_weights)

In [None]:
wsi_ids, embeddings, labels, tile_ids = next(iter(val_loader))

In [None]:
labels

In [None]:
mil_aggregator.to(device)
mil_aggregator.eval()   

In [None]:
def plot_attention_map(attention_map, thumbnail):
    # Normalize the attention map between 0 and 1
    attention_norm = (attention_map - attention_map.min()) / (
        attention_map.max() - attention_map.min()
    )

    # Plotting
    plt.figure(figsize=(10, 10))

    # Show WSI thumbnail
    plt.imshow(thumbnail, cmap="gray" if thumbnail.ndim == 2 else None)

    # Overlay attention heatmap with transparency
    plt.imshow(attention_norm, cmap="jet", alpha=0.5)  # alpha adjusts transparency

    plt.axis("off")
    plt.title("WSI Thumbnail with Attention Overlay")
    plt.tight_layout()
    plt.show()

In [None]:
batch_idx = 5
print(f"Batch {batch_idx} with labels {labels[batch_idx]}")

In [None]:
embedding = embeddings[batch_idx].to(device)
pred, proba, attention_scores = mil_aggregator.predict_one_embedding(embedding)

In [None]:
pred

In [None]:
proba

In [None]:
attention_scores = attention_scores.cpu().numpy()

In [None]:
wsi_id = wsi_ids[batch_idx]
wsi_path = get_wsi_path(wsi_id, wsi_dir)
attention_map, thumbnail = compute_attention_map(
    attention_scores,
    tile_ids[batch_idx],
    tile_size=224,
    tile_mpp=1.0,
    wsi_path=wsi_path,
)

In [None]:
plot_attention_map(attention_map, np.array(thumbnail))

In [None]:
trainer = Trainer()

In [None]:
trainer.test(mil_aggregator, test_loader)