### Imports

In [None]:
from pathlib import Path
import os
from tqdm import tqdm
from PIL import Image
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import pandas as pd
import numpy as np
import pickle as pkl
from sklearn.manifold import TSNE
from omegaconf import OmegaConf as omcon
from sklearn.metrics import f1_score, precision_recall_curve, confusion_matrix, ConfusionMatrixDisplay
from torchmetrics.classification import MulticlassAveragePrecision

import torch

import sys
sys.path.append("..") # root of repo
sys.path.append("/nfs/home/rhotertj/Code/thesis/src")
import src.multimodal_transforms as mmt
from src.lit_models import LitModel, weighted_cross_entropy, unweighted_cross_entropy, twin_head_loss, Cache
from src.video_models import make_kinetics_mvit
from src.graph_models import GAT, PositionTransformer, GIN
from src.multimodal_models import MultiModalModel
from src.lit_data import LitMultiModalHblDataset, LitResampledHblDataset, collate_function_builder
from src.data import LabelDecoder
from src.utils import get_proportions_df


### Load model and dataset

In [None]:
ckpt_dir = Path("/nfs/home/rhotertj/Code/thesis/experiments/input_format/posiformer_indicator_shuffle_long")
# ckpt_dir = Path("/nfs/home/rhotertj/Code/thesis/experiments/architecture/mvit_twin")



ckpt_file = [f for f in os.listdir(ckpt_dir) if f.endswith(".ckpt")][-1]
print("Loading", ckpt_file)
config = omcon.load(ckpt_dir / "config.yaml")
cache = Cache()
cache.load(ckpt_dir / "test_results.pkl")

# remove pretrained model weights, they are overwritten anyways by lightning checkpoint
if config.model.name == "MultiModalModel":
    config.model.params.video_model_ckpt = ""
    config.model.params.graph_model_ckpt = ""

loss_func = eval(config.loss_func)
label_decoder = LabelDecoder(config.num_classes)
config.data.params.batch_size = 1
config.data.params.load_frames = True

lit_dataset = eval(config.data.name)(**config.data.params, label_mapping=label_decoder)
lit_dataset.setup("test")
val_loader = lit_dataset.test_dataloader()
dataset = lit_dataset.data_test

base_path = Path(".") / "analysis" / ckpt_dir.name
if not os.path.exists(base_path):
    os.makedirs(base_path)

model = eval(config.model.name)(**config.model.params,  num_classes=config.num_classes, batch_size=config.data.params.batch_size)
print(config.model.name)
lit_model = LitModel.load_from_checkpoint(
    ckpt_dir / ckpt_file,
    optimizer=None,
    scheduler=None,
    loss_func=loss_func,
    model=model,
    label_mapping=label_decoder,
    experiment_dir=base_path
)
lit_model = lit_model.eval()
if torch.cuda.is_available():
    print("GPU!")
    lit_model.cuda()
else:
    print("CPU")



### Run prediction on validation data and save results

In [None]:
val_res_name = base_path / "val_results.pkl"
if not os.path.exists(val_res_name):

    val_results = [] # list of dicts with all info but frames, we can load them later via query idx!
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    for instance in tqdm(val_loader):
        if "frames" in instance:
            instance["frames"] = instance["frames"].to(device)
        instance["positions"] = instance["positions"].to(device)
        pred = lit_model.forward(instance)
        if isinstance(pred, tuple):
            pred, pred_offset = pred
            pred = pred.detach().cpu()
            pred_offset = pred_offset.detach().cpu()
            loss = loss_func(pred, pred_offset, instance["label"], instance["label_offset"])
        else:
            pred = pred.detach().cpu()
            loss = loss_func(pred, instance["label"], instance["label_offset"])
        res = {
            "query_idx" : instance["query_idx"].item(),
            "frame_idx" : instance["frame_idx"].item(),
            "match_number" : instance["match_number"].item(),
            "label" : instance["label"].item(),
            "label_offset" : instance["label_offset"].item(),
            "label_idx" : instance["label_offset"].item(),
            "prediction" : pred.detach().cpu().numpy().argmax(-1),
            "confidences" : pred.detach().cpu().numpy(),
            "loss" : loss.item()
        }
        val_results.append(res)

    with open(val_res_name, "wb+") as f:
        pkl.dump(val_results, f)

else:
    with open(val_res_name, "rb") as f:
        val_results = pkl.load(f)

### Metrics

In [None]:
gt = cache.data["ground_truths"]
preds = cache.data["predictions"]
confs = torch.stack(cache.data["confidences"])

f1 = f1_score(gt, preds, average=None)
print("f1:", f1)
ap_metric = MulticlassAveragePrecision(num_classes=3, average="none")
gt_tensor = torch.tensor(gt)
print("AP:", ap_metric(confs, gt_tensor))

confmat = confusion_matrix(gt, preds).astype(np.float64)
confmat_frac = np.zeros_like(confmat)
_, counts = np.unique(preds, return_counts=True)
fig, ax = plt.subplots(figsize=(8,6), dpi=300)
for i in range(3):
    confmat_frac[i] = (confmat[i, :] / np.sum(confmat[i, :]))

