In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from dotenv import load_dotenv

load_dotenv()
os.chdir("..")

DATA_DIR = os.getenv("DATA_DIR")
OUTPUT_DIR = os.getenv("OUTPUT_DIR")

In [None]:
# load the labels data with folds
import numpy as np

from utils.load_data import load_data
from data_models.Label import Label

label_path = os.path.join(DATA_DIR, "labels/labels.csv")
fold_path = os.path.join(DATA_DIR, "folds.json")

df = load_data(label_path=label_path, fold_path=fold_path)
df = df.set_index("specimen_id")
labels_onehot = df[Label._member_names_].to_dict(orient="split", index=True)
labels_onehot = {
    k: np.array(labels_onehot["data"][i])
    for i, k in enumerate(labels_onehot["index"])
}
labels_dict = {row.name: int(row["label"]) for _, row in df.iterrows()}

In [None]:
# get the tile embedding paths for each tile encoder model
def get_tile_embed_paths(tile_embed_dir):
    fnames = os.listdir(tile_embed_dir)
    fnames.sort()
    tile_embed_paths = {
        fname[:-4]: os.path.join(tile_embed_dir, fname)
        for fname in fnames
        if fname.endswith(".pkl") and fname[:6] in set(df.index)
    }

    # return as dict to act as an ordered set
    return tile_embed_paths


tile_encoders = ["uni", "gigapath"]
tile_embed_path_base = OUTPUT_DIR + "/{model}/tile_embeddings_sorted"
tile_embed_paths = {
    model: get_tile_embed_paths(tile_embed_path_base.format(model=model))
    for model in tile_encoders
}
slide_embeds_path = os.path.join(
    OUTPUT_DIR, "prism/slide_embeddings/prism_slide_embeds_perceiver.pkl"
)

In [None]:
# get list of specimens within each fold
specimens_by_fold = df.groupby("fold").groups
specimens_by_fold = [list(specs) for specs in specimens_by_fold.values()]

In [None]:
# map specimens to slides
slides_by_specimen = {spec: [] for spec in list(df.index)}
for slide_name in tile_embed_paths["uni"]:
    spec = slide_name[:6]
    if slides_by_specimen.get(spec) is not None:
        slides_by_specimen[spec].append(slide_name)

In [None]:
from data_models.Label import Label

class_freqs = {
    label: df[label].value_counts(normalize=True).iloc[1]
    for label in Label._member_names_
}

In [None]:
import torch

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")

In [None]:
from typing import Dict, List, Tuple
from operator import itemgetter

from data_models.datasets import (
    SlideEncodingDataset,
    EnsembleDataset,
    SlideClassificationDataset,
)
from utils.split import train_val_split_slides, train_val_split_labels


def get_datasets(
    val_fold: int,
    specimens_by_fold: List[List[str]],
    slides_by_specimen: Dict[str, List[str]],
    labels_by_specimen: Dict[str, int],
    tile_embed_paths: Dict[str, Dict[str, str]],
    slide_embeds_path: str,
) -> Tuple[EnsembleDataset, EnsembleDataset]:
    # get the train and val splits
    train, val = train_val_split_slides(
        val_fold=val_fold,
        specimens_by_fold=specimens_by_fold,
        slides_by_specimen=slides_by_specimen,
    )
    train_labels, val_labels = train_val_split_labels(
        val_fold=val_fold,
        labels_by_specimen=labels_by_specimen,
        specimens_by_fold=specimens_by_fold,
    )

    # create the datasets
    train_slide_encoder_datasets = {
        model: SlideEncodingDataset(
            list(itemgetter(*train)(tile_embed_paths[model])), train_labels
        )
        for model in tile_embed_paths
    }
    val_slide_encoder_datasets = {
        model: SlideEncodingDataset(
            itemgetter(*val)(tile_embed_paths[model]), val_labels
        )
        for model in tile_embed_paths
    }
    train_slide_classifier_dataset = SlideClassificationDataset(
        slide_embeds_path, train, train_labels
    )
    val_slide_classifier_dataset = SlideClassificationDataset(
        slide_embeds_path, val, val_labels
    )

    # create the ensemble datasets
    train_set = EnsembleDataset(
        train_slide_encoder_datasets
        | {"prism": train_slide_classifier_dataset}
    )
    val_set = EnsembleDataset(
        val_slide_encoder_datasets | {"prism": val_slide_classifier_dataset}
    )

    return train_set, val_set

