In [None]:
from pathlib import Path

import h5py
import pandas as pd
from openslide import OpenSlide
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torch
import cv2
from torch.nn.functional import normalize
from PIL import Image
from histopreprocessing.features.foundation_models import load_model

from histopatseg.fewshot.protonet import ProtoNet, prototype_topk_vote
from histopatseg.fewshot.extract_patch_features import extract_patch_features_from_dataloader
from histopatseg.data.compute_embeddings_tcga_ut import load_hdf5


In [None]:
# label_map = {"Lung_adenocarcinoma": 0,"Lung_squamous_cell_carcinoma": 1, "Lung_squamous_cell_":1}
label_map = {"Lung_adenocarcinoma": 0,"Lung_squamous_cell_carcinoma": 1}

In [None]:
hdf5_attributes = load_hdf5("../data/processed/embeddings/tcga_ut/UNI2_precentercrop.h5")

In [None]:
hdf5_attributes.keys()

In [None]:
print(hdf5_attributes["global_attributes"]["trainsform"])

In [None]:
hdf5_attributes["datasets"].keys()

In [None]:
embeddings = hdf5_attributes["datasets"]["embeddings"]
labels = hdf5_attributes["datasets"]["labels"]
labels = [label.decode("utf-8") for label in labels]
wsi_ids = hdf5_attributes["datasets"]["wsi_ids"]
wsi_ids = [Path(wsi_id.decode("utf-8")).name for wsi_id in wsi_ids]
image_ids = hdf5_attributes["datasets"]["image_ids"]
image_ids = [image_id.decode("utf-8") for image_id in image_ids]

In [None]:
metadata_df = pd.DataFrame(
    {
        "label": labels,
        "wsi_id": wsi_ids,
        "image_ids": image_ids,
        "mpp": [float(image_name.split("/")[-1].split("_")[-1]) / 1000 for image_name in image_ids],
        "embeddings": list(embeddings),  # Add embeddings as a column
        "numeric_label": [label_map[label] for label in labels],
    },
).set_index("image_ids")

In [None]:
metadata_df["mpp"].unique()

In [None]:
filtered_df = metadata_df[(metadata_df["mpp"] >= 0.45) & (metadata_df["mpp"] <= 0.55)]
filtered_df.shape

In [None]:
n_wsi = 32

In [None]:
# Initialize an empty list to store the selected rows
selected_rows = []

# Group by class labels
for label, group in filtered_df.groupby("label"):
    # Get unique WSI IDs for the current class
    unique_wsi_ids = group["wsi_id"].unique()
    
    # Randomly shuffle the WSI IDs
    np.random.shuffle(unique_wsi_ids)
    print(f"label: {label}, n_wsi: {len(unique_wsi_ids)}")
    
    # Select up to n_wsi WSI IDs
    selected_wsi_ids = unique_wsi_ids[:n_wsi]
    
    # Filter rows corresponding to the selected WSI IDs
    selected_rows.append(group[group["wsi_id"].isin(selected_wsi_ids)])

# Concatenate the selected rows into a single DataFrame
result_df = pd.concat(selected_rows)

In [None]:
result_df["wsi_id"].nunique(), result_df["label"].nunique(), result_df.shape[0]

In [None]:
# Extract embeddings and labels
embeddings_train = np.stack(result_df["embeddings"].values)
labels_train = result_df["numeric_label"].values


In [None]:
protonet = ProtoNet()
protonet.fit(
    torch.tensor(embeddings_train, dtype=torch.float32),
    torch.tensor(labels_train, dtype=torch.long),
)

In [None]:
wsi_id = "C3N-00167-21"

In [None]:
metadata = pd.read_csv("/mnt/nas6/data/CPTAC/TCIA_CPTAC_LUAD_Pathology_Data_Table.csv").set_index("Slide_ID")

In [None]:
print(f"Specimen Type: {metadata.loc[wsi_id, 'Specimen_Type']}")
print(f"Tumor Histological Type: {metadata.loc[wsi_id, 'Tumor_Histological_Type']}")

In [None]:
# result = load_hdf5(f"/home/valentin/workspaces/histopatseg/data/processed/embeddings/UNI2/cptac_luad/{wsi_id}.h5")
# print(f"Base Magnification: {result['global_attributes']['base_magnification']}")
# print(f"Tile Magnification: {result['global_attributes']['tile_magnification']}")

In [None]:
result = load_hdf5(f"../data/processed/mahmoodlab/UNI2-h_features/CPTAC/{wsi_id}.h5")

