In [None]:
from pathlib import Path

import h5py
import pandas as pd
from openslide import OpenSlide
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torch
import cv2
from torch.nn.functional import normalize
from PIL import Image
from histopreprocessing.features.foundation_models import load_model

from histopatseg.fewshot.protonet import ProtoNet, prototype_topk_vote
from histopatseg.fewshot.extract_patch_features import extract_patch_features_from_dataloader
from histopatseg.data.compute_embeddings_tcga_ut import load_hdf5


In [None]:
protonet = ProtoNet.load("../models/protonet/tcga_ut_nsclc_subtyping_with_normal_20x_nwsi_32.pt")

In [None]:
label_map = protonet.label_map
print(label_map)

In [None]:
project_dir = Path(".").resolve().parent
print(f"Project Directory: {project_dir}")

In [None]:
wsi_id = "TCGA-18-3411-01Z-00-DX1"

In [None]:
hdf5_match = list(Path("../data/processed/mahmoodlab/UNI2-h_features/TCGA/").glob(f"{wsi_id}*.h5"))
if len(list(hdf5_match)) == 0:
    raise FileNotFoundError(f"No HDF5 file found for {wsi_id} in {hdf5_match.parent}")
if len(list(hdf5_match)) > 1:
    raise FileExistsError(f"Multiple HDF5 files found for {wsi_id} in {hdf5_match.parent}")
hdf5_path = list(hdf5_match)[0]

In [None]:
result = load_hdf5(hdf5_path)

In [None]:
wsi_path_match = list(Path("/mnt/nas7/data/TCGA_Lung_svs").rglob(f"{wsi_id}*.svs"))
if len(list(wsi_path_match)) == 0:
    raise FileNotFoundError("mmmh not found")
if len(list(wsi_path_match)) > 1:
    raise FileExistsError(f"Multiple WSI files found for {wsi_id} in {wsi_path_match.parent}")
wsi_path = list(wsi_path_match)[0]

In [None]:
wsi = OpenSlide(wsi_path)

In [None]:
print(f"mpp x : {wsi.properties.get('openslide.mpp-x', 'nan')}")
print(f"mpp y : {wsi.properties.get('openslide.mpp-y', 'nan')}")

In [None]:
result["datasets"].keys()

In [None]:
result["datasets"]["coords"][:].shape

In [None]:
# embeddings = result["datasets"]["embeddings"]
# coordinates = result["datasets"]["coordinates"]
embeddings = np.squeeze(result["datasets"]["features"])
coordinates = np.squeeze(result["datasets"]["coords"])

In [None]:
def infer_tile_size(coords: np.ndarray):
    x_unique = np.unique(coords[:, 0])
    y_unique = np.unique(coords[:, 1])

    x_diffs = np.diff(np.sort(x_unique))
    y_diffs = np.diff(np.sort(y_unique))

    tile_width = np.min(x_diffs[x_diffs > 0])
    tile_height = np.min(y_diffs[y_diffs > 0])

    return tile_width, tile_height

In [None]:
tile_width, tile_height = infer_tile_size(coordinates)
if tile_width != tile_height:
    raise ValueError("Tile width and height are not equal. Please check the coordinates.")
tile_size = tile_width
print(f"Tile size: {tile_size} at level 0")

In [None]:
coordinates.shape, embeddings.shape

In [None]:
def compute_distances(embeddings, prototype_embeddings, mean_embedding):
    """
    Compute the similarity between an embedding and a prototype.
    """
    # Normalize the vectors
    
    feats_query = embeddings
    feats_query = feats_query - mean_embedding
    feats_query = normalize(feats_query, dim=-1, p=2)
    feats_query = feats_query[:, None]  # [N x 1 x D]
    proto_embeddings = prototype_embeddings[None, :]  # [1 x C x D]
    pw_dist = (feats_query - proto_embeddings).norm(
            dim=-1, p=2
    )  # [N x C ]
    
    return pw_dist

In [None]:
distances = compute_distances(torch.tensor(embeddings, dtype=torch.float32), protonet.prototype_embeddings, protonet.mean).numpy()

In [None]:
prototype_topk_vote(protonet, torch.tensor(embeddings, dtype=torch.float32), topk=5)

In [None]:
distances.shape

In [None]:
wsi.level_dimensions[0]

In [None]:
wsi.level_dimensions[0][0] / wsi.level_dimensions[-1][0] 

In [None]:
num_classes = len(label_map)

In [None]:
def compute_heatmap_optimized(wsi, coordinates, scores, tile_size=224, tile_level=0, rescale=False):

    # Rescale scores if needed
    if rescale:
        scores = (2 * scores - np.min(scores) - np.max(scores)) / (np.max(scores) - np.min(scores))

    num_classes = scores.shape[1]

    downsample_to_base = wsi.level_downsamples[tile_level]  # From scores_level to level 0

    wsi_dimensions = wsi.level_dimensions[0]
    downsample = downsample_to_base * tile_size
    heatmap_height = np.round(wsi_dimensions[0] / downsample).astype(int)
    heatmap_width = np.round(wsi_dimensions[1] / downsample).astype(int)
    heatmap = np.zeros((heatmap_width, heatmap_height, num_classes), dtype=np.float32)  # Shape should be (height, width)

    # Populate the heatmap
    for i, (x, y) in enumerate(coordinates):
        grid_x = np.floor(x / downsample).astype(int)
        grid_y = np.floor(y / downsample).astype(int)
        heatmap[grid_y, grid_x, :] = scores[i, :]

    # Upscale the heatmap to match the thumbnail size
    thumbnail_size = wsi.level_dimensions[-1]  # (height, width)
    heatmap_upscaled = cv2.resize(heatmap, thumbnail_size, interpolation=cv2.INTER_LINEAR)
    thumbnail = wsi.get_thumbnail(thumbnail_size)

    return heatmap_upscaled, thumbnail

In [None]:
heatmaps, thumbnail = compute_heatmap_optimized(wsi, coordinates, -distances, tile_size=tile_size, tile_level=0, rescale=True)

In [None]:
heatmaps.shape

In [None]:
# Normalize all heatmaps to the same scale
vmin = np.min(heatmaps)
vmax = np.max(heatmaps)

# Create subplots with space for a colorbar
fig, axes = plt.subplots(1, num_classes+1, figsize=(15, 5), gridspec_kw={"width_ratios": [1] * num_classes + [0.05]})
fig.suptitle(f"Heatmaps for WSI {wsi_id}", fontsize=16)

titles = [f"{i}" for i in label_map.keys()]

heatmaps_list = [heatmaps[:, :, i] for i in range(heatmaps.shape[2])]

# Plot heatmaps
for ax, heatmap, title in zip(axes[:-1], heatmaps_list, titles):  # Exclude the last axis for the colorbar
    im = ax.imshow(heatmap.squeeze(), cmap="jet", vmin=vmin, vmax=vmax)  # Use the same vmin and vmax
    ax.set_title(title)
    ax.axis("off")

# Add a single colorbar in the last axis
cbar = fig.colorbar(im, cax=axes[-1], orientation="vertical")
cbar.set_label("Heatmap Intensity")

# Show the plot
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(thumbnail)