In [None]:
from typing import Optional

import torch.nn as nn
from torch.utils.data import Dataset


def train_epoch(
    model: nn.Module,
    dataset: Dataset,
    optimizer: torch.optim,
    loss_fn: nn.Module,
    grad_accum_steps: int,
    device: Optional[torch.device] = None,
) -> float:
    """
    Trains an epoch.

    Parameters
    ----------
    model : nn.Module
        The model to train

    dataset : Dataset
        The dataset for the training data

    optimizer : torch.optim
        The optimizer

    loss_fn : nn.Module
        The loss function

    grad_accum_steps : int
        The number of batches/samples to accumulate gradients for before
        updating the model

    device : Optional[torch.device]
        The device to send the model and data to. If not provided, the model
        and data will not be moved to a device

    Returns
    -------
    float
        The average training loss for the epoch
    """
    agg_loss = 0.0
    model = model.to(device) if device else model
    model.train()
    order = np.random.permutation(len(dataset))
    for n, i in enumerate(order):
        sample = dataset[i]
        x = {
            "uni": sample["uni"]["tile_embeds"].unsqueeze(0),
            "gigapath": sample["gigapath"]["tile_embeds"].unsqueeze(0),
            "prism": sample["prism"]["slide_embed"].unsqueeze(0),
        }
        coords = {
            "uni": sample["uni"]["pos"].unsqueeze(0),
            "gigapath": sample["gigapath"]["pos"].unsqueeze(0),
        }
        label = torch.tensor([sample["uni"]["label"]])

        for key in x:
            x[key] = x[key].to(device)
            if key in coords:
                coords[key] = coords[key].to(device)

        logits = model(x, coords)
        loss = loss_fn(logits, label.to(logits.device))
        loss.backward()
        agg_loss += loss.item()

        # accumulate grad until grad_accum_steps is reached
        if (n + 1) % grad_accum_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

    # ensure no remaining accumulated grad
    if (n + 1) % grad_accum_steps != 0:
        optimizer.step()
        optimizer.zero_grad()

    return agg_loss / (n + 1)


def val_epoch(
    model: nn.Module,
    dataset: Dataset,
    device: torch.device,
    loss_fn: Optional[nn.Module] = None,
) -> Tuple[float, torch.Tensor, torch.Tensor, List[str]]:
    """
    Validates an epoch.

    Parameters
    ----------
    model : nn.Module
        The model to validate

    dataset : Dataset
        The dataset for the validation data

    device : torch.device
        The device to send the model and data to

    loss_fn : nn.Module, optional
        The loss function; if not provided no loss will be calculated

    Returns
    -------
    float
        The average training loss for the epoch; if loss_fn is not provided
        this will be 0

    torch.Tensor
        The labels for the validation data

    torch.Tensor
        The model outputs for the validation data (softmaxed logits)

    List[str]
        The IDs for the validation data
    """
    agg_loss = 0.0
    outputs = []
    labels = []
    ids = []

    model.to(device)
    model.eval()
    with torch.no_grad():
        for i, sample in enumerate(dataset):
            x = {
                "uni": sample["uni"]["tile_embeds"].unsqueeze(0),
                "gigapath": sample["gigapath"]["tile_embeds"].unsqueeze(0),
                "prism": sample["prism"]["slide_embed"].unsqueeze(0),
            }
            coords = {
                "uni": sample["uni"]["pos"].unsqueeze(0),
                "gigapath": sample["gigapath"]["pos"].unsqueeze(0),
            }
            label = torch.tensor([sample["uni"]["label"]])

            for key in x:
                x[key] = x[key].to(device)
                if key in coords:
                    coords[key] = coords[key].to(device)

            logits = model(x, coords)
            if loss_fn is not None:
                loss = loss_fn(logits, label.to(device))
                agg_loss += loss.item()

            outputs.append(torch.softmax(logits.detach().cpu(), dim=-1))
            labels.append(label)
            ids.append(sample["uni"]["id"])
    outputs = torch.cat(outputs)
    labels = torch.cat(labels)
    return agg_loss / (i + 1), labels, outputs, ids

In [None]:
import copy
from operator import itemgetter

