In [None]:
import os
from dotenv import load_dotenv

load_dotenv()
os.chdir("..")

DATA_DIR = os.getenv("DATA_DIR")
OUTPUT_DIR = os.getenv("OUTPUT_DIR")

In [None]:
# load the labels data with folds
from utils.load_data import load_data

label_path = os.path.join(DATA_DIR, "labels/labels.csv")
fold_path = os.path.join(DATA_DIR, "folds.json")

df = load_data(label_path=label_path, fold_path=fold_path)
df = df.set_index("specimen_id")
labels_onehot = df[["bowens", "scc", "bcc", "na"]]
labels_dict = {row.name: int(row["label"]) for _, row in df.iterrows()}

In [None]:
# get the absolute path for each slide's set of tile embeddings
tile_embed_dir = "/opt/gpudata/skin-cancer/outputs/uni/tile_embeddings"
fnames = os.listdir(tile_embed_dir)
tile_embed_paths = [
    os.path.join(tile_embed_dir, fname)
    for fname in fnames
    if fname.endswith(".pkl") and fname[:6] in set(df.index)
]

In [None]:
# get list of specimens within each fold
specimens_by_fold = df.groupby("fold").groups
specimens_by_fold = [list(specs) for specs in specimens_by_fold.values()]

In [None]:
# map specimens to slides
slides_by_specimen = {spec: [] for spec in list(df.index)}
for slide in tile_embed_paths:
    slide_name = os.path.basename(slide)[:-4]
    spec = slide_name[:6]
    if slides_by_specimen.get(spec) is not None:
        slides_by_specimen[spec].append(slide)

In [None]:
from data_models.Label import Label

spec_freqs = {
    label: df[label].value_counts(normalize=True).iloc[1]
    for label in Label._member_names_
}

In [None]:
import torch

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")

In [None]:
from torch import nn
from torch.utils.data import DataLoader


def train_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim,
    loss_fn: nn.Module,
) -> float:
    agg_loss = 0.0

    model.train()
    for i, sample in enumerate(dataloader):
        embeds = sample["tile_embeds"]
        label = sample["label"]
        if len(embeds.shape) < 3:
            embeds = embeds.unsqueeze(0)

        optimizer.zero_grad()

        output = model(embeds.to(device), sample["coords"].to(device))
        if len(output.shape) < 2:
            output = output.unsqueeze(0)
        loss = loss_fn(output, label.to(device))
        agg_loss += loss.item()
        loss.backward()

        optimizer.step()
    return agg_loss / (i + 1)

In [None]:
from typing import List, Tuple


def val_epoch(
    model: nn.Module, dataloader: DataLoader, loss_fn: nn.Module
) -> Tuple[float, List[float]]:
    agg_loss = 0.0
    outputs = []
    labels = []
    ids = []

    model.eval()
    with torch.no_grad():
        for i, sample in enumerate(dataloader):
            embeds = sample["tile_embeds"]
            label = sample["label"]
            if len(embeds.shape) < 3:
                embeds = embeds.unsqueeze(0)

            output = model(embeds.to(device), sample["coords"].to(device))
            if len(output.shape) < 2:
                output = output.unsqueeze(0)
            loss = loss_fn(output, label.to(device))
            outputs.append(torch.softmax(output.detach().cpu(), dim=-1))
            labels.append(label)
            ids.extend(sample["id"])
            agg_loss += loss.item()
    outputs = torch.cat(outputs)
    labels = torch.cat(labels)
    return agg_loss / (i + 1), labels, outputs, ids

In [None]:
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

# dictionaries to keep eval curve results for each label
tprs = {label: [] for label in Label._member_names_}
aucs = {label: [] for label in Label._member_names_}
precisions = {label: [] for label in Label._member_names_}
aps = {label: [] for label in Label._member_names_}

# mean x vals and axes for curves
mean_fpr = np.linspace(0, 1, 100)
mean_recall = np.linspace(0, 1, 100)
roc_fig, roc_axs = plt.subplots(2, 2, figsize=(12, 12))
prc_fig, prc_axs = plt.subplots(2, 2, figsize=(12, 12))

In [None]:
from typing import Dict

from torch.utils.data import DataLoader

from data_models.datasets import SlideEncodingDataset, collate_tile_embeds
from utils.split import train_val_split_slides, train_val_split_labels


def get_loaders(
    val_fold: int,
    specimens_by_fold: List[List[str]],
    slides_by_specimen: Dict[str, List[str]],
    labels_by_specimen: Dict[str, int],
) -> Tuple[DataLoader, DataLoader]:
    train, val = train_val_split_slides(
        val_fold=val_fold,
        specimens_by_fold=specimens_by_fold,
        slides_by_specimen=slides_by_specimen,
    )
    train_labels, val_labels = train_val_split_labels(
        val_fold=val_fold,
        labels_by_specimen=labels_by_specimen,
        specimens_by_fold=specimens_by_fold,
    )

    train_loader = DataLoader(
        SlideEncodingDataset(train, train_labels),
        batch_size=1,
        shuffle=True,
        collate_fn=collate_tile_embeds,
    )
    val_loader = DataLoader(
        SlideEncodingDataset(val, val_labels),
        batch_size=1,
        shuffle=False,
        collate_fn=collate_tile_embeds,
    )

    return train_loader, val_loader

