In [None]:
import h5py
import pandas as pd
from openslide import OpenSlide
import numpy as np
import matplotlib.pyplot as plt

from histopreprocessing.features.utils import load_hdf5

In [None]:
wsi_id = "C3N-02150-21"

In [None]:
metadata = pd.read_csv("/mnt/nas6/data/CPTAC/TCIA_CPTAC_LUAD_Pathology_Data_Table.csv").set_index("Slide_ID")

In [None]:
print(f"Specimen Type: {metadata.loc[wsi_id, 'Specimen_Type']}")
print(f"Tumor Histological Type: {metadata.loc[wsi_id, 'Tumor_Histological_Type']}")

In [None]:
result = load_hdf5(f"/home/valentin/workspaces/histopatseg/data/processed/embeddings/UNI2/cptac_luad/{wsi_id}.h5")
print(f"Base Magnification: {result['global_attributes']['base_magnification']}")
print(f"Tile Magnification: {result['global_attributes']['tile_magnification']}")

In [None]:
wsi = OpenSlide(f"/mnt/nas6/data/CPTAC/CPTAC-LUAD_v12/LUAD/{wsi_id}.svs")

In [None]:
embeddings = result["embeddings"]
coordinates = result["coordinates"]

In [None]:
embeddings.shape

In [None]:
prototype_h5 = "/home/valentin/workspaces/histopatseg/data/processed/prototypes_tcga_ut/uni2_prototypes__n_wsi_32_wo_normal_precentercrop.h5"

In [None]:
with h5py.File(prototype_h5, "r") as f:
    # Read the mean_embedding dataset
    mean_embedding = f["mean_embedding"][:]
    print("Mean Embedding:", mean_embedding)
    
    # Iterate through the cancer prototypes
    cancer_prototypes = {}
    for cancer_name in f.keys():
        if cancer_name != "mean_embedding":  # Skip the mean_embedding dataset
            cancer_prototypes[cancer_name] = f[cancer_name][:]
            print(f"Cancer Name: {cancer_name}, Prototype: {cancer_prototypes[cancer_name]}")

In [None]:
def compute_similarity(embedding, prototype, mean_embedding, distance_metric="cosine"):
    """
    Compute the similarity between an embedding and a prototype.
    """
    # Normalize the vectors
    emb = embedding - mean_embedding
    emb = emb / np.linalg.norm(emb, axis=1, keepdims=True)
    # proto = prototype - mean_embedding
    # proto = proto / np.linalg.norm(prototype, keepdims=True)
    # proto = prototype
    # emb = embedding
    proto = prototype
    
    # Compute the cosine similarity
    if distance_metric == "cosine":
        similarity = np.dot(emb, proto)
    elif distance_metric == "euclidean":
        similarity = -np.linalg.norm(emb - proto, axis=1)
    
    return similarity

In [None]:
def simple_shot_classification(embeddings, prototypes, mean_embedding, top_k=5):
    """
    Perform simple shot classification using Euclidean distances and majority voting.
    
    :param embeddings: np.ndarray, shape (n_tiles, embedding_dim)
        The embeddings for the tiles.
    :param prototypes: dict
        A dictionary where keys are cancer names and values are prototype vectors.
    :param mean_embedding: np.ndarray, shape (embedding_dim,)
        The mean embedding to normalize the embeddings.
    :param top_k: int
        The number of closest tiles to consider for majority voting.
    :return: str
        The predicted cancer type based on majority voting.
    """
    # Normalize the embeddings
    emb = embeddings - mean_embedding
    emb = emb / np.linalg.norm(emb, axis=1, keepdims=True)

    # Normalize the prototypes
    # normalized_prototypes = {
    #     cancer_name: prototype / np.linalg.norm(prototype)
    #     # cancer_name: prototype 
    #     for cancer_name, prototype in prototypes.items()
    # }

    # Compute distances for each tile to each prototype
    tile_distances = []
    for cancer_name, prototype in prototypes.items():
        distances = np.linalg.norm(emb - prototype, axis=1)  # Compute Euclidean distance
        tile_distances.append((cancer_name, distances))

    # Flatten distances into a list of (tile_index, cancer_name, distance)
    flat_distances = []
    for cancer_name, distances in tile_distances:
        for tile_index, distance in enumerate(distances):
            flat_distances.append((tile_index, cancer_name, distance))

    # Sort by distance and select the top-k tiles
    flat_distances = sorted(flat_distances, key=lambda x: x[2])[:top_k]

    # Perform majority voting among the top-k tiles
    cancer_votes = [cancer_name for _, cancer_name, _ in flat_distances]
    majority_vote = max(set(cancer_votes), key=cancer_votes.count)

    return majority_vote

In [None]:
distance_metric = "euclidean"
mean_embedding_test = np.mean(embeddings, axis=0)
scores_luad = compute_similarity(embeddings, cancer_prototypes["Lung_adenocarcinoma"], mean_embedding, distance_metric=distance_metric)
scores_lusc = compute_similarity(embeddings, cancer_prototypes["Lung_squamous_cell_carcinoma"], mean_embedding, distance_metric=distance_metric)
# scores_normal = compute_similarity(embeddings, cancer_prototypes["Lung_normal"], mean_embedding, distance_metric=distance_metric)

