In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from dotenv import load_dotenv

load_dotenv()
os.chdir("..")

DATA_DIR = os.getenv("DATA_DIR")
OUTPUT_DIR = os.getenv("OUTPUT_DIR")

In [None]:
import copy
from operator import itemgetter
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn

from sklearn.metrics import average_precision_score, roc_auc_score
from torch import nn
from torch.optim import AdamW

from data_models.datasets import (
    EnsembleDataset,
    SlideClassificationDataset,
    SlideEncodingDataset,
)
from data_models.Label import Label
from models.ensemble import EnsembleClassifier
from utils.eval import Evaluator
from utils.load_data import load_data
from utils.split import train_val_split_labels, train_val_split_slides

In [None]:
label_path = os.path.join(DATA_DIR, "labels/labels.csv")
fold_path = os.path.join(DATA_DIR, "folds.json")

df = load_data(label_path=label_path, fold_path=fold_path)
df = df.set_index("specimen_id")
labels_onehot = df[Label._member_names_].to_dict(orient="split", index=True)
labels_onehot = {
    k: np.array(labels_onehot["data"][i])
    for i, k in enumerate(labels_onehot["index"])
}
labels_dict = {row.name: int(row["label"]) for _, row in df.iterrows()}

In [None]:
# get list of specimens within each fold
specimens_by_fold = df.groupby("fold").groups
specimens_by_fold = [list(specs) for specs in specimens_by_fold.values()]

In [None]:
class_freqs = {
    label: df[label].value_counts(normalize=True).iloc[1]
    for label in Label._member_names_
}

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")

In [None]:
def get_datasets(
    val_fold: int,
    specimens_by_fold: List[List[str]],
    slides_by_specimen: Dict[str, List[str]],
    labels_by_specimen: Dict[str, int],
    tile_embed_paths: List[Dict[str, str]],
    slide_embed_paths: Optional[List[str]] = None,
) -> Tuple[EnsembleDataset, EnsembleDataset]:
    """
    Parameters
    ----------
    tile_embed_paths : List[Dict[str, str]]
        A list of tile embedding paths for different models. Each item in
        the list must be a dict with slide ids as keys and paths to tile
        embed pkl files as values

    slide_embed_paths : List[str]
        A list of slide embedding paths for different models. Each item in
        the list must be a path to a slide embed pkl file

    Returns
    -------
    Tuple[EnsembleDataset, EnsembleDataset]
        Train, val datasets
    """
    # get the train and val splits
    train, val = train_val_split_slides(
        val_fold=val_fold,
        specimens_by_fold=specimens_by_fold,
        slides_by_specimen=slides_by_specimen,
    )
    train_labels, val_labels = train_val_split_labels(
        val_fold=val_fold,
        labels_by_specimen=labels_by_specimen,
        specimens_by_fold=specimens_by_fold,
    )

    # get the train sets
    train_slide_encoder_datasets = [
        SlideEncodingDataset(list(itemgetter(*train)(paths)), train_labels)
        for paths in tile_embed_paths
    ]
    if slide_embed_paths is not None:
        train_slide_classifier_datasets = [
            SlideClassificationDataset(path, train, train_labels)
            for path in slide_embed_paths
        ]
    else:
        train_slide_classifier_datasets = []

    # get the val sets
    val_slide_encoder_datasets = [
        SlideEncodingDataset(list(itemgetter(*val)(paths)), val_labels)
        for paths in tile_embed_paths
    ]
    if slide_embed_paths is not None:
        val_slide_classifier_datasets = [
            SlideClassificationDataset(path, val, val_labels)
            for path in slide_embed_paths
        ]
    else:
        val_slide_classifier_datasets = []

    # construct the ensemble sets
    train_set = EnsembleDataset(
        train_slide_encoder_datasets, train_slide_classifier_datasets
    )
    val_set = EnsembleDataset(
        val_slide_encoder_datasets, val_slide_classifier_datasets
    )
    return train_set, val_set

In [None]:
def prep_model_input(
    sample: Tuple[List[Dict[str, Any]], List[Dict[str, Any]]],
    device: torch.device,
) -> Tuple[
    List[torch.Tensor],
    List[torch.Tensor],
    List[torch.Tensor],
    torch.Tensor,
    str,
]:
    tile, slide = sample

    # add batch dim and collate
    tile_embeds = []
    coords = []
    slide_embeds = []
    label = None
    slide_id = None
    for embed in tile:
        tile_embeds.append(embed["tile_embeds"].unsqueeze(0).to(device))
        coords.append(embed["pos"].unsqueeze(0).to(device))
        label = embed["label"]
        slide_id = embed["id"]

    for embed in slide:
        slide_embeds.append(embed["slide_embed"].unsqueeze(0).to(device))
        label = embed["label"]
        slide_id = embed["id"]

    label = torch.tensor([label]).to(device)

    return tile_embeds, coords, slide_embeds, label, slide_id


