In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from dotenv import load_dotenv

load_dotenv()
os.chdir("..")

DATA_DIR = os.getenv("DATA_DIR")
OUTPUT_DIR = os.getenv("OUTPUT_DIR")

In [None]:
from collections import Counter
import os
import pickle
from functools import partial
from typing import Dict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from seaborn import histplot
from sklearn.metrics import RocCurveDisplay, roc_curve
import torch

from data_models.Label import NCLabel, Label
from models.nearest_centroid.nearest_centroid import NearestCentroid
from utils.load_data import SpecimenData
from utils.slide_utils import plot_image

### Initialization and utility methods

In [None]:
# load label-specimen mapping
specimens_by_label = SpecimenData(
    label_path=os.path.join(DATA_DIR, "labels/labels.csv")
).specimens_by_label
specimens_by_label = [set(spec_list) for spec_list in specimens_by_label]

In [None]:
def extract_tiles(embeds: dict, tiles_to_filter: torch.Tensor) -> torch.Tensor:
    """
    filters out all tile embeddings that are not captured in the
    tiles_to_filter list of coordinates
    """
    coords = embeds["coords"]
    matches = (coords[:, None] == tiles_to_filter).all(-1)
    mask = matches.any(dim=1)
    return embeds["tile_embeds"][mask]

In [None]:
def get_preds(
    slides: list, models: list, tiles_to_filter: dict = None
) -> dict:
    """
    get predictions for each slide across all models
    """
    preds = {}
    for slide in slides:
        # for each model, get predictions for each tile for a slide
        slide_preds = []
        for model in models:
            with open(os.path.join(model["embedding_dir"], slide), "rb") as f:
                embeds = pickle.load(f)

            embeds = (
                extract_tiles(embeds, tiles_to_filter[slide[:-4]])
                if tiles_to_filter
                else embeds["tile_embeds"]
            )

            slide_preds.append(
                model["model"]
                .predict(embeds.float(), mode="dot_product")
                .softmax(dim=-1)
            )  # (S, C)

        # stack the prediction tensors into a single tensor;
        # dim 0 is the model dimension
        preds[slide[:-4]] = torch.stack(slide_preds)  # (M, S, C)
    return preds

In [None]:
def get_labeled_preds(preds: dict) -> dict:
    """
    convert raw softmax preds to integer label preds;
    eliminates the final dimension of the prediction tensors
    """
    labeled_preds = {}
    for slide, pred in preds.items():
        labeled_preds[slide] = pred.argmax(dim=-1)  # (M, S)
    return labeled_preds

In [None]:
def count_confusion_scores(labeled_preds: dict) -> Counter:
    """
    get total model confusion counts;
    1=all models agree on tile class,
    3=all models disagree on tile
    """
    counts = Counter()
    for pred in labeled_preds.values():
        pred: torch.Tensor
        for tile in pred.transpose(1, 0):
            counts[(len(tile.unique()))] += 1
    return counts

In [None]:
def confusion_by_gt(labeled_preds: dict):
    """
    count confusion across all tiles within a certain classification group
    """
    confusion_counters = {"aggregate": count_confusion_scores(labeled_preds)}
    for label in Label:
        idx = label.value
        confusion_counters[label] = Counter()
        for slide, pred in labeled_preds.items():
            if slide[:6] in specimens_by_label[idx]:
                pred: torch.Tensor
                for tile in pred.transpose(1, 0):
                    confusion_counters[label][(len(tile.unique()))] += 1
        for i in range(3):
            confusion_counters[label][i + 1] += 0
    return confusion_counters

In [None]:
def per_slide_confusion(labeled_preds: dict) -> dict:
    """
    confusion on a per-slide basis
    """
    confusion = {}
    for slide, pred in labeled_preds.items():
        counter = Counter()
        for tile in pred.transpose(1, 0):
            counter[(len(tile.unique()))] += 1
        confusion[slide] = counter
    return confusion

