In [None]:
from pathlib import Path

import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from pytorch_lightning import Trainer
import torch
import torch.nn.functional as F
from openslide import OpenSlide
from IPython.display import display
from tqdm import tqdm

from histomining.models.foundation_models import load_model
from histomining.models.linear_probing import LinearProbingFromEmbeddings
from histomining.data.torch_datasets import TileDataset
from histomining.utils import get_device 

In [None]:
device = get_device(gpu_id=0)

In [None]:
wsi_dir = Path("/mnt/nas6/data/CPTAC")
tiles_root_dir = "/mnt/nas7/data/Personal/Valentin/histopath/tiles_20x"

In [None]:
def get_wsi_path(wsi_id, wsi_dir):
    wsi_paths = [f for f in wsi_dir.rglob(wsi_id + ".svs")]
    if len(wsi_paths) > 1:
        raise ValueError(f"Multiple WSI files found for {wsi_id}: {wsi_paths}")
    return wsi_paths[0]

def get_tiles_dir(wsi_id, tiles_root_dir):
    tiles_dir_match = list(Path(tiles_root_dir).glob(f"./*/{wsi_id}"))
    if len(list(tiles_dir_match)) != 1:
        raise ValueError(f"Multiple tile directories found for {wsi_id}: {list(tiles_dir_match)}")
    tiles_dir = tiles_dir_match[0] / "tiles"
    return tiles_dir

In [None]:
model, preprocess, embedding_dim, autocast_dtype = load_model("UNI2",device )

In [None]:
linear_probing = LinearProbingFromEmbeddings.load_from_checkpoint("/home/valentin/workspaces/histomining/models/linear_probing_from_embeddings/linear_probing_weights_uni2_mag_key_0.ckpt", map_location=device)

In [None]:
wsi_id = "C3L-00001-21"
# wsi_id = "C3L-00893-22"

In [None]:
wsi_path = get_wsi_path(wsi_id, wsi_dir)
wsi = OpenSlide(wsi_path)
thumbnail = wsi.get_thumbnail((800,800))
display(thumbnail)

In [None]:
tile_paths = list(get_tiles_dir(wsi_id, tiles_root_dir).glob("*.png"))
tile_ids = [tile_path.stem for tile_path in tile_paths]

In [None]:
print(f"Number of tiles: {len(tile_ids)}")

In [None]:
dataset = TileDataset(tile_paths=tile_paths, preprocess=preprocess)

In [None]:
dataloader = DataLoader(
    dataset,
    batch_size=256,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
)

In [None]:
embeddings = []
preds = []
logits = []
with torch.inference_mode():
    for batch in tqdm(dataloader, desc="Computing embeddings"):
        batch_embeddings = model(batch.to(device))
        embeddings.append(batch_embeddings.cpu().numpy())
        batch_logits = linear_probing(batch_embeddings)
        logits.append(batch_logits.cpu().numpy())
        preds.append(F.softmax(batch_logits, dim=1).cpu().numpy())

embeddings = np.concatenate(embeddings, axis=0)
preds = np.concatenate(preds, axis=0)
logits = np.concatenate(logits, axis=0)

In [None]:
preds[:,1].max()

In [None]:
def plot_attention_map(attention_map, thumbnail):
    # Normalize the attention map between 0 and 1
    attention_norm = (attention_map - attention_map.min()) / (
        attention_map.max() - attention_map.min()
    )

    # Plotting
    plt.figure(figsize=(10, 10))

    # Show WSI thumbnail
    plt.imshow(thumbnail, cmap="gray" if thumbnail.ndim == 2 else None)

    # Overlay attention heatmap with transparency
    plt.imshow(attention_norm, cmap="jet", alpha=0.5)  # alpha adjusts transparency

    plt.axis("off")
    plt.title("WSI Thumbnail with Attention Overlay")
    plt.tight_layout()
    plt.show()

In [None]:

def compute_attention_map(
    attention_scores: np.array,
    tile_ids: list,
    tile_size: int = 224,
    tile_mpp: float = 1.0,
    wsi_path: str = None,
    output_mpp: float = 2.0,
    return_thumbnail: bool = True,
) -> np.array:
    attention_scores = np.squeeze(attention_scores)
    wsi = OpenSlide(wsi_path)
    mpp_x, mpp_y = (
        wsi.properties.get("openslide.mpp-x"),
        wsi.properties.get("openslide.mpp-y"),
    )
    if mpp_x is None or mpp_y is None:
        raise ValueError("Microns per pixel not found in WSI properties.")
    if mpp_x != mpp_y:
        raise ValueError("Microns per pixel values are not equal.")

    mpp_x = float(mpp_x)
    resizing_factor = mpp_x / output_mpp
    wsi_width, wsi_height = wsi.level_dimensions[0]
    width = int(wsi_width * resizing_factor)
    height = int(wsi_height * resizing_factor)

    attention_map = np.zeros((height, width), dtype=np.float32)
    resized_tile_size = int(tile_size * resizing_factor * tile_mpp / mpp_x)
    for tile_idx, tile_id in enumerate(tile_ids):
        x, y = get_position_from_tile_id(tile_id)
        resized_x = int(x * resizing_factor)
        resized_y = int(y * resizing_factor)
        attention_map[
            resized_y : resized_y + resized_tile_size,
            resized_x : resized_x + resized_tile_size,
        ] = attention_scores[tile_idx]

    if return_thumbnail:
        thumbnail = wsi.get_thumbnail((width, height))
        return attention_map, thumbnail
    return attention_map


def get_position_from_tile_id(tile_id):
    # tile_id = tile_id.decode("utf-8")
    parts = tile_id.split("__x")[1].split("_y")
    x = int(parts[0])
    y = int(parts[1])
    return x, y

In [None]:
attention_map, thumbnail = compute_attention_map(preds[:,1], tile_ids, tile_size=224, tile_mpp=0.5, wsi_path=wsi_path, output_mpp=2.0)

In [None]:
attention_map_thresholded = np.where(attention_map > 0.5, attention_map, 0)

In [None]:
attention_map_thresholded

In [None]:
attention_map.max()

In [None]:
plot_attention_map(attention_map, np.array(thumbnail))