In [None]:
%load_ext autoreload
%autoreload 2

### Extract attn weights from the models - use `uni` env

In [None]:
import os
from dotenv import load_dotenv

load_dotenv()
os.chdir("..")

DATA_DIR = os.getenv("DATA_DIR")
OUTPUT_DIR = os.getenv("OUTPUT_DIR")

In [None]:
import pickle

import torch
from transformers import AutoModel

from models.agg import MILClassifier

In [None]:
gigapath_best_model = "/opt/gpudata/skin-cancer/outputs/chkpts/gigapath/gigapath-abmil-1_head-fold-1.pt"
uni_best_model = (
    "/opt/gpudata/skin-cancer/outputs/chkpts/uni/uni-abmil-1_head-fold-3.pt"
)
prism_local_model = "/opt/gpudata/skin-cancer/models/models--paige-ai--Prism/snapshots/cd2eae7b1e6e51f3664e1a575c5bfe7045cc37d4"
hf_cache_dir = "/opt/gpudata/skin-cancer/models"
tile_embeddings_dir = "/opt/gpudata/skin-cancer/outputs/{foundation_model}/tile_embeddings_sorted"

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_model(model_name):
    if model_name == "uni":
        model = MILClassifier(1024, 4, 1, False).to(device)
        model.load_state_dict(torch.load(uni_best_model))
    elif model_name == "gigapath":
        model = MILClassifier(1536, 4, 1, False).to(device)
        model.load_state_dict(torch.load(gigapath_best_model))
    else:
        model = AutoModel.from_pretrained(
            prism_local_model,
            cache_dir=hf_cache_dir,
            trust_remote_code=True,
        )
    return model

In [None]:
def get_weights(model, model_name, slide_id):
    with open(
        os.path.join(
            tile_embeddings_dir.format(foundation_model=model_name),
            f"{slide_id}.pkl",
        ),
        "rb",
    ) as f:
        tile_embeddings = pickle.load(f)

    tile_coords = []
    for coords in tile_embeddings["coords"]:
        tile_coords.append({"x": int(coords[0]), "y": int(coords[1])})

    model.to(device)

    if model_name == "prism":
        with torch.autocast("cuda", torch.float16), torch.inference_mode():
            out = model.slide_representations(
                tile_embeddings=tile_embeddings["tile_embeds"]
                .unsqueeze(0)
                .to(device)
            )
            weights = (
                out["xattn_weights"].squeeze().max(dim=0).values.detach().cpu()
            ).numpy()
    else:
        with torch.inference_mode():
            _, weights = model(
                tile_embeddings["tile_embeds"].unsqueeze(0).to(device),
                tile_embeddings["pos"].unsqueeze(0).to(device),
            )
            weights = weights.squeeze().detach().cpu().numpy()

    return weights, tile_coords

In [None]:
# slide_ids = ["660529-6", "660370-5"] # bowens
# slide_ids = ["660042-3", "660506-6"]  # bcc
# slide_ids = ["660058-1"] # scc
slide_ids = [
    "660524-4",
    "660192-4",
    "660465-4",
    "660192-1",
    "660535-4",
    "660415-3",
    "660152-2",
    "660447-2",
    "660122-1",
    "660032-3",
]  # bowens round 2
models = ["prism", "uni", "gigapath"]

for m in models:
    for slide_id in slide_ids:
        model = load_model(m)
        weights, coords = get_weights(model, m, slide_id)
        with open(f"attn_comps/weights/{m}-{slide_id}-weights.pkl", "wb") as f:
            pickle.dump(
                {
                    "weights": weights,
                    "coords": coords,
                },
                f,
            )

### Plot the heatmaps - use `gigapath` env

In [None]:
import pickle

from gigapath.preprocessing.data.foreground_segmentation import LoadROId
from matplotlib import collections, patches, pyplot as plt
from monai.data.wsi_reader import WSIReader
import numpy as np

In [None]:
def load_slide(id):
    sample = {
        "image": f"/opt/gpudata/skin-cancer/data/slides/{id}.svs",
        "slide_id": id,
    }
    loader = LoadROId(
        WSIReader(backend="OpenSlide"),
        level=0,
        margin=0,
        foreground_threshold=None,
    )
    sample = loader(sample)
    return sample

In [None]:
def plot_image(model_name, slide_id):
    sample = load_slide(slide_id)
    with open(
        f"attn_comps/weights/{model_name}-{slide_id}-weights-norm.pkl", "rb"
    ) as f:
        data = pickle.load(f)

    weights = data["weights"]
    tile_coords = data["coords"]

    # scaled_weights = np.log(weights + 1e-6)
    slide_image = sample["image"]
    downscale_factor = sample["scale"]

    _, ax = plt.subplots(figsize=(15, 5))
    ax.imshow(slide_image.transpose(1, 2, 0))
    rects = []
    for tile_info in tile_coords:
        # change coordinate to the current level from level-0
        # tile location is in the original image cooridnate, while the slide image is after selecting ROI
        xy = (
            (tile_info["x"] - sample["origin"][0]) / downscale_factor,
            (tile_info["y"] - sample["origin"][1]) / downscale_factor,
        )
        rects.append(patches.Rectangle(xy, 256, 256))
    pc = collections.PatchCollection(rects, alpha=1, cmap="inferno")
    pc.set_array(weights)
    pc.set_clim(vmin=vmin[slide_id], vmax=vmax[slide_id])
    ax.add_collection(pc)
    plt.axis("off")
    plt.colorbar(pc, ax=ax)
    plt.title(f"{model_name}, slide {slide_id}")
    plt.savefig(f"attn_comps/{model_name}-{slide_id}.png", dpi=300)
    plt.close()

In [None]:
# slide_ids = ["660529-6", "660370-5"] # bowens
# slide_ids = ["660042-3", "660506-6"]  # bcc
# slide_ids = ["660058-1"] # scc
slide_ids = [
    "660524-4",
    "660192-4",
    # "660465-4",
    # "660192-1",
    # "660535-4",
    # "660415-3",
    # "660152-2",
    # "660447-2",
    # "660122-1",
    # "660032-3",
]  # bowens round 2
models = ["prism", "uni", "gigapath"]

vmin = {}
vmax = {}
coords = None
for slide_id in slide_ids:
    weights_dict = {}
    weights = []
    for m in models:
        with open(f"attn_comps/weights/{m}-{slide_id}-weights.pkl", "rb") as f:
            data = pickle.load(f)
            coords = data["coords"]

            # clip weights to control for outliers
            min_threshold = np.percentile(data["weights"], 1)
            max_threshold = np.percentile(data["weights"], 99)
            clipped_scores = np.clip(
                data["weights"], min_threshold, max_threshold
            )

            weights_dict[m] = clipped_scores
            weights.append(clipped_scores)

    concat_weights = np.concatenate(weights)
    mean = concat_weights.mean()
    std = concat_weights.std()
    normalized = (concat_weights - mean) / std
    weights_dict = {k: (v - mean) / std for k, v in weights_dict.items()}
    vmin[slide_id] = normalized.min()
    vmax[slide_id] = normalized.max()

    for m in models:
        with open(
            f"attn_comps/weights/{m}-{slide_id}-weights-norm.pkl", "wb"
        ) as f:
            pickle.dump(
                {
                    "weights": weights_dict[m],
                    "coords": coords,
                },
                f,
            )


for m in models:
    for slide_id in slide_ids:
        print(f"processing {slide_id} for {m}")
        plot_image(m, slide_id)