In [None]:
def count_disagreement(confusion_by_slide: dict) -> list:
    """
    isolate the counts of total disagreement
    """
    disagreement_counts = []
    for counter in confusion_by_slide.values():
        tiles = 0
        for count in counter.values():
            tiles += count
        disagreement_counts.append(counter.get(3, 0) / tiles)
    return disagreement_counts

### Get predictions from each model

In [None]:
fms = ["uni", "prism", "gigapath"]
experiments = [
    "old",
    "new",
    "sq_norm",
    "top_pct",
    "gaussian_mixture_separate",
    "gaussian_mixture_combined",
]
filtering_strategy = "agg"
param_patterns = [
    "/opt/gpudata/skin-cancer/models/few-shot/intersects/{fm}_param2.pkl",
    "/opt/gpudata/skin-cancer/models/few-shot/new/intersects/{fm}_param.pkl",
] + [
    partial(
        "/opt/gpudata/skin-cancer/models/few-shot/new/filtered/{fm}_param-{method}-{filtering_strategy}.pkl".format,
        method=method,
        filtering_strategy=filtering_strategy,
    )
    for method in experiments[2:]
]
embedding_dir = "{root}/{fm}/tile_embeddings_sorted"


def load_model(param_path):
    """
    loads a single model
    """
    with open(param_path, "rb") as f:
        param = pickle.load(f)
        model = NearestCentroid(NCLabel, centroids=param, mode="intersects")
    return model


def load_models(param_path: str) -> list:
    """
    load the models with associated metadata
    """
    models = []
    for fm in fms:
        if isinstance(param_path, str):
            model = load_model(param_path.format(fm=fm))
        else:
            model = load_model(
                param_path(fm=fm)
            )  # param path is an instance of partial
        models.append(
            {
                "fm": fm,
                "embedding_dir": embedding_dir.format(root=OUTPUT_DIR, fm=fm),
                "model": model,
            }
        )
    return models


models = {
    exp: load_models(param_patterns[i]) for i, exp in enumerate(experiments)
}

# get list of slides we have tile embeddings for
all_slides = os.listdir(embedding_dir.format(root=OUTPUT_DIR, fm=fms[0]))

In [None]:
# get preds
preds = {exp: get_preds(all_slides, models[exp]) for exp in experiments}
labeled_preds = {exp: get_labeled_preds(preds[exp]) for exp in experiments}

In [None]:
# def get_roi_tiles(roi_dir: str) -> dict:
#     # get the roi tiles used for training the models
#     tmp = NearestCentroid(NCLabel)
#     tmp.fit(
#         tile_embed_dir=os.path.join(OUTPUT_DIR, "uni/tile_embeddings_sorted"),
#         roi_dir=roi_dir,
#     )

#     roi_tiles = tmp.roi_tiles
#     del tmp
#     return roi_tiles


# roi_tiles_old = get_roi_tiles(
#     "/opt/gpudata/skin-cancer/models/few-shot/intersects"
# )
# roi_tiles_new = get_roi_tiles(
#     "/opt/gpudata/skin-cancer/models/few-shot/new/intersects"
# )

In [None]:
# def get_roi_tiles_unlabeled(roi_tiles: dict):
#     # separate the roi tiles from their lables and map slide
#     # ids directly to tensors of coords (just unpacking the
#     # dict returned by get_roi_tiles())
#     roi_tiles_unlabeled = {}

#     for label, slides in roi_tiles.items():
#         for slide, tiles in slides.items():
#             t = []
#             for x in tiles:
#                 # why did this happen?
#                 if not isinstance(x, tuple):
#                     print(label, x, slide)
#                 else:
#                     t.append(x)
#             roi_tiles_unlabeled[slide] = torch.tensor(t, dtype=torch.float32)
#     return roi_tiles_unlabeled


# roi_tiles_unlabeled_old = get_roi_tiles_unlabeled(roi_tiles_old)
# roi_tiles_unlabeled_new = get_roi_tiles_unlabeled(roi_tiles_new)

