In [None]:
from pathlib import Path

import pandas as pd
from openslide import OpenSlide
import numpy as np
import matplotlib.pyplot as plt
import torch
import cv2
from torch.nn.functional import normalize

from histopatseg.fewshot.protonet import ProtoNet, prototype_topk_vote
from histopatseg.data.compute_embeddings_tcga_ut import load_hdf5


In [None]:
protonet = ProtoNet.load("../models/protonet/cptac_enriched_uni2_20x_luad_differentiation_5_patterns_with_normal.pt")

In [None]:
print(protonet.label_map)

In [None]:
# label_map = {'Acinar adenocarcinoma': 0, 'Lepidic adenocarcinoma': 1, 'Micropapillary adenocarcinoma': 2, 'Normal': 3, 'Papillary adenocarcinoma': 4, 'Solid adenocarcinoma': 5}
label_map = protonet.label_map
print(label_map)

In [None]:
project_dir = Path(".").resolve().parent
print(f"Project Directory: {project_dir}")

In [None]:
metadata = pd.read_csv("/mnt/nas6/data/CPTAC/TCIA_CPTAC_LUAD_Pathology_Data_Table.csv").set_index("Slide_ID")

In [None]:
wsi_id = "C3N-02929-22"
tumor_hist_type = metadata.loc[wsi_id, 'Tumor_Histological_Type']

In [None]:
print(f"Specimen Type: {metadata.loc[wsi_id, 'Specimen_Type']}")
print(f"Tumor Histological Type: {tumor_hist_type}")

In [None]:
tumor_histological_counts = metadata["Tumor_Histological_Type"].value_counts()

In [None]:
result = load_hdf5(f"../data/processed/mahmoodlab/UNI2-h_features/CPTAC/CPTAC_LUAD/{wsi_id}.h5")

In [None]:
wsi_path = Path(f"/mnt/nas6/data/CPTAC/CPTAC-LUAD_v12/LUAD/{wsi_id}.svs")
print(str(wsi_path))

In [None]:
wsi = OpenSlide(f"/mnt/nas6/data/CPTAC/CPTAC-LUAD_v12/LUAD/{wsi_id}.svs")

In [None]:
print(f"mpp x : {wsi.properties.get('openslide.mpp-x', 'nan')}")
print(f"mpp y : {wsi.properties.get('openslide.mpp-y', 'nan')}")

In [None]:
embeddings = np.squeeze(result["datasets"]["features"])
coordinates = np.squeeze(result["datasets"]["coords"])

In [None]:
coordinates[1,:] - coordinates[0,:]

In [None]:
def compute_distances(embeddings, prototype_embeddings, mean_embedding):
    """
    Compute the similarity between an embedding and a prototype.
    """
    # Normalize the vectors
    
    feats_query = embeddings
    feats_query = feats_query - mean_embedding
    feats_query = normalize(feats_query, dim=-1, p=2)
    feats_query = feats_query[:, None]  # [N x 1 x D]
    proto_embeddings = prototype_embeddings[None, :]  # [1 x C x D]
    pw_dist = (feats_query - proto_embeddings).norm(
            dim=-1, p=2
    )  # [N x C ]
    
    return pw_dist

In [None]:
distances = compute_distances(torch.tensor(embeddings, dtype=torch.float32), protonet.prototype_embeddings, protonet.mean).numpy()

In [None]:
prototype_topk_vote(protonet, torch.tensor(embeddings, dtype=torch.float32), topk=5)

In [None]:
distances.shape

In [None]:
wsi.level_dimensions[0]

In [None]:
wsi.level_dimensions[0][0] / wsi.level_dimensions[-1][0] 

In [None]:
def compute_heatmap_optimized(wsi, coordinates, scores, tile_size=224, tile_level=0, rescale=False):

    # Rescale scores if needed
    if rescale:
        scores = (2 * scores - np.min(scores) - np.max(scores)) / (np.max(scores) - np.min(scores))

    num_classes = scores.shape[1]

    downsample_to_base = wsi.level_downsamples[tile_level]  # From scores_level to level 0

    wsi_dimensions = wsi.level_dimensions[0]
    downsample = downsample_to_base * tile_size
    heatmap_height = np.ceil(wsi_dimensions[0] / downsample).astype(int)
    heatmap_width = np.ceil(wsi_dimensions[1] / downsample).astype(int)
    heatmap = np.zeros((heatmap_width, heatmap_height, num_classes), dtype=np.float32)  # Shape should be (height, width)

    # Populate the heatmap
    for i, (x, y) in enumerate(coordinates):
        grid_x = np.floor(x / downsample).astype(int)
        grid_y = np.floor(y / downsample).astype(int)
        heatmap[grid_y, grid_x, :] = scores[i, :]

    # Upscale the heatmap to match the thumbnail size
    thumbnail_size = wsi.level_dimensions[-1]  # (height, width)
    heatmap_upscaled = cv2.resize(heatmap, thumbnail_size, interpolation=cv2.INTER_LINEAR)
    thumbnail = wsi.get_thumbnail(thumbnail_size)

    return heatmap_upscaled, thumbnail

In [None]:
heatmaps, thumbnail = compute_heatmap_optimized(wsi, coordinates, -distances, tile_size=256, tile_level=0, rescale=True)

In [None]:
heatmaps.shape

In [None]:
# Normalize all heatmaps to the same scale
vmin = np.min(heatmaps)
vmax = np.max(heatmaps)

num_classes = heatmaps.shape[2]

# Create subplots with space for a colorbar
fig, axes = plt.subplots(1, num_classes+1, figsize=(15, 5), gridspec_kw={"width_ratios": [1] * num_classes + [0.05]})
fig.suptitle(f"Heatmaps for WSI {wsi_id} with {tumor_hist_type} Tumor Type", fontsize=16)

titles = [f"{i.replace(' adenocarcinoma', '')}" for i in label_map.keys()]

heatmaps_list = [heatmaps[:, :, i] for i in range(heatmaps.shape[2])]

# Plot heatmaps
for ax, heatmap, title in zip(axes[:-1], heatmaps_list, titles):  # Exclude the last axis for the colorbar
    im = ax.imshow(heatmap.squeeze(), cmap="jet", vmin=vmin, vmax=vmax)  # Use the same vmin and vmax
    ax.set_title(title)
    ax.axis("off")

# Add a single colorbar in the last axis
cbar = fig.colorbar(im, cax=axes[-1], orientation="vertical")
cbar.set_label("Heatmap Intensity")

# Show the plot
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(thumbnail)