In [None]:
import copy

from torch import nn
from torch.optim import AdamW

from models.vanilla_attn import VanillaAttentionAggregator
from utils.eval import get_spec_level_probs, plot_eval


EPOCHS = 100
NUM_LABELS = 4
PATIENCE = 10


for i in range(len(specimens_by_fold)):
    print(f"--------------------FOLD    {i + 1}--------------------")
    # get dataloaders and embed dims for loaded embeddings
    train_loader, val_loader = get_loaders(
        i, specimens_by_fold, slides_by_specimen, labels_dict
    )
    for sample in train_loader:
        embed_dim = sample["tile_embeds"].shape[-1]
        break

    model = VanillaAttentionAggregator(embed_dim, NUM_LABELS, 1, True).to(
        device
    )
    loss_fn = nn.CrossEntropyLoss()
    optim = AdamW(model.parameters(), lr=1e-5)

    best_loss = float("inf")
    best_model_weights = None
    best_model_data = {"ids": None, "labels": None, "probs": None}
    patience = PATIENCE

    for epoch in range(EPOCHS):
        train_loss = train_epoch(model, train_loader, optim, loss_fn)
        val_loss, labels, probs, ids = val_epoch(model, val_loader, loss_fn)

        if val_loss < best_loss:
            best_loss = val_loss
            best_model_weights = copy.deepcopy(model.state_dict())
            best_model_data["ids"] = ids
            best_model_data["labels"] = labels
            best_model_data["probs"] = probs
            patience = PATIENCE
        else:
            patience -= 1
            if patience == 0:
                break

        if (epoch + 1) % 10 == 0:
            spaces = " " * (4 - len(str(epoch + 1)))
            print(
                f"--------------------EPOCH{spaces}{epoch + 1}--------------------"
            )
            print(f"train loss: {train_loss:0.6f}")
            print(f"val loss:   {val_loss:0.6f}")
            print()

    # save the best model
    torch.save(
        best_model_weights,
        os.path.join(OUTPUT_DIR, f"chkpts/uni-vanil-pool-fold-{i}"),
    )

    # extract relevant data from val results for best model
    ids, probs = get_spec_level_probs(
        best_model_data["ids"], best_model_data["probs"]
    )
    labels_onehot_val = labels_onehot.loc[ids]

    # for each predicted class, plot the current classifier's eval
    # and retain relevant data in dicts
    for j, class_of_interest in enumerate(Label._member_names_):
        interp_tpr, roc_auc = plot_eval(
            mean_x=mean_fpr,
            onehot_labels=labels_onehot_val[f"{class_of_interest}"],
            probs=probs[:, Label[class_of_interest].value],
            ax=roc_axs[j // 2][j % 2],
            plot_type="ROC",
            fold_idx=i,
            plot_chance_level=i == len(specimens_by_fold) - 1,
        )
        tprs[class_of_interest].append(interp_tpr)
        aucs[class_of_interest].append(roc_auc)

        interp_precision, average_precision = plot_eval(
            mean_x=mean_recall,
            onehot_labels=labels_onehot_val[f"{class_of_interest}"],
            probs=probs[:, Label[class_of_interest].value],
            ax=prc_axs[j // 2][j % 2],
            plot_type="PRC",
            fold_idx=i,
            plot_chance_level=False,
        )
        precisions[class_of_interest].append(interp_precision)
        aps[class_of_interest].append(average_precision)

In [None]:
from utils.eval import create_mean_curve

exp_name = "uni/vanilla/pool/uni-3_hidden"

# plot the mean eval curves
for j, class_of_interest in enumerate(Label._member_names_):
    create_mean_curve(
        mean_fpr,
        tprs[class_of_interest],
        aucs[class_of_interest],
        roc_axs[j // 2][j % 2],
        "ROC",
        class_of_interest,
    )

    prc_ax = prc_axs[j // 2][j % 2]
    create_mean_curve(
        mean_recall,
        precisions[class_of_interest],
        aps[class_of_interest],
        prc_ax,
        "PRC",
        class_of_interest,
    )

    # add chance line for PRC == label freq at specimen level
    prc_ax.axhline(
        spec_freqs[class_of_interest],
        linestyle="--",
        label=r"Chance level (AP = %0.2f)" % (spec_freqs[class_of_interest]),
        color="black",
    )

    # reorder legend
    handles, labs = prc_ax.get_legend_handles_labels()
    handles[-1], handles[-3] = handles[-3], handles[-1]
    labs[-1], labs[-3] = labs[-3], labs[-1]
    prc_ax.legend(handles=handles, labels=labs, loc="lower right")

# roc_fig.show()
roc_fig.savefig(f"outputs/{exp_name}-roc.png")
# prc_fig.show()
prc_fig.savefig(f"outputs/{exp_name}-prc.png")