In [None]:
# def get_roi_preds(roi_tiles_unlabeled, models):
#     # get preds just for the tiles within the ROIs
#     roi_preds = get_preds(
#         [f"{slide}.pkl" for slide in roi_tiles_unlabeled.keys()],
#         models,
#         roi_tiles_unlabeled,
#     )
#     roi_labeled_preds = get_labeled_preds(roi_preds)
#     return roi_labeled_preds


# roi_labeled_preds_old = get_roi_preds(roi_tiles_unlabeled_old, old_models)
# roi_labeled_preds_new = get_roi_preds(roi_tiles_unlabeled_new, new_models)

In [None]:
def get_odd_model_ratios(labeled_preds):
    """
    get the ratio of the number of times a model was the
    "odd-one-out" on a prediction (i.e., when confusion=2)
    """
    odd_model = []
    for preds in labeled_preds.values():
        for col in preds.T:
            if len(col.unique()) == 2:
                if col[0] == col[1]:
                    odd_model.append(2)
                elif col[1] == col[2]:
                    odd_model.append(0)
                else:
                    odd_model.append(1)
    _, odd_model_counts = np.unique(np.array(odd_model), return_counts=True)
    odd_model_ratios = odd_model_counts / odd_model_counts.sum()
    return odd_model_ratios


odd_models = {
    exp: {
        fms[i]: x
        for i, x in enumerate(get_odd_model_ratios(labeled_preds[exp]))
    }
    for exp in experiments
}

In [None]:
for exp in experiments:
    print(f"{exp}: {odd_models[exp]}")

Calculate confusion scores between models for each tile assessed. Confusion scores are:  
1 = all models agree  
2 = one model disagrees  
3 = all models disagree  

In [None]:
confusion_counts = {}
for exp in experiments:
    confusion_counts[exp] = count_confusion_scores(labeled_preds[exp])
    print(f"{exp}: {confusion_counts[exp]}")

### Visualize the distribution of disagreement across all slides

In [None]:
# x: proportion of tiles in disagreement
# y: proportion of slides with a given level of disagreement
for exp in experiments:
    confusion = per_slide_confusion(labeled_preds[exp])
    disagreement_counts = count_disagreement(confusion)
    # visualize the distribution of total disagreement
    histplot(
        disagreement_counts,
        element="step",
        fill=False,
        stat="proportion",
        label=exp,
    )
plt.xlim(left=0, right=0.2)
plt.ylim(bottom=0, top=0.25)
plt.xlabel("Proportion of tiles in total disagreement")
plt.legend()

In [None]:
# for i, roi_labeled_preds in enumerate(
#     [roi_labeled_preds_old, roi_labeled_preds_new]
# ):
#     confusion = per_slide_confusion(roi_labeled_preds)
#     disagreement_counts = count_disagreement(confusion)
#     # visualize the distribution of total disagreement
#     histplot(
#         disagreement_counts,
#         element="step",
#         fill=True,
#         stat="proportion",
#         label=model_version[i],
#     )
# plt.xlim(left=0)
# plt.xlabel("Proportion of tiles in disagreement")
# plt.legend()

### Visualize confusion levels across all slides vs. across ROIs

In [None]:
# confusion_counters = [
#     [
#         confusion_by_gt(all_labeled_preds_old),
#         confusion_by_gt(all_labeled_preds_new),
#         confusion_by_gt(all_labeled_preds_filt),
#     ],
#     [
#         confusion_by_gt(roi_labeled_preds_old),
#         confusion_by_gt(roi_labeled_preds_new),
#     ],
# ]

In [None]:
confusion_counters = {
    exp: confusion_by_gt(labeled_preds[exp]) for exp in experiments
}

In [None]:
# visualize confusion levels
fig, axs = plt.subplots(2, 3, figsize=(18, 12), sharey=True)
for i, axr in enumerate(axs):
    for j, ax in enumerate(axr):
        exp = experiments[i * 3 + j]
        conf_counter = confusion_counters[exp]
        df = pd.DataFrame(conf_counter).T
        df[list(range(1, 4))].div(df.sum(axis=1), axis=0).plot(
            kind="bar", ax=ax, legend=False
        )
        ax.set_xticks(
            ticks=list(range(df.shape[0])),
            labels=[
                tick.name if isinstance(tick, Label) else tick
                for tick in df.index
            ],
        )
        ax.set_xlabel("specimen label")
        ax.set_ylabel("% of tiles")
        ax.set_title(exp)
