In [None]:
from pathlib import Path

import h5py
import pandas as pd
from openslide import OpenSlide
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torch
import cv2
from torch.nn.functional import normalize
from PIL import Image
from histopreprocessing.features.foundation_models import load_model

from histopatseg.fewshot.protonet import ProtoNet, prototype_topk_vote
from histopatseg.fewshot.extract_patch_features import extract_patch_features_from_dataloader
from histopatseg.data.compute_embeddings_tcga_ut import load_hdf5


In [None]:
# label_map = {"aca_bd": 0, "aca_md": 1, "aca_pd":2, "nor":3}
label_map = {"aca_bd": 0, "aca_md": 1, "aca_pd":2, "nor":3, "scc_bd": 4, "scc_md": 5, "scc_pd":6}

In [None]:
project_dir = Path(".").resolve().parent
print(f"Project Directory: {project_dir}")

In [None]:
embedding_file = project_dir / "data/processed/embeddings/LungHist700/lunghist700_20x_UNI2_TS_256_embeddings.npz"
metadata  = pd.read_csv(project_dir / "data/processed/LungHist700_tiled/LungHist700_20x_TS_256/metadata.csv").set_index("tile_id")
metadata.head()

In [None]:
# Load the embeddings
data = np.load(embedding_file)
embeddings = data["embeddings"]
tile_ids = data["tile_ids"]
embedding_dim = data["embedding_dim"]

# Print basic information
print(f"Loaded {len(embeddings)} embeddings with dimensionality {embeddings.shape[1]}")
print(f"Embedding dimension from model: {embedding_dim}")

In [None]:
embeddings_df = pd.DataFrame(
    {
        "tile_id": tile_ids,
        "embeddings": list(embeddings),  # Add embeddings as a column
    }
).set_index("tile_id")

In [None]:
df = pd.concat([embeddings_df, metadata], axis=1)

In [None]:
# df_filtered = df[(df["superclass"] == "aca") | (df["superclass"] == "nor")]
df_filtered = df
num_classes = 7

In [None]:
df_filtered.head()

In [None]:
# Extract embeddings and labels
embeddings_train = np.stack(df_filtered["embeddings"].values)
labels_train = df_filtered["class_name"].values
labels_train = np.array([label_map[label] for label in labels_train])


In [None]:
embeddings_train

In [None]:
protonet = ProtoNet()
protonet.fit(
    torch.tensor(embeddings_train, dtype=torch.float32),
    torch.tensor(labels_train, dtype=torch.long),
)

In [None]:
metadata_luad = pd.read_csv("/mnt/nas6/data/CPTAC/TCIA_CPTAC_LUAD_Pathology_Data_Table.csv").set_index("Slide_ID")
metadata_lusc = pd.read_csv("/mnt/nas6/data/CPTAC/TCIA_CPTAC_LSCC_Pathology_Data_Table.csv").set_index("Slide_ID")
metadata = pd.concat([metadata_luad, metadata_lusc], axis=0)

In [None]:
wsi_id = "C3N-02929-22"
tumor_hist_type = metadata.loc[wsi_id, 'Tumor_Histological_Type']
cohort = metadata.loc[wsi_id, "Tumor"]

In [None]:
print(f"Specimen Type: {metadata.loc[wsi_id, 'Specimen_Type']}")
print(f"Tumor Histological Type: {tumor_hist_type}")
print(f"Cohort: {cohort}")

In [None]:
tumor_histological_counts = metadata["Tumor_Histological_Type"].value_counts()

In [None]:
result = load_hdf5(f"../data/processed/mahmoodlab/UNI2-h_features/CPTAC/CPTAC_{cohort}/{wsi_id}.h5")

In [None]:
wsi_matches = list(Path("/mnt/nas6/data/CPTAC/").rglob(f"**/{wsi_id}.svs"))
if len(wsi_matches) == 0:
    raise FileNotFoundError(f"Could not find WSI file for {wsi_id}")
elif len(wsi_matches) > 1:
    raise FileExistsError(f"Multiple WSI files found for {wsi_id}")

In [None]:
wsi = OpenSlide(wsi_matches[0])

In [None]:
print(f"mpp x : {wsi.properties.get('openslide.mpp-x', 'nan')}")
print(f"mpp y : {wsi.properties.get('openslide.mpp-y', 'nan')}")

In [None]:
result["datasets"].keys()

In [None]:
result["datasets"]["coords"][:].shape

In [None]:
# embeddings = result["datasets"]["embeddings"]
# coordinates = result["datasets"]["coordinates"]
embeddings = np.squeeze(result["datasets"]["features"])
coordinates = np.squeeze(result["datasets"]["coords"])

In [None]:
coordinates[1,:] - coordinates[0,:]

In [None]:
coordinates.shape, embeddings.shape

In [None]:
def compute_distances(embeddings, prototype_embeddings, mean_embedding):
    """
    Compute the similarity between an embedding and a prototype.
    """
    # Normalize the vectors
    
    feats_query = embeddings
    feats_query = feats_query - mean_embedding
    feats_query = normalize(feats_query, dim=-1, p=2)
    feats_query = feats_query[:, None]  # [N x 1 x D]
    proto_embeddings = prototype_embeddings[None, :]  # [1 x C x D]
    pw_dist = (feats_query - proto_embeddings).norm(
            dim=-1, p=2
    )  # [N x C ]
    
    return pw_dist

