In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from dotenv import load_dotenv

load_dotenv()
os.chdir("..")

DATA_DIR = os.getenv("DATA_DIR")
OUTPUT_DIR = os.getenv("OUTPUT_DIR")

In [None]:
# load the labels data with folds
import numpy as np

from utils.load_data import load_data
from data_models.Label import Label

label_path = os.path.join(DATA_DIR, "labels/labels.csv")
embedding_path = os.path.join(
    OUTPUT_DIR, "prism/slide_embeddings/prism_slide_embeds_perceiver.pkl"
)

df = load_data(label_path=label_path, embedding_path=embedding_path)
specimens = list(df["specimen_id"].unique())
slides = list(df.reset_index()["slide_id"].unique())

In [None]:
# map specimen id to a list of WSIs
slides_by_specimen = df.groupby("specimen_id").groups
slides_by_specimen = {k: list(v) for k, v in slides_by_specimen.items()}

In [None]:
df = (
    df.reset_index()
    .drop(columns=["slide_id"])
    .drop_duplicates(subset=["specimen_id"])
    .set_index("specimen_id")
)
labels_onehot = df[Label._member_names_].to_dict(orient="split", index=True)
labels_onehot = {
    k: np.array(labels_onehot["data"][i])
    for i, k in enumerate(labels_onehot["index"])
}
labels_dict = {row.name: int(row["label"]) for _, row in df.iterrows()}

In [None]:
from data_models.Label import Label

class_freqs = {
    label: df[label].value_counts(normalize=True).iloc[1]
    for label in Label._member_names_
}

In [None]:
import torch

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")

In [None]:
import pandas as pd
from torch.utils.data import DataLoader

from data_models.datasets import (
    SlideClassificationDataset,
    collate_slide_embeds,
)
from models.MLP import MLP
from models.utils.train import val_epoch


num_labels = 4
embed_dim = 1280
foundation_model = "prism"
aggregator = "PRISM"
best_fold = 3
model_name = (
    f"chkpts/{foundation_model}/"
    f"{foundation_model}-{aggregator}-fold-{best_fold}.pt"
)
save_fname = os.path.join(
    OUTPUT_DIR,
    f"{foundation_model}/preds/"
    f"{foundation_model}-{aggregator}-fold-{best_fold}.csv",
)

model = MLP(embed_dim, [1024, 512, 256], num_labels).to(device)
model.load_state_dict(
    torch.load(os.path.join(OUTPUT_DIR, model_name), weights_only=True)
)
dl = DataLoader(
    SlideClassificationDataset(embedding_path, slides, labels_dict),
    batch_size=1,
    shuffle=False,
    collate_fn=collate_slide_embeds,
)

_, ground_truth, probs, ids = val_epoch(
    model=model,
    dataloader=dl,
    device=device,
    input_keys=["slide_embed"],
    label_key="label",
)
probs = probs.transpose(0, 1)
probs = {k: probs[i].tolist() for i, k in enumerate(Label._member_names_)}

pd.DataFrame(
    {"id": ids, "ground_truth": ground_truth.tolist()} | probs
).to_csv(save_fname, index=False)