In [None]:
simple_shot_classification(
    embeddings,
    cancer_prototypes,
    mean_embedding_test,
    top_k=5
)

In [None]:
def compute_heatmap(wsi, coordinates, scores, scores_level=0, tile_size=224, heatmap_level=-1, return_thumbnail=False, rescale=False):
    """
    Compute a heatmap from the coordinates and scores.
    
    :param wsi: OpenSlide object
        The whole slide image object.
    :param coordinates: np.ndarray, shape (n_tiles, 2)
        The (x, y) coordinates of the tiles at level 0.
    :param scores: np.ndarray, shape (n_tiles,)
        The scores for each tile.
    :param scores_level: int
        The level at which the tile size was defined.
    :param tile_size: int
        The size of the tiles in pixels at the scores_level.
    :param heatmap_level: int
        The level at which to generate the heatmap (-1 means the lowest resolution).
    :param return_thumbnail: bool
        Whether to return the thumbnail of the WSI along with the heatmap.
    :return: np.ndarray
        The heatmap at the specified level.
    """
    # Get the dimensions of the heatmap level
    level_dimensions = wsi.level_dimensions[heatmap_level]
    
    # Create an empty heatmap
    heatmap = np.zeros(level_dimensions[::-1], dtype=np.float32)  # Shape should be (height, width)
    
    # Get the downsample factors
    downsample_to_heatmap = wsi.level_downsamples[heatmap_level]  # From level 0 to heatmap level
    downsample_to_base = wsi.level_downsamples[scores_level]  # From scores_level to level 0
    
    # Scale the tile size to the heatmap level
    tile_size_base = tile_size * downsample_to_base  # Upscale tile size to level 0
    tile_size_heatmap = np.round(tile_size_base / downsample_to_heatmap).astype(int) # Downscale to heatmap level
    
    # Scale the coordinates to the heatmap level
    coordinates_heatmap = np.round(coordinates / downsample_to_heatmap).astype(int)

    if rescale:
        scores = (2 * scores - np.min(scores) - np.max(scores)) / (np.max(scores) - np.min(scores))
    
    # Loop over the coordinates and scores
    for i, (x, y) in enumerate(coordinates_heatmap):  # Note: OpenSlide uses (x, y)
        # Get the score for the current tile
        score = scores[i]
                # Check if the tile is within bounds
        if (
            x < 0 or y < 0 or
            x + tile_size_heatmap > heatmap.shape[1] or
            y + tile_size_heatmap > heatmap.shape[0]
        ):
            raise ValueError(f"Tile at ({x}, {y}) with size {tile_size_heatmap} is out of bounds for heatmap of shape {heatmap.shape}")
     
        # Add the score to the heatmap
        heatmap[y:y + tile_size_heatmap, x:x + tile_size_heatmap] = score  # NumPy uses (row, column)
    
    if return_thumbnail:
        return heatmap, wsi.get_thumbnail(level_dimensions)
    
    return heatmap

In [None]:
import cv2

def compute_heatmap_optimized(wsi, coordinates, scores, tile_size=224, tile_level=0, rescale=False):

    downsample_to_base = wsi.level_downsamples[tile_level]  # From scores_level to level 0

    wsi_dimensions = wsi.level_dimensions[0]
    downsample = downsample_to_base * tile_size
    heatmap_height = np.round(wsi_dimensions[0] / downsample).astype(int)
    heatmap_width = np.round(wsi_dimensions[1] / downsample).astype(int)
    heatmap = np.zeros((heatmap_width, heatmap_height), dtype=np.float32)  # Shape should be (height, width)

    # Rescale scores if needed
    if rescale:
        scores = (2 * scores - np.min(scores) - np.max(scores)) / (np.max(scores) - np.min(scores))

    # Populate the heatmap
    for i, (x, y) in enumerate(coordinates):
        grid_x = np.round(x / downsample).astype(int)
        grid_y = np.round(y / downsample).astype(int)
        heatmap[grid_y, grid_x] = scores[i]

    # Upscale the heatmap to match the thumbnail size
    thumbnail_size = wsi.level_dimensions[-1]  # (height, width)
    heatmap_upscaled = cv2.resize(heatmap, thumbnail_size, interpolation=cv2.INTER_LINEAR)
    thumbnail = wsi.get_thumbnail(thumbnail_size)

    return heatmap_upscaled, thumbnail

In [None]:
# heat_map, thumbnail = compute_heatmap(wsi, coordinates, scores_luad, scores_level=0, tile_size=224, heatmap_level=-1, return_thumbnail=True, rescale=True)
heatmap, thumbnail = compute_heatmap_optimized(wsi, coordinates, scores_lusc, tile_size=224, tile_level=0, rescale=True)

In [None]:

plt.figure(figsize=(20, 20))
plt.imshow(heatmap, cmap="bwr")
plt.colorbar()

In [None]:
plt.imshow(thumbnail)