In [None]:
distances = compute_distances(torch.tensor(embeddings, dtype=torch.float32), protonet.prototype_embeddings, protonet.mean).numpy()

In [None]:
prototype_topk_vote(protonet, torch.tensor(embeddings, dtype=torch.float32), topk=5)

In [None]:
distances.shape

In [None]:
wsi.level_dimensions[0]

In [None]:
wsi.level_dimensions[0][0] / wsi.level_dimensions[-1][0] 

In [None]:
def compute_heatmap_optimized(wsi, coordinates, scores, tile_size=224, tile_level=0, rescale=False):

    # Rescale scores if needed
    if rescale:
        scores = (2 * scores - np.min(scores) - np.max(scores)) / (np.max(scores) - np.min(scores))

    num_classes = scores.shape[1]

    downsample_to_base = wsi.level_downsamples[tile_level]  # From scores_level to level 0

    wsi_dimensions = wsi.level_dimensions[0]
    downsample = downsample_to_base * tile_size
    heatmap_height = np.round(wsi_dimensions[0] / downsample).astype(int)
    heatmap_width = np.round(wsi_dimensions[1] / downsample).astype(int)
    heatmap = np.zeros((heatmap_width, heatmap_height, num_classes), dtype=np.float32)  # Shape should be (height, width)

    # Populate the heatmap
    for i, (x, y) in enumerate(coordinates):
        grid_x = np.floor(x / downsample).astype(int)
        grid_y = np.floor(y / downsample).astype(int)
        heatmap[grid_y, grid_x, :] = scores[i, :]

    # Upscale the heatmap to match the thumbnail size
    thumbnail_size = wsi.level_dimensions[-1]  # (height, width)
    heatmap_upscaled = cv2.resize(heatmap, thumbnail_size, interpolation=cv2.INTER_LINEAR)
    thumbnail = wsi.get_thumbnail(thumbnail_size)

    return heatmap_upscaled, thumbnail

In [None]:
heatmaps, thumbnail = compute_heatmap_optimized(wsi, coordinates, -distances, tile_size=256, tile_level=0, rescale=True)

In [None]:
heatmaps.shape

In [None]:
# Normalize all heatmaps to the same scale
vmin = np.min(heatmaps)
vmax = np.max(heatmaps)

# Create subplots with space for a colorbar
fig, axes = plt.subplots(1, num_classes+1, figsize=(15, 5), gridspec_kw={"width_ratios":[1,]*num_classes + [0.05]})
fig.suptitle(f"Heatmaps for WSI {wsi_id} with {tumor_hist_type} Tumor Type", fontsize=16)

titles = [f"{i}" for i in label_map.keys()]

heatmaps_list = [heatmaps[:, :, i] for i in range(heatmaps.shape[2])]

# Plot heatmaps
for ax, heatmap, title in zip(axes[:-1], heatmaps_list, titles):  # Exclude the last axis for the colorbar
    im = ax.imshow(heatmap.squeeze(), cmap="jet", vmin=vmin, vmax=vmax)  # Use the same vmin and vmax
    ax.set_title(title)
    ax.axis("off")

# Add a single colorbar in the last axis
cbar = fig.colorbar(im, cax=axes[-1], orientation="vertical")
cbar.set_label("Heatmap Intensity")

# Show the plot
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(thumbnail)

In [None]:
def plot_heatmaps_with_thumbnail(heatmaps, thumbnail, wsi_id, tumor_hist_type, label_map):
    # Normalize all heatmaps to the same scale
    vmin = np.min(heatmaps)
    vmax = np.max(heatmaps)
    num_classes = heatmaps.shape[2]

    # Create a figure with a grid layout
    fig = plt.figure(figsize=(15, 15))  # Adjusted height to accommodate the large thumbnail
    grid = plt.GridSpec(2, num_classes + 1, height_ratios=[1, 2], hspace=0.3, wspace=0.3)

    # Plot heatmaps in the first row
    heatmaps_list = [heatmaps[:, :, i] for i in range(heatmaps.shape[2])]
    titles = list(label_map.keys())

    for i, (heatmap, title) in enumerate(zip(heatmaps_list, titles)):
        ax = fig.add_subplot(grid[0, i])
        im = ax.imshow(heatmap.squeeze(), cmap="jet", vmin=vmin, vmax=vmax)
        ax.set_title(title, fontsize=10)
        ax.axis("off")

    # Add a single colorbar in the last column of the first row
    cbar_ax = fig.add_subplot(grid[0, -1])
    cbar = fig.colorbar(im, cax=cbar_ax, orientation="vertical")
    cbar.set_label("Heatmap Intensity", fontsize=10)

    # Plot the thumbnail in the second row spanning all columns
    thumbnail_ax = fig.add_subplot(grid[1, :])
    thumbnail_ax.imshow(thumbnail)
    thumbnail_ax.set_title("Thumbnail", fontsize=12)
    thumbnail_ax.axis("off")

    # Add a main title
    plt.suptitle(f"Heatmaps for WSI {wsi_id} with {tumor_hist_type} Tumor Type", fontsize=16)

    # Show the plot
    plt.show()

In [None]:
plot_heatmaps_with_thumbnail(heatmaps, thumbnail, wsi_id, tumor_hist_type, label_map)