def train_epoch(
    model: EnsembleClassifier,
    dataset: EnsembleDataset,
    optimizer: torch.optim,
    loss_fn: nn.Module,
    grad_accum_steps: int,
    device: torch.device,
) -> float:
    """
    Trains an epoch.

    Parameters
    ----------
    model : nn.Module
        The model to train

    dataset : Dataset
        The dataset for the training data

    optimizer : torch.optim
        The optimizer

    loss_fn : nn.Module
        The loss function

    grad_accum_steps : int
        The number of batches/samples to accumulate gradients for before
        updating the model

    device : torch.device
        The device to send the model and data to

    Returns
    -------
    float
        The average training loss for the epoch
    """
    agg_loss = 0.0
    model = model.to(device) if device else model
    model.train()
    order = np.random.permutation(len(dataset))
    for n, i in enumerate(order):
        tile_embeds, coords, slide_embeds, label, _ = prep_model_input(
            dataset[i], device
        )

        logits = model.forward(tile_embeds, coords, slide_embeds)
        loss = loss_fn(logits, label)
        loss.backward()
        agg_loss += loss.item()

        # accumulate grad until grad_accum_steps is reached
        if (n + 1) % grad_accum_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

    # ensure no remaining accumulated grad
    if (n + 1) % grad_accum_steps != 0:
        optimizer.step()
        optimizer.zero_grad()

    return agg_loss / (n + 1)


def val_epoch(
    model: EnsembleClassifier,
    dataset: EnsembleDataset,
    device: torch.device,
    loss_fn: Optional[nn.Module] = None,
) -> Tuple[float, torch.Tensor, torch.Tensor, List[str]]:
    """
    Validates an epoch.

    Parameters
    ----------
    model : nn.Module
        The model to validate

    dataset : Dataset
        The dataset for the validation data

    device : torch.device
        The device to send the model and data to

    loss_fn : nn.Module, optional
        The loss function; if not provided no loss will be calculated

    Returns
    -------
    float
        The average training loss for the epoch; if loss_fn is not provided
        this will be 0

    torch.Tensor
        The labels for the validation data

    torch.Tensor
        The model outputs for the validation data (softmaxed logits)

    List[str]
        The IDs for the validation data
    """
    agg_loss = 0.0
    outputs = []
    labels = []
    ids = []

    model.to(device)
    model.eval()
    with torch.inference_mode():
        for i, sample in enumerate(dataset):
            tile_embeds, coords, slide_embeds, label, slide_id = (
                prep_model_input(sample, device)
            )

            logits = model(tile_embeds, coords, slide_embeds)
            if loss_fn is not None:
                loss = loss_fn(logits, label)
                agg_loss += loss.item()

            outputs.append(torch.softmax(logits.detach().cpu(), dim=-1))
            labels.append(label)
            ids.append(slide_id)
    outputs = torch.cat(outputs)
    labels = torch.cat(labels)
    return agg_loss / (i + 1), labels, outputs, ids

In [None]:
# get the tile embedding paths for each tile encoder model
def get_tile_embed_paths(tile_embed_dir: str) -> Dict[str, str]:
    fnames = os.listdir(tile_embed_dir)
    fnames.sort()
    tile_embed_paths = {
        fname[:-4]: os.path.join(tile_embed_dir, fname)
        for fname in fnames
        if fname.endswith(".pkl") and fname[:6] in set(df.index)
    }

    # keys are slide ids, values are the paths
    return tile_embed_paths


included_models = {"gigapath": "tile", "prism": "slide"}
foundation_model = "ensemble_" + "".join(
    [m[0] for m in included_models.keys()]
)

# get the paths for all included tile embeddings
tile_encoders = [k for k, v in included_models.items() if v == "tile"]
tile_embed_path_base = OUTPUT_DIR + "/{model}/tile_embeddings_sorted"
tile_embed_paths = {
    model: get_tile_embed_paths(tile_embed_path_base.format(model=model))
    for model in tile_encoders
}

# get the path for the prism slide embeddings if included
slide_embeds_path = None
if "prism" in included_models:
    slide_embeds_path = os.path.join(
        OUTPUT_DIR, "prism/slide_embeddings/prism_slide_embeds_perceiver.pkl"
    )

