In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from dotenv import load_dotenv

load_dotenv()
os.chdir("..")

DATA_DIR = os.getenv("DATA_DIR")
OUTPUT_DIR = os.getenv("OUTPUT_DIR")

In [None]:
fms = ["gigapath", "uni", "prism"]
MODEL_DIRS = [os.path.join(OUTPUT_DIR, f"chkpts/{fm}") for fm in fms]
model_files = [
    os.path.join(fm_dir, file)
    for fm_dir in MODEL_DIRS
    for file in os.listdir(fm_dir)
]

In [None]:
# load the labels data with folds
import numpy as np

from utils.load_data import load_data
from data_models.Label import Label

label_path = os.path.join(DATA_DIR, "labels/labels.csv")
fold_path = os.path.join(DATA_DIR, "folds.json")

df = load_data(label_path=label_path, fold_path=fold_path)
df = df.set_index("specimen_id")
labels_onehot = df[Label._member_names_].to_dict(orient="split", index=True)
labels_onehot = {
    k: np.array(labels_onehot["data"][i])
    for i, k in enumerate(labels_onehot["index"])
}
labels_dict = {row.name: int(row["label"]) for _, row in df.iterrows()}

In [None]:
tile_embed_dirs = {
    fm: os.path.join(OUTPUT_DIR, f"{fm}/tile_embeddings_sorted") for fm in fms
}
tile_embed_paths = {
    fm: [
        os.path.join(tile_embed_dirs[fm], fname)
        for fname in os.listdir(tile_embed_dirs[fm])
        if fname.endswith(".pkl") and fname[:6] in set(df.index)
    ]
    for fm in fms
}

In [None]:
# get list of specimens within each fold
specimens_by_fold = df.groupby("fold").groups
specimens_by_fold = [list(specs) for specs in specimens_by_fold.values()]

In [None]:
# map specimens to slides
slides_by_specimen = {fm: {spec: [] for spec in list(df.index)} for fm in fms}
for fm in fms:
    for slide in tile_embed_paths[fm]:
        slide_name = os.path.basename(slide)[:-4]
        spec = slide_name[:6]
        if slides_by_specimen[fm].get(spec) is not None:
            slides_by_specimen[fm][spec].append(slide)

In [None]:
from data_models.label import Label

class_freqs = {
    label: df[label].value_counts(normalize=True).iloc[1]
    for label in Label._member_names_
}

In [None]:
import torch

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")

In [None]:
from typing import Dict, List, Tuple

from torch.utils.data import DataLoader

from data_models.datasets import SlideEncodingDataset, collate_tile_embeds
from utils.split import train_val_split_slides, train_val_split_labels


def get_loaders(
    val_fold: int,
    specimens_by_fold: List[List[str]],
    slides_by_specimen: Dict[str, List[str]],
    labels_by_specimen: Dict[str, int],
) -> Tuple[DataLoader, DataLoader]:
    train, val = train_val_split_slides(
        val_fold=val_fold,
        specimens_by_fold=specimens_by_fold,
        slides_by_specimen=slides_by_specimen,
    )
    train_labels, val_labels = train_val_split_labels(
        val_fold=val_fold,
        labels_by_specimen=labels_by_specimen,
        specimens_by_fold=specimens_by_fold,
    )

    train_loader = DataLoader(
        SlideEncodingDataset(train, train_labels),
        batch_size=1,
        shuffle=True,
        collate_fn=collate_tile_embeds,
    )
    val_loader = DataLoader(
        SlideEncodingDataset(val, val_labels),
        batch_size=1,
        shuffle=False,
        collate_fn=collate_tile_embeds,
    )

    return train_loader, val_loader

In [None]:
from operator import itemgetter

from torch import nn
import pandas as pd

from models.utils.train import val_epoch
from evaluation.eval import get_spec_level_probs

from sklearn.metrics import roc_auc_score, average_precision_score
from models.agg import MILClassifier


NUM_LABELS = 4
EMBED_DIMS = {"gigapath": 1536, "uni": 1024, "prism": 2560}

auroc_keys = [k + "_auroc" for k in ["benign", "bowens", "bcc", "scc"]]
auprc_keys = [k + "_auprc" for k in ["benign", "bowens", "bcc", "scc"]]
results = pd.DataFrame(
    columns=["foundation_model", "aggregator", "classifier", "fold"]
    + auroc_keys
    + auprc_keys
)


for m in model_files:
    model_details = {}
    fm, agg, heads, _, fold = os.path.basename(m).split(".")[0].split("-")

    if agg == "gabmil":
        model = MILClassifier(
            EMBED_DIMS[fm], NUM_LABELS, int(heads.split("_")[0]), gated=True
        )
    elif agg == "abmil":
        model = MILClassifier(
            EMBED_DIMS[fm], NUM_LABELS, int(heads.split("_")[0]), gated=False
        )
    model.load_state_dict(torch.load(m))
    model.to(device)

    loss_fn = nn.CrossEntropyLoss()

    _, val_loader = get_loaders(
        int(fold), specimens_by_fold, slides_by_specimen[fm], labels_dict
    )

    loss, labels, probs, ids = val_epoch(model, val_loader, loss_fn, device)
    ids, probs = get_spec_level_probs(ids, probs)
    labels_onehot_val = np.array(itemgetter(*ids)(labels_onehot))

    auroc = roc_auc_score(
        labels_onehot_val, probs, average=None, multi_class="ovr"
    )
    auroc_dict = {auroc_keys[i]: v for i, v in enumerate(auroc)}

    auprc = average_precision_score(labels_onehot_val, probs, average=None)
    auprc_dict = {auprc_keys[i]: v for i, v in enumerate(auprc)}

    model_details["foundation_model"] = fm
    model_details["aggregator"] = agg + "-" + heads.split("_")[0] + "_heads"
    model_details["classifier"] = "MLP"
    model_details["fold"] = fold
    model_details = model_details | auroc_dict | auprc_dict
    details_df = pd.Series(model_details)
    results = pd.concat([results, details_df.to_frame().T], ignore_index=True)

In [None]:
results.to_csv("outputs/experiments_by_fold.csv", sep="|")