import pandas as pd
from sklearn.metrics import average_precision_score, roc_auc_score
from torch import nn
from torch.optim import AdamW

from models.ensemble import EnsembleClassifier
from utils.eval import Evaluator


EPOCHS = 30
BATCH_SIZE = 16
NUM_LABELS = 4
PATIENCE = 10

auroc_keys = [k + "_auroc" for k in ["benign", "bowens", "bcc", "scc"]]
auprc_keys = [k + "_auprc" for k in ["benign", "bowens", "bcc", "scc"]]
results = pd.DataFrame(
    columns=["foundation_model", "aggregator", "classifier", "fold"]
    + auroc_keys
    + auprc_keys
)

# set experiment-specific variables
foundation_model = "ensemble"
gated = False
heads = 1
aggregator = "abmil"
save_directory = "ungated"
model_name = "chkpts/{foundation_model}-{aggregator}-{heads}_heads-fold-{i}.pt"
exp_name = f"{foundation_model}/abmil/{save_directory}/{foundation_model}-3_hidden-{heads}_heads"

evaluator = Evaluator(Label)

# for each fold, train and validate while saving the best models
# then evaluate the final outputs by generating curves using Evaluator
for i in range(len(specimens_by_fold)):
    print(f"--------------------FOLD    {i + 1}--------------------")
    # get dataloaders and embed dims for loaded embeddings
    train_set, val_set = get_datasets(
        i,
        specimens_by_fold,
        slides_by_specimen,
        labels_dict,
        tile_embed_paths,
        slide_embeds_path,
    )

    model = EnsembleClassifier().to(device)
    loss_fn = nn.CrossEntropyLoss()
    optim = AdamW(model.parameters(), lr=1e-5)

    best_loss = float("inf")
    best_model_weights = None
    best_model_data = {"ids": None, "labels": None, "probs": None}
    patience = PATIENCE

    for epoch in range(EPOCHS):
        train_loss = train_epoch(
            model,
            train_set,
            optim,
            loss_fn,
            BATCH_SIZE,
            device,
        )
        val_loss, labels, probs, ids = val_epoch(
            model,
            val_set,
            device,
            loss_fn,
        )

        if val_loss < best_loss:
            best_loss = val_loss
            best_model_weights = copy.deepcopy(model.state_dict())
            best_model_data["ids"] = ids
            best_model_data["labels"] = labels
            best_model_data["probs"] = probs
            patience = PATIENCE
        else:
            patience -= 1
            if patience == 0:
                break

        if (epoch + 1) % 2 == 0:
            spaces = " " * (4 - len(str(epoch + 1)))
            print(
                f"--------------------EPOCH{spaces}{epoch + 1}--------------------"
            )
            print(f"train loss: {train_loss:0.6f}")
            print(f"val loss:   {val_loss:0.6f}")
            print()

    # save the best model
    torch.save(
        best_model_weights,
        os.path.join(
            OUTPUT_DIR,
            model_name.format(
                foundation_model=foundation_model,
                aggregator=aggregator,
                heads=heads,
                i=i,
            ),
        ),
    )

    # extract relevant data from val results for best model
    ids, probs = Evaluator.get_spec_level_probs(
        best_model_data["ids"], best_model_data["probs"]
    )
    labels_onehot_val = np.array(itemgetter(*ids)(labels_onehot))

    evaluator.fold(probs, labels_onehot_val, i, len(specimens_by_fold))
    auroc = roc_auc_score(
        labels_onehot_val, probs, average=None, multi_class="ovr"
    )
    auroc_dict = {auroc_keys[i]: v for i, v in enumerate(auroc)}

    auprc = average_precision_score(labels_onehot_val, probs, average=None)
    auprc_dict = {auprc_keys[i]: v for i, v in enumerate(auprc)}

    model_details = {}
    model_details["foundation_model"] = foundation_model
    model_details["aggregator"] = aggregator
    model_details["classifier"] = "MLP"
    model_details["fold"] = i
    model_details = model_details | auroc_dict | auprc_dict
    details_df = pd.Series(model_details)
    results = pd.concat([results, details_df.to_frame().T], ignore_index=True)

evaluator.finalize(class_freqs)
evaluator.save_figs(exp_name)

In [None]:
results.to_csv(
    "outputs/experiments_by_fold.csv", sep="|", mode="a", header=False
)