# map specimens to slides
slides_by_specimen = {spec: [] for spec in list(df.index)}
for slide_name in tile_embed_paths["gigapath"]:
    spec = slide_name[:6]
    if slides_by_specimen.get(spec) is not None:
        slides_by_specimen[spec].append(slide_name)

In [None]:
EPOCHS = 30
BATCH_SIZE = 16
NUM_LABELS = 4
PATIENCE = 10

auroc_keys = [k + "_auroc" for k in ["benign", "bowens", "bcc", "scc"]]
auprc_keys = [k + "_auprc" for k in ["benign", "bowens", "bcc", "scc"]]
results = pd.DataFrame(
    columns=["foundation_model", "aggregator", "classifier", "fold"]
    + auroc_keys
    + auprc_keys
)

# set experiment-specific variables
gated = False
heads = 1
aggregator = "abmil"
save_directory = "ungated"
model_name = "chkpts/{foundation_model}-{aggregator}-{heads}_heads-fold-{i}.pt"
exp_name = f"{foundation_model}/abmil/{save_directory}/{foundation_model}-3_hidden-{heads}_heads"

evaluator = Evaluator(Label)

# for each fold, train and validate while saving the best models
# then evaluate the final outputs by generating curves using Evaluator
for i in range(len(specimens_by_fold)):
    print(f"--------------------FOLD    {i + 1}--------------------")
    # get dataloaders and embed dims for loaded embeddings
    train_set, val_set = get_datasets(
        i,
        specimens_by_fold,
        slides_by_specimen,
        labels_dict,
        list(tile_embed_paths.values()),
        [slide_embeds_path] if slide_embeds_path is not None else None,
    )

    model = EnsembleClassifier(
        tile_encoder_dim=[1536], slide_encoder_dim=1280, out_features=4
    ).to(device)
    loss_fn = nn.CrossEntropyLoss()
    optim = AdamW(model.parameters(), lr=1e-5)

    best_loss = float("inf")
    best_model_weights = None
    best_model_data = {"ids": None, "labels": None, "probs": None}
    patience = PATIENCE

    for epoch in range(EPOCHS):
        train_loss = train_epoch(
            model,
            train_set,
            optim,
            loss_fn,
            BATCH_SIZE,
            device,
        )
        val_loss, labels, probs, ids = val_epoch(
            model,
            val_set,
            device,
            loss_fn,
        )

        if val_loss < best_loss:
            best_loss = val_loss
            best_model_weights = copy.deepcopy(model.state_dict())
            best_model_data["ids"] = ids
            best_model_data["labels"] = labels
            best_model_data["probs"] = probs
            patience = PATIENCE
        else:
            patience -= 1
            if patience == 0:
                break

        if (epoch + 1) % 2 == 0:
            spaces = " " * (4 - len(str(epoch + 1)))
            print(
                f"--------------------EPOCH{spaces}{epoch + 1}--------------------"
            )
            print(f"train loss: {train_loss:0.6f}")
            print(f"val loss:   {val_loss:0.6f}")
            print()

    # save the best model
    torch.save(
        best_model_weights,
        os.path.join(
            OUTPUT_DIR,
            model_name.format(
                foundation_model=foundation_model,
                aggregator=aggregator,
                heads=heads,
                i=i,
            ),
        ),
    )

    # extract relevant data from val results for best model
    ids, probs = Evaluator.get_spec_level_probs(
        best_model_data["ids"], best_model_data["probs"]
    )
    labels_onehot_val = np.array(itemgetter(*ids)(labels_onehot))

    evaluator.fold(probs, labels_onehot_val, i, len(specimens_by_fold))
    auroc = roc_auc_score(
        labels_onehot_val, probs, average=None, multi_class="ovr"
    )
    auroc_dict = {auroc_keys[i]: v for i, v in enumerate(auroc)}

    auprc = average_precision_score(labels_onehot_val, probs, average=None)
    auprc_dict = {auprc_keys[i]: v for i, v in enumerate(auprc)}

    model_details = {}
    model_details["foundation_model"] = foundation_model
    model_details["aggregator"] = aggregator
    model_details["classifier"] = "MLP"
    model_details["fold"] = i
    model_details = model_details | auroc_dict | auprc_dict
    details_df = pd.Series(model_details)
    results = pd.concat([results, details_df.to_frame().T], ignore_index=True)

evaluator.finalize(class_freqs)
evaluator.save_figs(exp_name)

In [None]:
results.to_csv(
    "outputs/experiments_by_fold.csv", sep="|", mode="a", header=False
)