plt.tight_layout()
plt.show()

Calculate the number of tiles that are predicted (obviously) incorrectly

In [None]:
# map labels that are obviously incorrect for a given gt label
obv_incorrect_labels = {
    Label.na: {
        NCLabel.bcc_nodular,
        NCLabel.bcc_superficial,
        NCLabel.bowens,
        NCLabel.scc,
    },
    Label.bcc: {NCLabel.bowens, NCLabel.scc},
    Label.bowens: {NCLabel.bcc_nodular, NCLabel.bcc_superficial},
    Label.scc: {NCLabel.bcc_nodular, NCLabel.bcc_superficial},
}

# convert the values from above to tensors for use in torch.isin
for gt in obv_incorrect_labels:
    obv_incorrect_labels[gt] = torch.tensor(
        [x.value for x in obv_incorrect_labels[gt]]
    )

In [None]:
def get_vote_preds(all_labeled_preds) -> dict:
    """
    compute the mode prediction for each tile across the models"
    """
    vote_preds = {}
    for slide, pred in all_labeled_preds.items():
        vote_preds[slide] = pred.mode(dim=0).values
    return vote_preds


vote_preds = {exp: get_vote_preds(labeled_preds[exp]) for exp in experiments}

In [None]:
def count_incorrect(preds: Dict[str, torch.Tensor]) -> Dict[Label, Counter]:
    """
    get counts of tiles obviously incorrectly predicted
    """
    obv_incorrect_counts = {}
    for label in Label:
        idx = label.value
        obv_incorrect_counts[label] = Counter()
        for slide, pred in preds.items():
            # if ground truth matches label, then count obv incorrect
            if slide[:6] in specimens_by_label[idx]:
                mask = torch.isin(pred, obv_incorrect_labels[label])
                for tile_pred in pred[mask]:
                    obv_incorrect_counts[label][tile_pred.item()] += 1

    return obv_incorrect_counts

In [None]:
def get_incorrect_counts(
    all_labeled_preds: dict, vote_preds: dict = None
) -> dict:
    incorrect_counts = {
        fm: count_incorrect(
            {slide: pred[i] for slide, pred in all_labeled_preds.items()}
        )
        for i, fm in enumerate(fms)
    }
    if vote_preds:
        incorrect_counts["vote"] = count_incorrect(vote_preds)
    return incorrect_counts


incorrect_counts = {
    exp: get_incorrect_counts(labeled_preds[exp], vote_preds[exp])
    for exp in experiments
}

In [None]:
incorrect_counts["gaussian_mixture_combined"]

In [None]:
# proportion of obviously incorrect classifications for benign slides
# by classification
df = pd.DataFrame(confusion_counters[experiments[0]]).T
for exp in experiments:
    print(f"**{exp}**")
    for model, counts in incorrect_counts[exp].items():
        print(model)
        for incorrect, count in counts[Label.na].items():
            # numerator: num incorrect; denominator: total benign gt slides
            print(
                f"{NCLabel(incorrect).name}: {count / df.loc[Label.na].sum()}"
            )
        print()
    print("------------------")

In [None]:
def class_proportions(ground_truth: Label, preds: dict):
    # preds must be labeled, ie {slide_id: tensor with shape (# models, # tiles)}
    # and values in tensors are integers
    gt_label = ground_truth.value
    props = np.zeros((len(all_slides), len(NCLabel._member_names_)))
    slides = []

    i = 0
    for slide_id in all_slides:
        if slide_id[:6] in specimens_by_label[gt_label]:
            slides.append(slide_id[:-4])
            counts = preds[slide_id[:-4]].unique(return_counts=True)
            tile_count = counts[1].sum()
            for j, label in enumerate(counts[0]):
                props[i][label] = counts[1][j] / tile_count
            i += 1
    props = props[:i].T
    return props, slides


