In [1]:
import h5py
# import openslide
import torch
import pandas as pd
from prototype_visualization_utils import get_panther_encoder, visualize_categorical_heatmap, get_mixture_plot, get_default_cmap

import sys
sys.path.append('../')
from mil_models.tokenizer import PrototypeTokenizer

In [2]:
### Loading PANTHER Encoder
proto_path = r'C:\Users\Vivian\Documents\PANTHER\PANTHER\src\splits\FA_PT_k=0\prototypes\prototypes_c16_uni_kmeans_num_1.0e+06.pkl' # 5x uni
# proto_path = r'C:\Users\Vivian\Documents\PANTHER\PANTHER\src\splits\FA_PT_10x_k=0\prototypes\prototypes_c16_uniextracted_mag10x_patch224_fp_kmeans_num_1.0e+06.pkl' # 10x uni
# proto_path = r'C:\Users\Vivian\Documents\PANTHER\PANTHER\src\splits\FA_PT_2.5x_k=0\prototypes\prototypes_c16_uniextracted_mag2.5x_patch224_fp_kmeans_num_1.0e+06.pkl' # 2.5x uni
# , model_config='PANTHER_fa_pt' out_type='allcat', 
panther_encoder = get_panther_encoder(in_dim=1024, p=16, proto_path=proto_path, config_dir=r'C:\Users\Vivian\Documents\PANTHER\PANTHER\src\configs', model_config='PANTHER_fa_pt', out_type='allcat')

📌 Loaded config path: C:\Users\Vivian\Documents\PANTHER\PANTHER\src\configs\PANTHER_fa_pt\config.json
📌 Loaded config_dict from JSON: {'in_dim': 1024, 'n_classes': 2, 'heads': 1, 'em_iter': 1, 'tau': 0.001, 'ot_eps': 0.1, 'n_fc_layers': 0, 'dropout': 0.25, 'out_type': 'allcat', 'out_size': 8, 'load_proto': False, 'proto_path': '.', 'fix_proto': False}


In [21]:
import os
import h5py
import torch
import numpy as np
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
from prototype_visualization_utils import get_default_cmap, get_mixture_plot  # make sure these exist

# Setup paths
feats_h5_dir = r'C:\Users\Vivian\Documents\PANTHER\PANTHER\features\test_slide'
save_root = r'C:\Users\Vivian\Documents\PANTHER\PANTHER\features\test_slide\visualizations'
os.makedirs(save_root, exist_ok=True)

# Settings
patch_size = 224
alpha = 1
scale = 0.25

# Loop through all h5 feature files
for fname in os.listdir(feats_h5_dir):
    if not fname.endswith('.h5'):
        continue

    slide_id = fname.replace('.h5', '').replace(' ', '_')
    h5_path = os.path.join(feats_h5_dir, fname)
    slide_dir = os.path.join(save_root, slide_id)
    os.makedirs(slide_dir, exist_ok=True)

    print(f"Processing {slide_id}...")

    # Load features and coords
    with h5py.File(h5_path, 'r') as h5:
        coords = h5['coords'][:]
        feats = torch.Tensor(h5['features'][:])

    # Inference
    with torch.inference_mode():
        info = panther_encoder.representation(feats.unsqueeze(0))
        qq = info['qq'][0, :, :, 0].cpu().numpy()
        out = info['repr']

        tokenizer = PrototypeTokenizer(out_type='allcat', p=qq.shape[1])
        mus, pis, sigmas = tokenizer.forward(out)
        mus = mus[0].detach().cpu().numpy()
        cluster_labels = qq.argmax(axis=1)

    # Normalize coordinates
    x_min, y_min = coords.min(axis=0)
    coords -= np.array([x_min, y_min])
    max_x, max_y = coords.max(axis=0) + patch_size
    canvas_w, canvas_h = int(max_x), int(max_y)

    # Draw prototype map
    canvas = Image.new('RGBA', (canvas_w, canvas_h), (255, 255, 255, 255))
    draw = ImageDraw.Draw(canvas, 'RGBA')
    cmap = get_default_cmap(int(cluster_labels.max()) + 1)

    for (x, y), label in zip(coords, cluster_labels):
        color = cmap[label]
        rgba = color + (int(255 * alpha),)
        draw.rectangle([x, y, x + patch_size, y + patch_size], fill=rgba)

    if scale < 1.0:
        canvas = canvas.resize((int(canvas_w * scale), int(canvas_h * scale)))
    canvas.save(os.path.join(slide_dir, 'prototype_map.png'))

    # Save mixture plot
    mixture_plot = get_mixture_plot(mus)
    mixture_plot.savefig(os.path.join(slide_dir, 'mixture_plot.png'), bbox_inches='tight')
    plt.close(mixture_plot)  # to avoid plot accumulation

print("✅ All visualizations saved.")


Processing FA_47_B1...
Processing FA_57B...
Processing PT_41_B...
✅ All visualizations saved.


