### Imports

In [None]:
from pathlib import Path
import os
from tqdm import tqdm
from PIL import Image
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import pandas as pd
import numpy as np
import pickle as pkl

from omegaconf import OmegaConf as omcon


import torch

import sys
sys.path.append("..") # root of repo
sys.path.append("/nfs/home/rhotertj/Code/thesis/src")
import src.multimodal_transforms as mmt
from src.lit_models import LitModel, weighted_cross_entropy, unweighted_cross_entropy, twin_head_loss, Cache
from src.video_models import make_kinetics_mvit
from src.graph_models import GAT, PositionTransformer, GIN
from src.multimodal_models import MultiModalModel
from src.lit_data import LitMultiModalHblDataset, LitResampledHblDataset, collate_function_builder
from src.data import LabelDecoder
from src.utils import get_proportions_df


### Load model and dataset

In [None]:
ckpt_dir = Path("/nfs/home/rhotertj/Code/thesis/experiments/input_format/posiformer_indicator_shuffle_long")

ckpt_file = [f for f in os.listdir(ckpt_dir) if f.endswith(".ckpt")][-1]
print("Loading", ckpt_file)
config = omcon.load(ckpt_dir / "config.yaml")
cache = Cache()
cache.load(ckpt_dir / "val_results.pkl")

# remove pretrained model weights, they are overwritten anyways by lightning checkpoint
if config.model.name == "MultiModalModel":
    config.model.params.video_model_ckpt = ""
    config.model.params.graph_model_ckpt = ""

loss_func = eval(config.loss_func)
label_decoder = LabelDecoder(config.num_classes)
config.data.params.batch_size = 1

lit_dataset = eval(config.data.name)(**config.data.params, label_mapping=label_decoder)
lit_dataset.setup("validate")
val_loader = lit_dataset.val_dataloader()
dataset = lit_dataset.data_val

base_path = Path(".") / "analysis" / ckpt_dir.name
if not os.path.exists(base_path):
    os.makedirs(base_path)

model = eval(config.model.name)(**config.model.params,  num_classes=config.num_classes, batch_size=config.data.params.batch_size)
print(config.model.name)
lit_model = LitModel.load_from_checkpoint(
    ckpt_dir / ckpt_file,
    optimizer=None,
    scheduler=None,
    loss_func=loss_func,
    model=model,
    label_mapping=label_decoder,
    experiment_dir=base_path
)
lit_model = lit_model.eval()
if torch.cuda.is_available():
    print("GPU!")
    lit_model.cuda()
else:
    print("CPU")



In [None]:
pd.DataFrame(cache.data)

### Run prediction on validation data and save results

In [None]:
val_res_name = base_path / "val_results.pkl"
if not os.path.exists(val_res_name):

    val_results = [] # list of dicts with all info but frames, we can load them later via query idx!
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    for instance in tqdm(val_loader):
        if "frames" in instance:
            instance["frames"] = instance["frames"].to(device)
        instance["positions"] = instance["positions"].to(device)
        pred = lit_model.forward(instance)
        if isinstance(pred, tuple):
            pred, pred_offset = pred
            pred = pred.detach().cpu()
            pred_offset = pred_offset.detach().cpu()
            loss = loss_func(pred, pred_offset, instance["label"], instance["label_offset"])
        else:
            pred = pred.detach().cpu()
            loss = loss_func(pred, instance["label"], instance["label_offset"])
        res = {
            "query_idx" : instance["query_idx"].item(),
            "frame_idx" : instance["frame_idx"].item(),
            "match_number" : instance["match_number"].item(),
            "label" : instance["label"].item(),
            "label_offset" : instance["label_offset"].item(),
            "label_idx" : instance["label_offset"].item(),
            "prediction" : pred.detach().cpu().numpy().argmax(-1),
            "confidences" : pred.detach().cpu().numpy(),
            "loss" : loss.item()
        }
        val_results.append(res)

    with open(val_res_name, "wb+") as f:
        pkl.dump(val_results, f)

else:
    with open(val_res_name, "rb") as f:
        val_results = pkl.load(f)

### Distribution over offsets

In [None]:
df = pd.DataFrame(cache.data)
df["correct"] = False
df.loc[df["ground_truths"] == df["predictions"], "correct"] = True
df_nobg = df[df["ground_truths"] != 0]
dataset.load_frames = True


In [None]:
sns.countplot(data=df_nobg, x="label_offsets", hue="correct", width=0.8)
plt.savefig(base_path / "offsets.png")

### Confusion per class

In [None]:
def renew_figure():
    plt.clf()
    fig, axis = plt.subplots(2, n, figsize=(16,8))
    return fig, axis

def animate_samples(dataset, df, n):
    fig, axis = renew_figure()
    sample_frames = []
    sample_positions = []
    for i, sample in df.sample(n).iterrows():
        idx = i
        frames = dataset[idx]["frames"].numpy() # cthw
        positions = dataset[idx]["positions"].as_TNC(normalize=False)
        sample_frames.append(frames)
        sample_positions.append(positions)
    # list of 16 entries, each a list of n frames
    sample_frames = np.stack(sample_frames) # ncthw
    sample_frames = np.einsum("ncthw->nthwc", sample_frames)

    images = []
    for t in range(sample_frames.shape[1]):
        images_at_timestep = []
        for i in range(sample_frames.shape[0]):
            # video
            im = axis[0, i].imshow(sample_frames[i, t])
            axis[0, i].tick_params(
                top=False,
                bottom=False,
                left=False,
                right=False,
                labelleft=False,
                labelbottom=False
            )
            images_at_timestep.append(im)
            # positions
            colors = ["red", "green", "blue"]
            scatterplot = axis[1, i].scatter(
                x=sample_positions[i][t, :, 1],
                y=sample_positions[i][t, :, 2],
                c=[colors[int(m)] for m in sample_positions[i][t, :, 0]]
            )
            
            axis[1, i].tick_params(
                top=False,
                bottom=False,
                left=False,
                right=False,
                labelleft=False,
                labelbottom=False
            )
            axis[0, i].set_title(f"label: {sample['ground_truths']}, pred: {sample['predictions']}", fontsize='small', loc='left')
            images_at_timestep.append(scatterplot)

        images.append(images_at_timestep)
        

    ani = animation.ArtistAnimation(fig, images, interval=50, blit=False, repeat_delay=1000)
    return ani

for c in label_decoder.get_classnames():
    label_int = label_decoder.get_classnames().index(c)
    n = 5

    tp = df[(df["ground_truths"] == label_int) & (df["predictions"] == label_int)]
    fp = df[(df["ground_truths"] != label_int) & (df["predictions"] == label_int)]
    fn = df[(df["ground_truths"] == label_int) & (df["predictions"] != label_int)]


    animate_samples(dataset, tp, 5).save(base_path / f"{c}_tp.gif", fps=10)
    animate_samples(dataset, fp, 5).save(base_path / f"{c}_fp.gif", fps=10)
    animate_samples(dataset, fn, 5).save(base_path / f"{c}_fn.gif", fps=10)


### Top-k loss predictions

In [None]:
for c in label_decoder.get_classnames():
    label_int = label_decoder.get_classnames().index(c)
    n = 5
    c_df = df[df["ground_truths"] == label_int]
    topk = c_df.sort_values(by="loss", ascending=True).head(n)
    worstk = c_df.sort_values(by="loss", ascending=False).head(n)
    

    animate_samples(dataset, topk, n).save(base_path / f"{c}_topk.gif", fps=10)
    animate_samples(dataset, worstk, n).save(base_path / f"{c}_worstk.gif", fps=10)