# initialize using the raw 2D confusion matrix 
# and output labels (in our case, it's 0 and 1)
confmat = confmat.astype(np.int64)
labels = ["Background", "Pass", "Shot"]
sns.heatmap(confmat_frac, annot=confmat, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels, cbar=True, square=True, ax=ax)
ax.set(ylabel="Ground Truth", xlabel="Predicted")

In [None]:
from wandb import util
from wandb.plots.utils import test_missing, test_types
import wandb


def pr_curve(
    y_true=None,
    y_probas=None,
    labels=None,
    classes_to_plot=None,
    interp_size=21,
):
    """Compute the tradeoff between precision and recall for different thresholds.

    A high area under the curve represents both high recall and high precision, where
    high precision relates to a low false positive rate, and high recall relates to a
    low false negative rate. High scores for both show that the classifier is returning
    accurate results (high precision), and returning a majority of all positive results
    (high recall). PR curve is useful when the classes are very imbalanced.

    Arguments:
        y_true (arr): true sparse labels y_probas (arr): Target scores, can either be
            probability estimates, confidence values, or non-thresholded measure of
            decisions. shape: (*y_true.shape, num_classes)
        labels (list): Named labels for target variable (y). Makes plots easier to read
            by replacing target values with corresponding index. For example labels =
            ['dog', 'cat', 'owl'] all 0s are replaced by 'dog', 1s by 'cat'.
        classes_to_plot (list): unique values of y_true to include in the plot
        interp_size (int): the recall values will be fixed to `interp_size` points
            uniform on [0, 1] and the precision will be interpolated for these recall
            values.

    Returns:
        Nothing. To see plots, go to your W&B run page then expand the 'media' tab under
        'auto visualizations'.

    Example:
        ```
        wandb.log({"pr-curve": wandb.plot.pr_curve(y_true, y_probas, labels)})
        ```
    """
    np = util.get_module(
        "numpy",
        required="roc requires the numpy library, install with `pip install numpy`",
    )
    pd = util.get_module(
        "pandas",
        required="roc requires the pandas library, install with `pip install pandas`",
    )
    sklearn_metrics = util.get_module(
        "sklearn.metrics",
        "roc requires the scikit library, install with `pip install scikit-learn`",
    )
    sklearn_utils = util.get_module(
        "sklearn.utils",
        "roc requires the scikit library, install with `pip install scikit-learn`",
    )

    def _step(x):
        y = np.array(x)
        for i in range(1, len(y)):
            y[i] = max(y[i], y[i - 1])
        return y

    y_true = np.array(y_true)
    y_probas = np.array(y_probas)

    if not test_missing(y_true=y_true, y_probas=y_probas):
        return
    if not test_types(y_true=y_true, y_probas=y_probas):
        return

    classes = np.unique(y_true)
    if classes_to_plot is None:
        classes_to_plot = classes

    precision = dict()
    interp_recall = np.linspace(0, 1, interp_size)[::-1]
    indices_to_plot = np.where(np.isin(classes, classes_to_plot))[0]
    for i in indices_to_plot:
        if labels is not None and (
            isinstance(classes[i], int) or isinstance(classes[0], np.integer)
        ):
            class_label = labels[classes[i]]
        else:
            class_label = classes[i]

        cur_precision, cur_recall, _ = sklearn_metrics.precision_recall_curve(
            y_true, y_probas[:, i], pos_label=classes[i]
        )
        # smooth the precision (monotonically increasing)
        cur_precision = _step(cur_precision)

        # reverse order so that recall in ascending
        cur_precision = cur_precision[::-1]
        cur_recall = cur_recall[::-1]
        indices = np.searchsorted(cur_recall, interp_recall, side="left")
        precision[class_label] = cur_precision[indices]

    df = pd.DataFrame(
        {
            "class": np.hstack([[k] * len(v) for k, v in precision.items()]),
            "precision": np.hstack(list(precision.values())),
            "recall": np.tile(interp_recall, len(precision)),
        }
    )
    df = df.round(3)

    if len(df) > wandb.Table.MAX_ROWS:
        wandb.termwarn(
            "wandb uses only %d data points to create the plots." % wandb.Table.MAX_ROWS
        )
        # different sampling could be applied, possibly to ensure endpoints are kept
        df = sklearn_utils.resample(
            df,
            replace=False,
            n_samples=wandb.Table.MAX_ROWS,
            random_state=42,
            stratify=df["class"],
        ).sort_values(["precision", "recall", "class"])

    return df

In [None]:
pr_df = pr_curve(gt, confs, labels=["Background", "Pass", "Shot"])
fig, ax = plt.subplots(figsize=(10,6), dpi=300)

ax = sns.lineplot(data=pr_df, x="recall", y="precision", hue="class", markers=True, dashes=False, palette="Set2", linestyle="--", ax=ax)
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.grid()
plt.show()

### Distribution over offsets

In [None]:
df = pd.DataFrame(cache.data)
df["correct"] = False
df.loc[df["ground_truths"] == df["predictions"], "correct"] = True
df_nobg = df[df["ground_truths"] != 0]
dataset.load_frames = True


