In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from dotenv import load_dotenv

load_dotenv()
os.chdir("..")

DATA_DIR = os.getenv("DATA_DIR")
OUTPUT_DIR = os.getenv("OUTPUT_DIR")

In [None]:
# load the labels data with folds
import numpy as np

from utils.load_data import load_data
from data_models.Label import Label

label_path = os.path.join(DATA_DIR, "labels/labels.csv")
fold_path = os.path.join(DATA_DIR, "folds.json")

df = load_data(label_path=label_path, fold_path=fold_path)
df = df.set_index("specimen_id")
labels_onehot = df[Label._member_names_].to_dict(orient="split", index=True)
labels_onehot = {
    k: np.array(labels_onehot["data"][i])
    for i, k in enumerate(labels_onehot["index"])
}
labels_dict = {row.name: int(row["label"]) for _, row in df.iterrows()}

In [None]:
foundation_model = "uni"

# get the absolute path for each slide's set of tile embeddings
tile_embed_dir = f"/opt/gpudata/skin-cancer/outputs/{foundation_model}/tile_embeddings_sorted"
fnames = os.listdir(tile_embed_dir)
tile_embed_paths = [
    os.path.join(tile_embed_dir, fname)
    for fname in fnames
    if fname.endswith(".pkl") and fname[:6] in set(df.index)
]

In [None]:
# get list of specimens within each fold
specimens_by_fold = df.groupby("fold").groups
specimens_by_fold = [list(specs) for specs in specimens_by_fold.values()]

In [None]:
# map specimens to slides
slides_by_specimen = {spec: [] for spec in list(df.index)}
for slide in tile_embed_paths:
    slide_name = os.path.basename(slide)[:-4]
    spec = slide_name[:6]
    if slides_by_specimen.get(spec) is not None:
        slides_by_specimen[spec].append(slide)

In [None]:
from data_models.Label import Label

class_freqs = {
    label: df[label].value_counts(normalize=True).iloc[1]
    for label in Label._member_names_
}

In [None]:
import torch

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")

In [None]:
from typing import Dict, List, Tuple

from torch.utils.data import DataLoader

from data_models.datasets import SlideEncodingDataset, collate_tile_embeds
from utils.split import train_val_split_slides, train_val_split_labels


def get_loaders(
    val_fold: int,
    specimens_by_fold: List[List[str]],
    slides_by_specimen: Dict[str, List[str]],
    labels_by_specimen: Dict[str, int],
) -> Tuple[DataLoader, DataLoader]:
    train, val = train_val_split_slides(
        val_fold=val_fold,
        specimens_by_fold=specimens_by_fold,
        slides_by_specimen=slides_by_specimen,
    )
    train_labels, val_labels = train_val_split_labels(
        val_fold=val_fold,
        labels_by_specimen=labels_by_specimen,
        specimens_by_fold=specimens_by_fold,
    )

    train_loader = DataLoader(
        SlideEncodingDataset(train, train_labels),
        batch_size=1,
        shuffle=True,
        collate_fn=collate_tile_embeds,
    )
    val_loader = DataLoader(
        SlideEncodingDataset(val, val_labels),
        batch_size=1,
        shuffle=False,
        collate_fn=collate_tile_embeds,
    )

    return train_loader, val_loader

In [None]:
import copy
from operator import itemgetter

import pandas as pd
from sklearn.metrics import average_precision_score, roc_auc_score
from torch import nn
from torch.optim import AdamW

from models.agg import MILClassifier
from models.utils.train import train_epoch, val_epoch
from evaluation.eval import Evaluator


EPOCHS = 30
BATCH_SIZE = 16
NUM_LABELS = 4
PATIENCE = 10

gates = [False]
head_counts = [1]
auroc_keys = [k + "_auroc" for k in ["benign", "bowens", "bcc", "scc"]]
auprc_keys = [k + "_auprc" for k in ["benign", "bowens", "bcc", "scc"]]
results = pd.DataFrame(
    columns=["foundation_model", "aggregator", "classifier", "fold"]
    + auroc_keys
    + auprc_keys
)

