In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from dotenv import load_dotenv

load_dotenv()
os.chdir("../..")

DATA_DIR = os.getenv("DATA_DIR")
OUTPUT_DIR = os.getenv("OUTPUT_DIR")

In [None]:
import os
import pickle

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import pandas as pd
from seaborn import histplot
from sklearn.decomposition import PCA
from sklearn.mixture import GaussianMixture
import torch

from data_models.Label import NCLabel, Label
from models.nearest_centroid.nearest_centroid import NearestCentroid
from utils.load_data import SpecimenData
from utils.slide_utils import plot_image

In [None]:
fms = ["uni", "prism", "gigapath"]
old_model_dir = (
    "/opt/gpudata/skin-cancer/models/few-shot/intersects/{fm}_param2.pkl"
)
new_model_dir = (
    "/opt/gpudata/skin-cancer/models/few-shot/new/intersects/{fm}_param.pkl"
)
top_pct_model_dir = "/opt/gpudata/skin-cancer/models/few-shot/new/filtered/{fm}_param-top_pct-agg.pkl"
sq_norm_model_dir = "/opt/gpudata/skin-cancer/models/few-shot/new/filtered/{fm}_param-sq_norm-agg.pkl"
gm_sep_model_dir = "/opt/gpudata/skin-cancer/models/few-shot/new/filtered/{fm}_param-gaussian_mixture_separate-agg.pkl"


def open_param(fpath):
    with open(fpath, "rb") as f:
        param = pickle.load(f)
    return param


params = {f"{fm}_old": open_param(old_model_dir.format(fm=fm)) for fm in fms}

params.update(
    {f"{fm}_new": open_param(new_model_dir.format(fm=fm)) for fm in fms}
)
params.update(
    {
        f"{fm}_top_pct": open_param(top_pct_model_dir.format(fm=fm))
        for fm in fms
    }
)
params.update(
    {
        f"{fm}_sq_norm": open_param(sq_norm_model_dir.format(fm=fm))
        for fm in fms
    }
)
params.update(
    {f"{fm}_gm_sep": open_param(gm_sep_model_dir.format(fm=fm)) for fm in fms}
)

In [None]:
components = {}
explained_variance = []
for fm, param in params.items():
    pca = PCA(n_components=3)
    components[fm] = pca.fit_transform(param)
    explained_variance.append(pca.explained_variance_ratio_)

In [None]:
axis_lim = float("-inf")
for pcs in components.values():
    axis_lim = max(axis_lim, max(abs(pcs.min()), abs(pcs.max())))
axis_lim = np.array((-axis_lim, axis_lim))

In [None]:
fig, axs = plt.subplots(
    5, 3, figsize=(18, 30), subplot_kw={"projection": "3d"}
)
n_labels = len(NCLabel)
origin = np.zeros((3, n_labels))
legend_patches = [
    mpatches.Patch(color=plt.cm.inferno(i / (n_labels - 1)), label=label)
    for i, label in enumerate(NCLabel._member_names_)
]
colors = [plt.cm.inferno(i / (n_labels - 1)) for i in range(n_labels)]
for k, (fm, pcs) in enumerate(components.items()):
    i = k // len(axs[0])
    j = k % len(axs[0])
    ax = axs[i][j]
    ax.set_xlim(*axis_lim)
    ax.set_ylim(*axis_lim)
    ax.set_zlim(*axis_lim)
    ax.set_title(fm)
    ax.set_facecolor("lightblue")
    q = ax.quiver(
        *origin,
        pcs[:, 0],
        pcs[:, 1],
        pcs[:, 2],
        color=colors,
        # angles="xy",
        # scale_units="xy",
        # scale=1,
        cmap="inferno",
    )
    ax.legend(handles=legend_patches, title="Labels", loc="upper right")
    watermark = f"explained variance: {explained_variance[k].round(4)}"
    # ax.text(
    #     -len(watermark) // 2,
    #     -(axis_lim[1] // 10 * 10),
    #     watermark,
    #     fontsize=10,
    #     color="gray",
    #     alpha=0.5,
    # )
