In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from dotenv import load_dotenv

load_dotenv()
os.chdir("../..")

DATA_DIR = os.getenv("DATA_DIR")
OUTPUT_DIR = os.getenv("OUTPUT_DIR")

In [None]:
import copy
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import DataLoader

from data_processing.datasets import (
    SlideClassificationDataset,
    collate_slide_embeds,
)
from data_processing.label import Label
from evaluation.eval import Evaluator
from models import MLP
from models.training.train import train_epoch, val_epoch
from models.training.load import get_loaders
from data_processing.load_data import load_data
from data_processing.split import (
    train_val_split_labels,
    train_val_split_slides,
)

In [None]:
# load the labels data with folds
label_path = os.path.join(DATA_DIR, "labels/labels.csv")
fold_path = os.path.join(DATA_DIR, "folds.json")
embedding_path = os.path.join(
    OUTPUT_DIR, "prism/slide_embeddings/prism_slide_embeds_GAP.pkl"
)

df = load_data(
    label_path=label_path, embedding_path=embedding_path, fold_path=fold_path
)

In [None]:
# map specimen id to a list of WSIs
slides_by_specimen = df.groupby("specimen_id").groups
slides_by_specimen = {k: list(v) for k, v in slides_by_specimen.items()}

In [None]:
df = (
    df.reset_index()
    .drop(columns=["slide_id"])
    .drop_duplicates(subset=["specimen_id"])
    .set_index("specimen_id")
)
labels_onehot = df[Label._member_names_].to_dict(orient="split", index=True)
labels_onehot = {
    k: np.array(labels_onehot["data"][i])
    for i, k in enumerate(labels_onehot["index"])
}
labels_dict = {row.name: int(row["label"]) for _, row in df.iterrows()}

In [None]:
# get list of specimens within each fold
specimens_by_fold = (
    df.reset_index()[["specimen_id", "fold"]]
    .drop_duplicates()
    .set_index("specimen_id")
    .groupby("fold")
    .groups
)
specimens_by_fold = [list(specs) for specs in specimens_by_fold.values()]

In [None]:
class_freqs = {
    label: df[label].value_counts(normalize=True).iloc[1]
    for label in Label._member_names_
}

In [None]:
df

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")

In [None]:
def get_loaders(
    val_fold: int,
    specimens_by_fold: List[List[str]],
    slides_by_specimen: Dict[str, List[str]],
    labels_by_specimen: Dict[str, int],
    slide_embeds_path: str,
) -> Tuple[DataLoader, DataLoader]:
    train, val = train_val_split_slides(
        val_fold=val_fold,
        specimens_by_fold=specimens_by_fold,
        slides_by_specimen=slides_by_specimen,
    )
    train_labels, val_labels = train_val_split_labels(
        val_fold=val_fold,
        labels_by_specimen=labels_by_specimen,
        specimens_by_fold=specimens_by_fold,
    )

    train_loader = DataLoader(
        SlideClassificationDataset(slide_embeds_path, train, train_labels),
        batch_size=1,
        shuffle=True,
        collate_fn=collate_slide_embeds,
    )
    val_loader = DataLoader(
        SlideClassificationDataset(slide_embeds_path, val, val_labels),
        batch_size=1,
        shuffle=False,
        collate_fn=collate_slide_embeds,
    )

    return train_loader, val_loader

In [None]:
from models.training.load import get_loaders
from functools import partial

In [None]:
EPOCHS = 30
BATCH_SIZE = 16
NUM_LABELS = 4
PATIENCE = 10

foundation_model = "prism"
aggregator = "PRISM_test"
model_name = (
    "chkpts/{foundation_model}/{foundation_model}-{aggregator}-fold-{i}.pt"
)
exp_name = "{foundation_model}/testing/{foundation_model}-3_hidden"

evaluator = Evaluator(Label)
auroc_keys = [k + "_auroc" for k in ["benign", "bowens", "bcc", "scc"]]
auprc_keys = [k + "_auprc" for k in ["benign", "bowens", "bcc", "scc"]]
results = pd.DataFrame(
    columns=["foundation_model", "aggregator", "classifier", "fold"]
    + auroc_keys
    + auprc_keys
)

for i in range(len(specimens_by_fold)):
    print(f"--------------------FOLD    {i + 1}--------------------")
    # get dataloaders and embed dims for loaded embeddings
    train_loader, val_loader = get_loaders(
        val_fold=i,
        specimens_by_fold=specimens_by_fold,
        slides_by_specimen=slides_by_specimen,
        labels_by_specimen=labels_dict,
        train_dataset_class=partial(
            SlideClassificationDataset, slide_embeds_path=embedding_path
        ),
        val_dataset_class=partial(
            SlideClassificationDataset, slide_embeds_path=embedding_path
        ),
        collate_fn=collate_slide_embeds,
    )

    for sample in train_loader:
        embed_dim = sample["slide_embed"].shape[-1]
        break

    model = MLP(embed_dim, [1024, 512, 256], NUM_LABELS).to(device)
    loss_fn = nn.CrossEntropyLoss()
    optim = AdamW(model.parameters(), lr=1e-5)

    best_loss = float("inf")
    best_model_weights = None
    best_model_data = {"ids": None, "labels": None, "probs": None}
    patience = PATIENCE

    for epoch in range(EPOCHS):
        train_loss = train_epoch(
            model,
            train_loader,
            optim,
            loss_fn,
            BATCH_SIZE,
            ["slide_embed"],
            "label",
            device,
        )
        val_loss, labels, probs, ids = val_epoch(
            model, val_loader, device, ["slide_embed"], "label", loss_fn
        )

        if val_loss < best_loss:
            best_loss = val_loss
            best_model_weights = copy.deepcopy(model.state_dict())
            best_model_data["ids"] = ids
            best_model_data["labels"] = labels
            best_model_data["probs"] = probs
            patience = PATIENCE
        else:
            patience -= 1
            if patience == 0:
                break

        if (epoch + 1) % 2 == 0:
            spaces = " " * (4 - len(str(epoch + 1)))
            print(
                f"--------------------EPOCH{spaces}{epoch + 1}--------------------"
            )
            print(f"train loss: {train_loss:0.6f}")
            print(f"val loss:   {val_loss:0.6f}")
            print()

    # save the best model
    torch.save(
        best_model_weights,
        os.path.join(
            OUTPUT_DIR,
            model_name.format(
                foundation_model=foundation_model, aggregator=aggregator, i=i
            ),
        ),
    )

    evaluator.fold(i, best_model_data, len(specimens_by_fold))

evaluator.finalize(class_freqs)
evaluator.save_figs(exp_name.format(foundation_model=foundation_model))
evaluator.results.to_csv(
    "outputs/experiments_by_fold.csv", sep="|", mode="a", header=False
)