for exp_idx in range(len(gates)):
    print(f"\n\n****EXPERIMENT {exp_idx + 1}****\n\n")
    # set experiment-specific variables
    gated = gates[exp_idx]
    heads = head_counts[exp_idx]
    aggregator = "gabmil_test" if gated else "abmil_test"
    save_directory = "gated" if gated else "ungated"
    model_name = (
        "chkpts/{foundation_model}-{aggregator}-{heads}_heads-fold-{i}.pt"
    )
    # exp_name = f"{foundation_model}/abmil/{save_directory}/{foundation_model}-3_hidden-{heads}_heads"
    exp_name = (
        f"{foundation_model}/testing/{foundation_model}-3_hidden-{heads}_heads"
    )

    evaluator = Evaluator(Label)

    # for each fold, train and validate while saving the best models
    # then evaluate the final outputs by generating curves using Evaluator
    for i in range(len(specimens_by_fold)):
        print(f"--------------------FOLD    {i + 1}--------------------")
        # get dataloaders and embed dims for loaded embeddings
        train_loader, val_loader = get_loaders(
            i, specimens_by_fold, slides_by_specimen, labels_dict
        )
        for sample in train_loader:
            embed_dim = sample["tile_embeds"].shape[-1]
            break

        model = MILClassifier(embed_dim, NUM_LABELS, heads, gated).to(device)
        loss_fn = nn.CrossEntropyLoss()
        optim = AdamW(model.parameters(), lr=1e-5)

        best_loss = float("inf")
        best_model_weights = None
        best_model_data = {"ids": None, "labels": None, "probs": None}
        patience = PATIENCE

        for epoch in range(EPOCHS):
            train_loss = train_epoch(
                model,
                train_loader,
                optim,
                loss_fn,
                BATCH_SIZE,
                ["tile_embeds", "pos"],
                "label",
                device,
            )
            val_loss, labels, probs, ids = val_epoch(
                model,
                val_loader,
                device,
                ["tile_embeds", "pos"],
                "label",
                loss_fn,
            )

            if val_loss < best_loss:
                best_loss = val_loss
                best_model_weights = copy.deepcopy(model.state_dict())
                best_model_data["ids"] = ids
                best_model_data["labels"] = labels
                best_model_data["probs"] = probs
                patience = PATIENCE
            else:
                patience -= 1
                if patience == 0:
                    break

            if (epoch + 1) % 2 == 0:
                spaces = " " * (4 - len(str(epoch + 1)))
                print(
                    f"--------------------EPOCH{spaces}{epoch + 1}--------------------"
                )
                print(f"train loss: {train_loss:0.6f}")
                print(f"val loss:   {val_loss:0.6f}")
                print()

        # save the best model
        torch.save(
            best_model_weights,
            os.path.join(
                OUTPUT_DIR,
                model_name.format(
                    foundation_model=foundation_model,
                    aggregator=aggregator,
                    heads=heads,
                    i=i,
                ),
            ),
        )

        # extract relevant data from val results for best model
        ids, probs = Evaluator.get_spec_level_probs(
            best_model_data["ids"], best_model_data["probs"]
        )
        labels_onehot_val = np.array(itemgetter(*ids)(labels_onehot))

        evaluator.fold(probs, labels_onehot_val, i, len(specimens_by_fold))
        auroc = roc_auc_score(
            labels_onehot_val, probs, average=None, multi_class="ovr"
        )
        auroc_dict = {auroc_keys[i]: v for i, v in enumerate(auroc)}

        auprc = average_precision_score(labels_onehot_val, probs, average=None)
        auprc_dict = {auprc_keys[i]: v for i, v in enumerate(auprc)}

        model_details = {}
        model_details["foundation_model"] = foundation_model
        model_details["aggregator"] = aggregator
        model_details["classifier"] = "MLP"
        model_details["fold"] = i
        model_details = model_details | auroc_dict | auprc_dict
        details_df = pd.Series(model_details)
        results = pd.concat(
            [results, details_df.to_frame().T], ignore_index=True
        )

    evaluator.finalize(class_freqs)
    evaluator.save_figs(exp_name)

In [None]:
results.to_csv(
    "outputs/experiments_by_fold.csv", sep="|", mode="a", header=False
)