In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from dotenv import load_dotenv

load_dotenv()
os.chdir("..")

DATA_DIR = os.getenv("DATA_DIR")
OUTPUT_DIR = os.getenv("OUTPUT_DIR")

In [None]:
# load the labels data with folds
import numpy as np

from utils.load_data import load_data
from data_models.Label import Label

label_path = os.path.join(DATA_DIR, "labels/labels.csv")

df = load_data(label_path=label_path)
df = df.set_index("specimen_id")
labels_onehot = df[Label._member_names_].to_dict(orient="split", index=True)
labels_onehot = {
    k: np.array(labels_onehot["data"][i])
    for i, k in enumerate(labels_onehot["index"])
}
labels_dict = {row.name: int(row["label"]) for _, row in df.iterrows()}

In [None]:
foundation_model = "gigapath"

# get the absolute path for each slide's set of tile embeddings
tile_embed_dir = f"/opt/gpudata/skin-cancer/outputs/{foundation_model}/tile_embeddings_sorted"
fnames = os.listdir(tile_embed_dir)
tile_embed_paths = [
    os.path.join(tile_embed_dir, fname)
    for fname in fnames
    if fname.endswith(".pkl") and fname[:6] in set(df.index)
]

In [None]:
# map specimens to slides
slides_by_specimen = {spec: [] for spec in list(df.index)}
for slide in tile_embed_paths:
    slide_name = os.path.basename(slide)[:-4]
    spec = slide_name[:6]
    if slides_by_specimen.get(spec) is not None:
        slides_by_specimen[spec].append(slide)

In [None]:
from data_models.Label import Label

class_freqs = {
    label: df[label].value_counts(normalize=True).iloc[1]
    for label in Label._member_names_
}

In [None]:
import torch

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")

In [None]:
import pandas as pd
from torch.utils.data import DataLoader

from data_models.datasets import SlideEncodingDataset, collate_tile_embeds
from models.agg import MILClassifier
from models.utils.train import val_epoch


num_labels = 4
embed_dim = 1536
foundation_model = "gigapath"
aggregator = "abmil-1_head"
best_fold = 1
model_name = (
    f"chkpts/{foundation_model}/"
    f"{foundation_model}-{aggregator}-fold-{best_fold}.pt"
)
save_fname = os.path.join(
    OUTPUT_DIR,
    f"{foundation_model}/preds/"
    f"{foundation_model}-{aggregator}-fold-{best_fold}.csv",
)

model = MILClassifier(embed_dim, num_labels, 1, False).to(device)
model.load_state_dict(
    torch.load(os.path.join(OUTPUT_DIR, model_name), weights_only=True)
)

dl = DataLoader(
    SlideEncodingDataset(tile_embed_paths, labels_dict),
    batch_size=1,
    shuffle=False,
    collate_fn=collate_tile_embeds,
)

_, ground_truth, probs, ids = val_epoch(
    model=model,
    dataloader=dl,
    device=device,
    input_keys=["tile_embeds", "pos"],
    label_key="label",
)
probs = probs.transpose(0, 1)
probs = {k: probs[i].tolist() for i, k in enumerate(Label._member_names_)}

pd.DataFrame(
    {"id": ids, "ground_truth": ground_truth.tolist()} | probs
).to_csv(save_fname, index=False)