In [None]:
f, ax = plt.subplots(1, 1, figsize=(20,8))
sns_plot = sns.countplot(data=df_nobg, x="label_offsets", hue="correct", width=0.8, ax=ax)
sns_plot.set_xticklabels(list(np.arange(0,16,0.5)))
sns_plot.set_xlabel("Position of annotated frame in window")
sns_plot.set_ylabel("Count")
legend = sns_plot.get_legend()
legend.set_title("Correct Prediction")
sns_plot.figure.savefig(base_path / "offsets.png",dpi=300)

### Confusion per class

In [None]:
def renew_figure():
    plt.clf()
    fig, axis = plt.subplots(2, n, figsize=(16,8))
    return fig, axis

def animate_samples(dataset, df, n):
    fig, axis = renew_figure()
    sample_frames = []
    sample_positions = []
    query_idx = []
    for i, sample in df.sample(n).iterrows():
        idx = i
        frames = dataset[idx]["frames"].numpy() # cthw
        positions = dataset[idx]["positions"].as_TNC(normalize=False)
        sample_frames.append(frames)
        sample_positions.append(positions)
        query_idx.append(idx)
    # list of 16 entries, each a list of n frames
    sample_frames = np.stack(sample_frames) # ncthw
    sample_frames = np.einsum("ncthw->nthwc", sample_frames)

    images = []
    for t in range(sample_frames.shape[1]):
        images_at_timestep = []
        for i in range(sample_frames.shape[0]):
            # video
            im = axis[0, i].imshow(sample_frames[i, t])
            axis[0, i].tick_params(
                top=False,
                bottom=False,
                left=False,
                right=False,
                labelleft=False,
                labelbottom=False
            )
            images_at_timestep.append(im)
            # positions
            colors = ["red", "green", "blue"]
            scatterplot = axis[1, i].scatter(
                x=sample_positions[i][t, :, 1],
                y=sample_positions[i][t, :, 2],
                c=[colors[int(m)] for m in sample_positions[i][t, :, 0]]
            )
            
            axis[1, i].tick_params(
                top=False,
                bottom=False,
                left=False,
                right=False,
                labelleft=False,
                labelbottom=False
            )
            axis[0, i].set_title(f"label: {sample['ground_truths']}, pred: {sample['predictions']}, {query_idx[i]}", fontsize='small', loc='left')
            images_at_timestep.append(scatterplot)

        images.append(images_at_timestep)
        

    ani = animation.ArtistAnimation(fig, images, interval=50, blit=False, repeat_delay=1000)
    return ani

for c in label_decoder.get_classnames():
    label_int = label_decoder.get_classnames().index(c)
    n = 5

    tp = df[(df["ground_truths"] == label_int) & (df["predictions"] == label_int)]
    fp = df[(df["ground_truths"] != label_int) & (df["predictions"] == label_int)]
    fn = df[(df["ground_truths"] == label_int) & (df["predictions"] != label_int)]


    animate_samples(dataset, tp, 5).save(base_path / f"{c}_tp.gif", fps=10)
    animate_samples(dataset, fp, 5).save(base_path / f"{c}_fp.gif", fps=10)
    animate_samples(dataset, fn, 5).save(base_path / f"{c}_fn.gif", fps=10)


### Top-k loss predictions

In [None]:
for c in label_decoder.get_classnames():
    label_int = label_decoder.get_classnames().index(c)
    n = 5
    c_df = df[df["ground_truths"] == label_int]
    topk = c_df.sort_values(by="loss", ascending=True).head(n)
    worstk = c_df.sort_values(by="loss", ascending=False).head(n)
    

    animate_samples(dataset, topk, n).save(base_path / f"{c}_topk.gif", fps=10)
    animate_samples(dataset, worstk, n).save(base_path / f"{c}_worstk.gif", fps=10)


### tSNE

In [None]:
n = 2000
representations = []
labels = []
preds = []
indices = np.random.randint(low=0, high=len(dataset), size=n)
# maybe do this with multimodal model and use the representations from there?
for i, instance in enumerate(tqdm(val_loader)):
    if i == n:
        break
    lit_model.model.head_type = "pool"
    r = lit_model.model(instance["positions"])
    lit_model.model.head_type = "classify"
    pred = lit_model.model(instance["positions"])
    representations.append(r)
    labels.append(label_decoder.class_names[instance["label"].item()])
    preds.append(label_decoder.class_names[pred.argmax().item()])

representations = torch.stack(representations).squeeze(1).detach().numpy()

In [None]:
print(representations.shape)
tsne = TSNE()
embedded_representations = tsne.fit_transform(representations)
print(embedded_representations.shape)

In [None]:
tsne_df = pd.DataFrame({
    "x" : embedded_representations[:,0],
    "y" : embedded_representations[:,1],
    "label" : labels,
    "prediction" : preds
}) 
# TODO: ground truth color, prediction as shape
sns.scatterplot(data=tsne_df, x="x", y="y", style="prediction", hue="label")