plt.tight_layout()
plt.show()

In [None]:
fm = "gigapath"
roi_dir = "/opt/gpudata/skin-cancer/models/few-shot/new/intersects"
embeddings_path = os.path.join(OUTPUT_DIR, f"{fm}/tile_embeddings_sorted")

In [None]:
roi_tiles = {}
embeds = {}
roi_embeds = {}
roi_components = {}

for label in NCLabel._member_names_:
    # extract the roi tiles for the current label
    with open(os.path.join(roi_dir, f"{label}-roi.pkl"), "rb") as f:
        roi_tiles[label] = pickle.load(f)

    embeds[label] = {}
    roi_embeds[label] = {}
    roi_components[label] = {}

    for slide_id in roi_tiles[label]:
        # retrieve the embeddings for all annotated slides
        with open(os.path.join(embeddings_path, f"{slide_id}.pkl"), "rb") as f:
            embeds[label][slide_id] = pickle.load(f)

        # extract only the relevant tile embeddings
        roi_embeds[label][slide_id] = torch.stack(
            NearestCentroid._get_roi_embeds(
                embeds[label][slide_id], roi_tiles[label][slide_id]
            )
        )

        # perform PCA on the relevant embeddings
        pca = PCA(n_components=2)
        roi_components[label][slide_id] = pca.fit_transform(
            roi_embeds[label][slide_id]
        )

In [None]:
fig, axs = plt.subplots(len(NCLabel), 1, figsize=(6, 6 * 7))
for label_enum in NCLabel:
    label = label_enum.name
    label_val = label_enum.value
    ax = axs[label_val]
    n_labels = len(NCLabel)

    ax.set_xlim(*axis_lim)
    ax.set_ylim(*axis_lim)

    for slide_id, pcs in roi_components[label].items():
        origin = np.zeros((2, pcs.shape[0]))
        ax.quiver(
            *origin,
            pcs[:, 0],
            pcs[:, 1],
            angles="xy",
            scale_units="xy",
            scale=1,
            color="black",
        )

    ax.quiver(
        0,
        0,
        components[f"{fm}-gm_sep"][label_val, 0],
        components[f"{fm}-gm_sep"][label_val, 1],
        angles="xy",
        scale_units="xy",
        scale=1,
        color="red",
    )
    ax.set_title(label)

plt.tight_layout()
plt.show()

In [None]:
model = NearestCentroid(NCLabel, params[f"{fm}_new"])
preds = {}
for label in NCLabel._member_names_:
    preds[label] = []
    for slide_id, roi_embed in roi_embeds[label].items():
        preds[label].append(
            model.predict(roi_embed.float(), mode="dot_product")
        )
    preds[label] = torch.cat(preds[label])

In [None]:
centroids_dot = model.predict(model.centroids, mode="dot_product")
dermis_dot = centroids_dot[0, 0].item()
epi_dot = centroids_dot[1, 1].item()

In [None]:
centroids_dot

In [None]:
histplot(
    preds["scc"][:, 5],
    element="step",
    fill=True,
    stat="proportion",
    bins=100,
)

idea: create an ECDF from this data for each ROI to inform the filtering params for top_pct (e.g., for scc, filter out any tiles that are above 725 dp)  
and use the histograms themselves to inform the sq_norm thresholds

In [None]:
def filter(roi_embeds, filter_func, **kwargs):
    roi_embeds = roi_embeds.float()
    preds = model.predict(roi_embeds, mode="dot_product")[:, :2]
    filter = filter_func(preds, **kwargs)
    return roi_embeds[filter]

In [None]:
# ad hoc/arbitrary filtering methods
def sq_norm_filter(model_preds, dermis_thresh=0.7, epi_thresh=0.7, **kwargs):
    def get_norm_filter(model_preds, thresh):
        return model_preds < thresh

    dermis_filter = get_norm_filter(
        model_preds[:, 0], dermis_thresh * dermis_dot
    )
    epi_filter = get_norm_filter(model_preds[:, 1], epi_thresh * epi_dot)
    return torch.logical_and(dermis_filter, epi_filter)