In [18]:
def visualize_top_k_patches_per_prototype_from_qq(
    h5_feats_fpath,
    patch_dir,
    panther_encoder,
    top_k=4,
    patch_size=224,
    save_path=None
):
    import glob
    import matplotlib.pyplot as plt
    import numpy as np
    import torch
    import h5py
    import os

    # Load patch-level features
    with h5py.File(h5_feats_fpath, 'r') as h5:
        coords = h5['coords'][:]
        feats = torch.Tensor(h5['features'][:])

    # Get prototype assignments
    with torch.inference_mode():
        info = panther_encoder.representation(feats.unsqueeze(0))
        qq = info['qq'][0, :, :, 0].cpu().numpy()  # shape: (N, p)

    n_patches, n_protos = qq.shape

    # Select top-k patches per prototype (skip inactive prototypes)
    proto_to_indices = {}
    for proto in range(n_protos):
        proto_scores = qq[:, proto]
        if np.max(proto_scores) < 1e-5:  # Skip prototypes with no significant activation
            continue

        nonzero_indices = np.where(proto_scores > 1e-3)[0]
        if len(nonzero_indices) == 0:
            continue

        sorted_indices = nonzero_indices[np.argsort(-proto_scores[nonzero_indices])]
        proto_to_indices[proto] = sorted_indices[:top_k]

    active_protos = sorted(proto_to_indices.keys())
    if len(active_protos) == 0:
        print("⚠️ No active prototypes found in this slide.")
        return

    # Plotting
    fig, axs = plt.subplots(len(active_protos), top_k, figsize=(top_k * 3, len(active_protos) * 3))
    if len(active_protos) == 1:
        axs = np.expand_dims(axs, axis=0)

    for row_idx, proto_idx in enumerate(active_protos):
        indices = proto_to_indices[proto_idx]
        for col_idx in range(top_k):
            ax = axs[row_idx, col_idx]
            if col_idx >= len(indices):
                ax.axis('off')
                continue

            patch_idx = indices[col_idx]
            x, y = coords[patch_idx]

            pattern = f'*_x{int(x)}_y{int(y)}.npy'
            matches = glob.glob(os.path.join(patch_dir, pattern))
            if not matches:
                ax.axis('off')
                continue

            patch = np.load(matches[0])
            ax.imshow(patch.astype(np.uint8))
            if col_idx == 0:
                ax.set_title(f'Proto {proto_idx}')
            ax.axis('off')

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300)
        print(f"Saved top-k patches visualization to {save_path}")
    plt.close(fig)


In [19]:
import os
import pandas as pd
import torch

def run_all_patch_visualizations(
    feats_h5_dir,
    patch_root,
    csv_path,
    save_root,
    panther_encoder,
    top_k=4,
    patch_size=224
):
    """
    Automates prototype patch visualization across slides using Panther qq assignments.

    Args:
        feats_h5_dir (str): Path to the directory containing .h5 feature files.
        patch_root (str): Root path to the patch directory (organized by magnification/class).
        csv_path (str): Path to CSV file with columns ['Filename', 'Class', 'Magnification'].
        save_root (str): Directory where visualization images will be saved.
        panther_encoder (torch.nn.Module): Trained Panther encoder model.
        top_k (int): Number of top patches to visualize per prototype.
        patch_size (int): Size of each patch (default: 224).
    """
    os.makedirs(save_root, exist_ok=True)
    df = pd.read_csv(csv_path)
    df['Filename'] = df['Filename'].str.strip()

    for h5_file in os.listdir(feats_h5_dir):
        if not h5_file.endswith('.h5'):
            continue

        slide_id = h5_file.replace('.h5', '').strip()

        row = df[df['Filename'] == slide_id]
        if row.empty:
            print(f"⚠️ Slide {slide_id} not found in metadata CSV.")
            continue

        class_label = row['Class'].values[0]
        magnification = row['Magnification'].values[0]

        h5_feats_fpath = os.path.join(feats_h5_dir, h5_file)
        patch_dir = os.path.join(patch_root, f"{magnification}x", class_label, slide_id)

        if not os.path.isdir(patch_dir):
            print(f"❌ Patch directory not found: {patch_dir}")
            continue

        print(f"✅ Processing: {slide_id} | Class: {class_label} | Mag: {magnification}x")

        # Optional: organize outputs per slide
        slide_dir = os.path.join(save_root, slide_id.replace(' ', '_'))
        os.makedirs(slide_dir, exist_ok=True)
        save_path = os.path.join(slide_dir, 'top_patches.png')

        try:
            visualize_top_k_patches_per_prototype_from_qq(
                h5_feats_fpath=h5_feats_fpath,
                patch_dir=patch_dir,
                panther_encoder=panther_encoder,
                top_k=top_k,
                patch_size=patch_size,
                save_path=save_path
            )

        except Exception as e:
            print(f"❌ Error processing {slide_id}: {e}")


In [22]:
run_all_patch_visualizations(
    feats_h5_dir=r'C:\Users\Vivian\Documents\PANTHER\PANTHER\features\test_slide',
    patch_root=r'C:\Users\Vivian\Documents\CONCH\all_patches\patches_5x',
    csv_path=r'C:\Users\Vivian\Documents\PANTHER\PANTHER\src\visualization\slides_list.csv',
    save_root=r'C:\Users\Vivian\Documents\PANTHER\PANTHER\features\test_slide\visualizations',
    panther_encoder=panther_encoder,
    top_k=4
)


✅ Processing: FA 47 B1 | Class: FA | Mag: 40x
Saved top-k patches visualization to C:\Users\Vivian\Documents\PANTHER\PANTHER\features\test_slide\visualizations\FA_47_B1\top_patches.png
✅ Processing: FA 57B | Class: FA | Mag: 20x
Saved top-k patches visualization to C:\Users\Vivian\Documents\PANTHER\PANTHER\features\test_slide\visualizations\FA_57B\top_patches.png
✅ Processing: PT 41 B | Class: PT | Mag: 20x
Saved top-k patches visualization to C:\Users\Vivian\Documents\PANTHER\PANTHER\features\test_slide\visualizations\PT_41_B\top_patches.png