In [None]:
wsi = OpenSlide(f"/mnt/nas6/data/CPTAC/CPTAC-LUAD_v12/LUAD/{wsi_id}.svs")

In [None]:
print(f"mpp x : {wsi.properties.get('openslide.mpp-x', 'nan')}")
print(f"mpp y : {wsi.properties.get('openslide.mpp-y', 'nan')}")

In [None]:
result["datasets"].keys()

In [None]:
result["datasets"]["coords"][:].shape

In [None]:
# embeddings = result["datasets"]["embeddings"]
# coordinates = result["datasets"]["coordinates"]
embeddings = np.squeeze(result["datasets"]["features"])
coordinates = np.squeeze(result["datasets"]["coords"])

In [None]:
coordinates[1,:] - coordinates[0,:]

In [None]:
coordinates.shape, embeddings.shape

In [None]:
def compute_distances(embeddings, prototype_embeddings, mean_embedding):
    """
    Compute the similarity between an embedding and a prototype.
    """
    # Normalize the vectors
    
    feats_query = embeddings
    feats_query = feats_query - mean_embedding
    feats_query = normalize(feats_query, dim=-1, p=2)
    feats_query = feats_query[:, None]  # [N x 1 x D]
    proto_embeddings = prototype_embeddings[None, :]  # [1 x C x D]
    pw_dist = (feats_query - proto_embeddings).norm(
            dim=-1, p=2
    )  # [N x C ]
    
    
    return pw_dist

In [None]:
distances = compute_distances(torch.tensor(embeddings, dtype=torch.float32), protonet.prototype_embeddings, protonet.mean).numpy()

In [None]:
prototype_topk_vote(protonet, torch.tensor(embeddings, dtype=torch.float32), topk=5)

In [None]:
distances.shape

In [None]:
wsi.level_dimensions[0]

In [None]:
wsi.level_dimensions[0][0] / wsi.level_dimensions[-1][0] 

In [None]:
def compute_heatmap_optimized(wsi, coordinates, scores, tile_size=224, tile_level=0, rescale=False):

    downsample_to_base = wsi.level_downsamples[tile_level]  # From scores_level to level 0

    wsi_dimensions = wsi.level_dimensions[0]
    downsample = downsample_to_base * tile_size
    heatmap_height = np.round(wsi_dimensions[0] / downsample).astype(int)
    heatmap_width = np.round(wsi_dimensions[1] / downsample).astype(int)
    heatmap = np.zeros((heatmap_width, heatmap_height), dtype=np.float32)  # Shape should be (height, width)

    # Rescale scores if needed
    if rescale:
        scores = (2 * scores - np.min(scores) - np.max(scores)) / (np.max(scores) - np.min(scores))

    # Populate the heatmap
    for i, (x, y) in enumerate(coordinates):
        grid_x = np.round(x / downsample).astype(int)
        grid_y = np.round(y / downsample).astype(int)
        heatmap[grid_y, grid_x] = scores[i]

    # Upscale the heatmap to match the thumbnail size
    thumbnail_size = wsi.level_dimensions[-1]  # (height, width)
    heatmap_upscaled = cv2.resize(heatmap, thumbnail_size, interpolation=cv2.INTER_LINEAR)
    thumbnail = wsi.get_thumbnail(thumbnail_size)

    return heatmap_upscaled, thumbnail

In [None]:
heatmap, thumbnail = compute_heatmap_optimized(wsi, coordinates, -distances[:, 0], tile_size=256, tile_level=0, rescale=True)

In [None]:
# Ensure the heatmap and thumbnail are the same size
heatmap_rescaled = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(thumbnail.size, Image.BICUBIC)

# Convert the heatmap to RGBA for transparency
heatmap_rgba = heatmap_rescaled.convert("RGBA")
heatmap_array = np.array(heatmap_rgba)
heatmap_array[..., 3] = (heatmap_array[..., 0] * 0.5).astype(np.uint8)  # Adjust alpha for transparency
heatmap_rgba = Image.fromarray(heatmap_array)

# Overlay the heatmap on the thumbnail
overlay = Image.alpha_composite(thumbnail.convert("RGBA"), heatmap_rgba)

# Display the result
plt.figure(figsize=(10, 10))
plt.imshow(overlay)
plt.axis("off")
plt.show()

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(heatmap, cmap="bwr")
plt.colorbar()

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(thumbnail)