In [None]:
from pathlib import Path

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from histoseg_plugin.storage.factory import build_embedding_store_from_dir
import openslide
from scipy.ndimage import affine_transform
import pandas as pd

from luadseg.eval.anorak import LinearSoftmaxHead
from luadseg.data.constants import ANORAK_CLASS_MAPPING

In [None]:
dhmc_metadata = pd.read_csv("/home/valentin/workspaces/luadseg/data/raw/DHMC/MetaData_Release_1.0.csv")

In [None]:
dhmc_metadata["slide_id"] = dhmc_metadata["File Name"].str.replace(".tif", "", regex=False)

In [None]:
dhmc_metadata

In [None]:
def get_class_from_slide_id(slide_id: str) -> str:
    row = dhmc_metadata[dhmc_metadata["slide_id"] == slide_id]
    assert len(row) == 1
    return row["Class"].values[0]

In [None]:
weights_path = "/home/valentin/workspaces/luadseg/mlflow/839028616188564930/f29f1b20c47f419c9ee99ac016ceb9c7/artifacts/heads/fold_0.pt"

In [None]:
linear_head_state_dict = torch.load(weights_path)

linear_head = LinearSoftmaxHead(in_dim=linear_head_state_dict["in_dim"],
                                n_classes=7)
linear_head.load_state_dict(linear_head_state_dict["state_dict"])
linear_head.eval()

In [None]:
linear_head_state_dict

In [None]:
slides_dir = "/home/valentin/workspaces/luadseg/data/raw/DHMC/DHMC_LUAD_corrected"
embedding_dir = "/home/valentin/workspaces/luadseg/data/embeds/DHMC/224_10x/uni2"
embedding_store = build_embedding_store_from_dir(slides_root=slides_dir,
                                                 root_dir=embedding_dir)


In [None]:
slides_id = embedding_store.slide_ids()
slide_id = "DHMC_0140"
get_class_from_slide_id(slide_id)

In [None]:
feats, coords, attrs = embedding_store.load(slide_id)

In [None]:
attrs

In [None]:
feats.shape

In [None]:
with torch.inference_mode():
    predictions = linear_head(torch.tensor(feats, dtype=torch.float32)).softmax(dim=-1).cpu().numpy()

In [None]:
predictions.shape

In [None]:
wsi_path = Path(slides_dir) / attrs['relative_wsi_path']
wsi = openslide.OpenSlide(str(wsi_path))

In [None]:
def compute_level_downsamples(wsi):
    """Return (x, y) downsample per level, matching CLAM behavior."""
    outs = []
    dim0 = wsi.level_dimensions[0]
    for ds, dim in zip(wsi.level_downsamples, wsi.level_dimensions):
        est = (dim0[0] / float(dim[0]), dim0[1] / float(dim[1]))
        outs.append(est if est != (ds, ds) else (ds, ds))
    return outs

In [None]:
wsi.level_downsamples
wsi.properties[openslide.PROPERTY_NAME_MPP_X]
wsi.get_best_level_for_downsample(20)

In [None]:
wsi.level_dimensions[4]

In [None]:

def compute_heatmaps(
    wsi,
    coords,                 # (N,2) top-left tile coords at patch_level
    predictions,            # (N, n_classes)
    patch_size=224,
    patch_level=0,
    n_classes=7,
    heatmap_level=4,
):
    coords = np.asarray(coords, dtype=np.float64)
    predictions = np.asarray(predictions, dtype=np.float32)
    if predictions.ndim == 1:
        predictions = predictions[:, None]

    # level-0 sizing
    level_downsample_patch = float(wsi.level_downsamples[patch_level])
    patch_size_level0 = patch_size * level_downsample_patch

    # 1) origin/max in level-0 coords (tight bbox around tiles, using top-left)
    min_x0 = float(coords[:, 0].min())
    min_y0 = float(coords[:, 1].min())
    max_x0 = float(coords[:, 0].max() + patch_size_level0)
    max_y0 = float(coords[:, 1].max() + patch_size_level0)

    origin_x0, origin_y0 = min_x0, min_y0

    # 2) grid size in “tile cells”
    grid_w = int(np.ceil((max_x0 - origin_x0) / patch_size_level0))
    grid_h = int(np.ceil((max_y0 - origin_y0) / patch_size_level0))

    # 3) instantiate heatmap grid (one cell per tile) + fill
    heatmaps = np.zeros((grid_h, grid_w, n_classes), dtype=np.float32)

    # map each tile to integer cell indices (use floor)
    x_idx = np.floor((coords[:, 0] - origin_x0) / patch_size_level0).astype(np.int64)
    y_idx = np.floor((coords[:, 1] - origin_y0) / patch_size_level0).astype(np.int64)

    # clip for safety
    x_idx = np.clip(x_idx, 0, grid_w - 1)
    y_idx = np.clip(y_idx, 0, grid_h - 1)

    for i in range(coords.shape[0]):
        heatmaps[y_idx[i], x_idx[i], :] = predictions[i, :]

    # 4) resample to the requested WSI level
    output_W, output_H = wsi.level_dimensions[heatmap_level]  # (W, H)
    d_out = float(wsi.level_downsamples[heatmap_level])       # level-0 px per output px

    # Build affine: output (row,col) -> input (row_in,col_in) on the tile grid
    scale = d_out / patch_size_level0  # how many tile-cells per output pixel
    A = np.array([[scale, 0.0],
                  [0.0,  scale]], dtype=np.float64)

    # offset is in input-index units (tile cells), note row=y, col=x ordering
    offset = np.array([
        -origin_y0 / patch_size_level0,  # row offset
        -origin_x0 / patch_size_level0,  # col offset
    ], dtype=np.float64)

    # apply per-channel
    out = np.zeros((output_H, output_W, n_classes), dtype=np.float32)
    for c in range(n_classes):
        out[:, :, c] = affine_transform(
            heatmaps[:, :, c],
            matrix=A,
            offset=offset,
            output_shape=(output_H, output_W),
            order=1,                # bilinear
            mode='constant',
            cval=0.0,
            prefilter=True,
        )

    return out


In [None]:
heatmaps = compute_heatmaps(
    wsi,
    coords,
    predictions,
    patch_size=attrs["patch_size"],
    patch_level=attrs["patch_level"],
    n_classes=7,
    heatmap_level=4,
)

In [None]:
# Plot all 7 classes with names and adjusted colormaps
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

# Get the thumbnail for reference
thumbnail = wsi.get_thumbnail((400, 400))

# Plot thumbnail in first subplot
axes[0].imshow(thumbnail)
axes[0].set_title('WSI Thumbnail')
axes[0].axis('off')

# Plot each class heatmap
for class_idx in range(7):
    class_name = ANORAK_CLASS_MAPPING[class_idx]
    
    # Use different colormaps for better distinction
    cmaps = ['Blues', 'Reds', 'Greens', 'Purples', 'Oranges', 'viridis', 'plasma']
    
    im = axes[class_idx + 1].imshow(heatmaps[..., class_idx], 
                                   cmap=cmaps[class_idx], 
                                   vmin=0, vmax=1)
    axes[class_idx + 1].set_title(f'Class {class_idx}: {class_name}')
    axes[class_idx + 1].axis('off')
    
    # Add colorbar for each subplot
    plt.colorbar(im, ax=axes[class_idx + 1], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()

In [None]:
get_class_from_slide_id(slide_id)