def proportions_hist(
    class_props: list, ax, ground_truth: Label, experiment_name: str = ""
):
    bins = np.linspace(0, 1, 100)

    for i, pcts in enumerate(class_props):
        if i not in {6}:
            histplot(
                np.array(pcts),
                bins=bins,
                element="step",
                fill=False,
                stat="proportion",
                label=NCLabel(i).name,
                ax=ax,
            )
        ax.legend()
        ax.set_title(f"{ground_truth.name}-{exp}")
        ax.set_xlabel("proportion of tiles per slide")
        ax.set_ylabel("proportion of slides")
        ax.set_xlim(left=0, right=0.15)


def plot_heatmap(
    slide_id: str,
    preds: torch.Tensor,
    out_name: str,
    title: str,
    embedding_dir: str,
):
    if os.path.exists(out_name):
        return

    with open(os.path.join(embedding_dir, f"{slide_id}.pkl"), "rb") as f:
        slide_data = pickle.load(f)

    fig, ax = plt.subplots(figsize=(10, 10), constrained_layout=True)
    plot_image(
        fpath="/opt/gpudata/skin-cancer/data/slides/" + f"{slide_id}.svs",
        ax=ax,
        tile_coords=slide_data["coords"],
        tile_weights=preds,
        weight_labels={label.name: label.value for label in NCLabel},
    )
    ax.set_title(title)
    fig.savefig(
        out_name,
        bbox_inches="tight",
        dpi=200,
    )
    plt.close("all")

In [None]:
# for all slides, get percent of each class
props = {
    exp: {label: class_proportions(label, vote_preds[exp]) for label in Label}
    for exp in experiments
}

In [None]:
# plot distribution of proportion of tiles assigned to each class
# per slide
label_of_interest = Label.bcc
fig, axs = plt.subplots(2, 3, figsize=(18, 12), sharey=True, sharex=True)
for k, exp in enumerate(experiments):
    i = k // 3
    j = k % 3
    proportions_hist(
        props[exp][label_of_interest][0], axs[i][j], label_of_interest, exp
    )

In [None]:
sampled_specs = {
    "bowens": "660524-2",
    "bcc": "660369-6",
    "scc": "660109-1",
    "na": "660375-1",
}
embedding_dir = models[experiments[0]][0]["embedding_dir"]

for exp in experiments:
    model_outputs = {
        fm: f"labeled_preds[exp].get(slide_id)[{i}]"
        for i, fm in enumerate(fms)
    }
    model_outputs["vote"] = "vote_preds[exp].get(slide_id)"

    for i, (fm, pred_string) in enumerate(model_outputs.items()):
        for ground_truth, slide_id in sampled_specs.items():
            print(f"{exp}-{fm}-{slide_id}")
            preds = eval(pred_string)
            plot_heatmap(
                slide_id,
                preds,
                f"{exp}-{fm}-{slide_id}.png",
                f"{exp}-{fm}-{slide_id}",
                embedding_dir,
            )

In [None]:
onehot = np.zeros((len(all_slides), len(Label._member_names_)))
for i, slide_id in enumerate(all_slides):
    for j, lst in enumerate(specimens_by_label):
        if slide_id[:6] in lst:
            onehot[i][j] = 1

props = np.zeros((len(all_slides), len(NCLabel._member_names_)))
for i, slide_id in enumerate(all_slides):
    counts = vote_preds["gaussian_mixture_separate"][slide_id[:-4]].unique(
        return_counts=True
    )
    tile_count = counts[1].sum()
    for j, label in enumerate(counts[0]):
        props[i][label] = counts[1][j] / tile_count

In [None]:
ax = plt.subplot()
plot = RocCurveDisplay.from_predictions
plot(onehot[:, 0], props[:, :2].sum(axis=-1), name="benign", ax=ax)
plot(onehot[:, 1], props[:, 2], name="bowens", ax=ax)
plot(onehot[:, 2], props[:, 3:5].sum(axis=-1), name="bcc", ax=ax)
plot(onehot[:, 3], props[:, 5], name="scc", ax=ax)