def top_pct_filter(model_preds, discard_top=0.7, **kwargs):
    n_tiles = model_preds.shape[0]

    def get_pct_filter(model_preds, ret):
        _, indices = torch.topk(model_preds, ret, largest=False)
        return torch.isin(torch.arange(model_preds.shape[0]), indices)

    dermis_filter = get_pct_filter(
        model_preds[:, 0], int((1 - discard_top) * n_tiles)
    )
    epi_filter = get_pct_filter(
        model_preds[:, 1], int((1 - discard_top) * n_tiles)
    )
    return torch.logical_and(dermis_filter, epi_filter)

In [None]:
def get_mixture_preds(model_preds):
    mixture_model = GaussianMixture(n_components=2)
    if len(model_preds.shape) == 1:
        model_preds = model_preds.reshape(-1, 1)
    gm_outputs = mixture_model.fit_predict(model_preds)
    gaussian_of_interest = mixture_model.means_.argmin()
    filter = gm_outputs == gaussian_of_interest
    return torch.tensor(filter)


def separate_mixture_filter(model_preds, **kwargs):
    dermis_filter = get_mixture_preds(model_preds[:, 0])
    epi_filter = get_mixture_preds(model_preds[:, 1])
    return torch.logical_and(dermis_filter, epi_filter)


def combined_mixture_filter(model_preds, **kwargs):
    return get_mixture_preds(model_preds)

In [None]:
# TODO: should filtering occur per-slide or in aggregate? does it matter?
# - square norm will be the same regardless
# - all others may differ - try both methods

In [None]:
# filtering on a per-slide basis
filter_funcs = [
    sq_norm_filter,
    top_pct_filter,
    separate_mixture_filter,
]
filter_methods = [
    "sq_norm",
    "top_pct",
    "gaussian_mixture_separate",
]


def per_slide_filtering():
    centroids = {method: None for method in filter_methods}
    for i, func in enumerate(filter_funcs):
        # add the two normal tissue centroids, since they will not be filtered
        filtered_centroids = [*model.centroids[:2]]

        # only create centroids for cancerous labels
        for label in NCLabel._member_names_[2:-1]:
            print(label)
            res = []
            for roi_embed in roi_embeds[label].values():
                res.append(filter(roi_embed, func))
            filtered_centroid = torch.cat(res).mean(dim=0).float()
            filtered_centroids.append(filtered_centroid)

        # append the artifact centroid
        filtered_centroids.append(model.centroids[-1])

        # create the param matrix
        filtered_centroids = torch.stack(filtered_centroids, dim=0)
        centroids[filter_methods[i]] = filtered_centroids

    return centroids

In [None]:
# filtering on an aggregate basis
def agg_filtering():
    centroids = {method: None for method in filter_methods}
    for i, func in enumerate(filter_funcs):

        # add the two normal tissue centroids, since they will not be filtered
        filtered_centroids = [*model.centroids[:2]]

        # only create centroids for cancerous labels
        for label in NCLabel._member_names_[2:-1]:
            print(label)
            embeds = torch.cat(list(roi_embeds[label].values()))
            filtered = filter(embeds, func)
            filtered_centroid = filtered.mean(dim=0).float()
            filtered_centroids.append(filtered_centroid)

        # append the artifact centroid
        filtered_centroids.append(model.centroids[-1])

        # create the param matrix
        filtered_centroids = torch.stack(filtered_centroids, dim=0)
        centroids[filter_methods[i]] = filtered_centroids

    return centroids

In [None]:
centroids = per_slide_filtering()
for method in filter_methods:
    filtered_model = NearestCentroid(NCLabel, centroids[method])
    filtered_model.save_model(
        f"/opt/gpudata/skin-cancer/models/few-shot/new/filtered/{fm}_param-{method}-per_slide.pkl"
    )

In [None]:
centroids = agg_filtering()
for method in filter_methods:
    filtered_model = NearestCentroid(NCLabel, centroids[method])
    filtered_model.save_model(
        f"/opt/gpudata/skin-cancer/models/few-shot/new/filtered/{fm}_param-{method}-agg.pkl"
    )