### Imports

In [None]:
import sys
sys.path.append("..") # root of repo
sys.path.append("../src/")
import os
import numpy as np
import pandas as pd
import pickle as pkl
from collections import Counter, defaultdict
import os.path as osp
from pathlib import Path
from tqdm import tqdm
from PIL import Image
from src.data.datasets import MultiModalHblDataset, ResampledHblDataset
from src.data.labels import LabelDecoder
from src.utils import array2gif, draw_trajectory
from src.metrics import postprocess_predictions, average_mAP, postprocess_peaks_only
import torchvision
import torch
import itertools
import torchvision.transforms as t
import multimodal_transforms as mmt
import pytorchvideo.transforms as ptvt
from lit_data import collate_function_builder

from utils import * # debug import

### Plot class occurences

In [None]:
df = pd.read_json("dataset_train_sql=16_sr=2_nooverlap.jsonl", lines=True)
df.sample(5)
# TODO np diff over idx with actions at sr=1 gives us frame distance between actions, bar chart for intervals

In [None]:
# TODO Get frames and matches
data = ResampledHblDataset(
    meta_path="/nfs/home/rhotertj/datasets/hbl/meta30_valid.csv",
    idx_to_frame="/nfs/home/rhotertj/datasets/hbl/resampled/balanced/True/overlap/True/sql_sr/16x2/mode/matches/meta30_val.jsonl",
    seq_len=16,
    sampling_rate=2,
    load_frames=False,
    label_mapping=LabelDecoder(3)
)
ld = LabelDecoder(3)
shot_frames = []
pass_frames = []
for df in data.event_dfs:
    frames = df.index.to_numpy()
    label = df.labels.apply(lambda x : ld(x)).to_numpy()
    shot_frames.append(frames[label == 2])
    pass_frames.append(frames[label == 1])
shot_frames = np.concatenate(shot_frames)
pass_frames = np.concatenate(pass_frames)
print(len(shot_frames), len(pass_frames))


### Debugging idx -> Frame index mapping

In [None]:
sql = 4
hql = sql // 2
rate = 2
sr = sql * rate
hr = hql * rate
kernel = np.ones(sr)

availables = [True, True, True, True, True, True, True, True, True, True, True, True, True, False, False]
pos = np.arange(len(availables))

cv = np.convolve(availables, kernel)
print(f"{cv}")
idxs = np.where(cv == sr)[0] - (sr - 1) # subtract filter length - 1
print("idx for valid sequences:", idxs)

q_idx = 0
f_idx = idxs[q_idx] + hr
sequence = pos[f_idx - hr : f_idx + hr : rate]
print(f"Idx for sequence {q_idx}: {sequence}")

### Visualizing each class

In [None]:
dataset = ResampledHblDataset(
    meta_path="/nfs/home/rhotertj/datasets/hbl/meta30_valid.csv",
    idx_to_frame="/nfs/home/rhotertj/datasets/hbl/resampled/balanced/True/overlap/True/sql_sr/16x2/mode/matches/meta30_val.jsonl",
    seq_len=16,
    sampling_rate=2,
    load_frames=True,
    label_mapping=LabelDecoder(3)
)

In [None]:
os.makedirs("../img/classes", exist_ok=True)
n_per_class = 3
shots = list(range(0, 9))
passes = ['A', 'B', 'C', 'D', 'E', 'X', 'O']
df = dataset.idx_to_frame_number
for s,p in itertools.product(shots, passes):
    events = df[(df['shot'] == s) & (df['pass'] == p)]
    if len(events) > 0:
        events = events.sample(n_per_class)
        for i, (idx, event) in enumerate(events.iterrows()):
            instance = dataset.__getitem__(idx)
            fname = f"../img/classes/{s}_{p}_{i}"
            array2gif(instance["frames"], fname + ".gif", 10)
            f = draw_trajectory(instance["positions"])
            plt.savefig(fname + ".png")
    

### Calculating Mean and Standard Deviation

In [None]:
sql = 16
sr = 1
dataset_img = MultiModalHblDataset(
    meta_path="/nfs/home/rhotertj/datasets/hbl/meta3d.csv",
    seq_len=sql,
    sampling_rate=sr,
    load_frames=True
)
print(len(dataset_img))
n = 0
means = np.zeros((len(dataset_img) + sql, 3))
stds = np.zeros((len(dataset_img) + sql, 3))

for i in tqdm(range(0, len(dataset_img), sql)):
    frames = torch.tensor(dataset_img[i]["frames"]) / 255

    for j, frame in enumerate(frames):
        mean, std = frame.mean([1,2]), frame.std([1,2])
        means[i+j] = mean
        stds[i+j] = std

        n+=1
print(n)
print(means.shape)
# Maybe we want to calculate this per match instead of over the whole dataset


In [None]:
np.save("means.npy", means)
np.save("std.npy", stds)

In [None]:
print(means.mean(0), stds.mean(0))

### Test plotting

In [None]:
dataset_img = MultiModalHblDataset(
    meta_path="/nfs/home/rhotertj/datasets/hbl/meta3d.csv",
    seq_len=16,
    sampling_rate=2,
    load_frames=False
)

In [None]:
instance = dataset_img[11345]
draw_trajectory(instance["positions"])


In [None]:
poscon = instance["positions"].mirror_again(horizontal=True, vertical=False)
draw_trajectory(poscon)

### Test transforms and augmentation

In [None]:

transforms_jitter = t.Compose([
            mmt.FrameSequenceToTensor(),
            mmt.TimeFirst(),
            mmt.ColorJitter(brightness=0.2, hue=.2, contrast=0.2, saturation=0.2),
            mmt.ChannelFirst(),
            mmt.Resize(size=(224,224)),
            ])

transforms_randaugment = t.Compose([
            mmt.FrameSequenceToTensor(),
            mmt.TimeFirst(),
            mmt.Resize(size=(224,224)),
            #ptvt.RandAugment(num_layers=3, prob=0.5, magnitude=5),
            mmt.ChannelFirst(),
            ])

transforms_translate = t.Compose([
            mmt.FrameSequenceToTensor(),
            mmt.Resize(size=(224,224)),
            mmt.Translate()
    ])

transforms_raw = t.Compose([
            mmt.FrameSequenceToTensor(),
            mmt.Resize(size=(224,224)),
    ])

transforms_full = t.Compose([
            mmt.FrameSequenceToTensor(),
            mmt.RandomHorizontalFlipVideo(p=0.5),
            mmt.TimeFirst(),
            mmt.ColorJitter(brightness=0.2, hue=.2, contrast=0.2, saturation=0.2),
            #ptvt.RandAugment(num_layers=3, prob=0.5, magnitude=5),
            mmt.ChannelFirst(),
            mmt.Resize(size=(224,224)),
            ])
dataset_img = ResampledHblDataset(
    meta_path="/nfs/home/rhotertj/datasets/hbl/meta30_train.csv",
    idx_to_frame="/nfs/home/rhotertj/datasets/hbl/resampled/balanced/True/overlap/True/sql_sr/16x2/mode/matches/upsampled/True/meta30_train.jsonl",
    label_mapping=LabelDecoder(3),
    load_frames=True,
    seq_len=16,
    sampling_rate=2,
    transforms=transforms_raw
)

collate_mixvideo = collate_function_builder(epsilon=7, load_frames=True, mix_video=ptvt.MixVideo(num_classes=3, cutmix_alpha=0.8, cutmix_prob=0))
collate_fn = collate_function_builder(epsilon=7, load_frames=True)

In [5]:
raw_instances = []
jitter_instances = []
randaug_instances = []
combined_instances = []
translate_instances = []
idxs = [13456,23574,98533,64378,22546,324567,243343,9632] #random
for i in range(8):
    dataset_img.transforms = transforms_raw
    instance = dataset_img[idxs[i]]
    raw_instances.append(instance)

    dataset_img.transforms = transforms_translate
    instance = dataset_img[idxs[i]]
    translate_instances.append(instance)

    dataset_img.transforms = transforms_jitter
    instance = dataset_img[idxs[i]]
    jitter_instances.append(instance)

    dataset_img.transforms = transforms_randaugment
    instance = dataset_img[idxs[i]]
    randaug_instances.append(instance)

    dataset_img.transforms = transforms_full
    instance = dataset_img[idxs[i]]
    combined_instances.append(instance)
dataset_img.transforms = transforms_raw

raw_batch = collate_fn(raw_instances)
mix_batch = collate_mixvideo(raw_instances)
jitter_batch = collate_fn(jitter_instances)
translate_batch = collate_fn(translate_instances)
combined_batch = collate_mixvideo(combined_instances)

raw_pos = None
for i in range(8):
    for name, batch in zip(["raw", "mixvideo", "jitter", "translate", "combined"], [raw_batch, mix_batch, jitter_batch, translate_batch, combined_batch]):
        frames = batch["frames"][i].mul(255).to(torch.uint8).numpy()
        array2gif(frames, f"../img/transforms/{name}_transforms_{i}.gif", 10)
        if name == "raw":
            raw_pos = batch["positions"]
        if name == "translate":
            print(raw_pos.ndata["positions"][0], batch["positions"].ndata["positions"][0])
        if name in ("translate", "raw"):
            q_idx = batch["query_idx"][i].item()
            pos = dataset_img[q_idx]["positions"]
            fig = draw_trajectory(pos)
            plt.savefig(f"../img/transforms/{name}_transforms_positions_{i}.png")


### Visualizing each match

In [None]:
dataset_img = MultiModalHblDataset(
    meta_path="/nfs/home/rhotertj/datasets/hbl/meta30.csv",
    seq_len=16,
    sampling_rate=2,
    load_frames=True
)

In [None]:
n_per_match = 2
mult = 100
for i, border in tqdm(enumerate(dataset_img.index_tracker[:-1])):
    for j in range(n_per_match):
        instance = dataset_img[border + (j+2)*mult]
        mn = dataset_img.meta_df.iloc[instance["match_number"]]["match_id"].split(":")[-1]
        os.makedirs(f"/nfs/home/rhotertj/Code/thesis/img/matches/{mn}",exist_ok=True)
        fname = f"../img/matches/{mn}/{(j+8)*mult}"
        array2gif(instance["frames"], fname + ".gif", 10)
        f = draw_trajectory(instance["positions"])
        plt.savefig(fname + ".png")


### Visualizing average MAP

In [None]:
val_res_name = "/nfs/home/rhotertj/Code/thesis/dataset/analysis/blooming-hill-271/val_results.pkl"
# val_res_name = "/nfs/home/rhotertj/Code/thesis/dataset/analysis/copper-bush-8/val_results.pkl"
with open(val_res_name, "rb") as f:
    val_results = pkl.load(f)

df = pd.DataFrame(val_results)
confidences = np.concatenate(df.confidences.to_numpy())
frame_numbers = df.frame_idx.to_numpy()
match_numbers = df.match_number.to_numpy()
label_idx = df.label_idx.to_numpy()
labels = df.label.to_numpy()
label_offsets = df.label_offset.to_numpy()

# boost frame numbers per game
max_frame_magnitude = len(str(frame_numbers.max()))
frame_offset = 10**(max_frame_magnitude + 1)
frame_numbers = frame_numbers + (frame_offset * match_numbers)

correct_order = np.argsort(frame_numbers)
reordered_frames = frame_numbers[correct_order]
confidences = confidences[correct_order]

# labelidx solution
gt_labels_ll = []
gt_anchors_ll = []
offset_ll = []
for i, l_idx in enumerate(label_idx):
    if l_idx == -1:
        continue
    f = frame_offset * match_numbers[i] + l_idx
    if f in gt_anchors_ll:
        continue
    l = labels[i]
    gt_labels_ll.append(l)
    gt_anchors_ll.append(f)
    offset_ll.append(label_offsets[i])

gt_anchors_ll = np.array(gt_anchors_ll)
gt_labels_ll = np.array(gt_labels_ll)

correct_order = np.argsort(gt_anchors_ll)
gt_anchors_ll = gt_anchors_ll[correct_order]
gt_labels_ll = gt_labels_ll[correct_order]
gt_anchors = gt_anchors_ll
gt_labels = gt_labels_ll

In [None]:
# setup plots and helpers
fig, ax = plt.subplots(figsize=(20, 8))
lm = LabelDecoder(3)
int2class = lambda i: lm.get_classnames()[i]
class2int = lambda c: lm.get_classnames().index(c)

# put predictions into 
pred_list = []
for (f, cs) in zip(reordered_frames, confidences):
    for i, c in enumerate(cs):
        plot_preds = {}
        plot_preds["frame"] = f
        plot_preds["type"] = int2class(i)
        plot_preds["confidence"] = c
        pred_list.append(plot_preds)

pred_df = pd.DataFrame(pred_list)
pred_df.sort_values(by="frame", inplace=True)

start_frame = 30000 # 5000 - 8000
end_frame = 308000
palette = sns.color_palette("husl")
class_palette = {c: color for c, color in zip(lm.get_classnames(), palette[:3])}
plot_data = pred_df[(pred_df["frame"] > start_frame) & (pred_df["frame"] < end_frame)]
sns.scatterplot(data=plot_data, x="frame", y="confidence", hue="type", palette=class_palette, ax=ax)

# plot ground truth
plotted_frames_in_gt = np.where((gt_anchors < end_frame) & (gt_anchors > start_frame))[0]
for c, f in zip(gt_labels[plotted_frames_in_gt], gt_anchors[plotted_frames_in_gt]):
    if c != 0: # fix div by 2 == 0 bug later
        ax.axvline(x=f, color=palette[c], linestyle="-")
        pass

# do post-processing

anchors, confs = postprocess_predictions(confidences, reordered_frames)
anchors = np.array(anchors)
confs = np.stack(confs)
# anchors, confs = postprocess_peaks_only(confidences, reordered_frames, height=0.5, distance=8, width=0)
# confs[confs > 0.9] = 1
postprocess_list = []
# array index, frame confidences
for idx, (f, cs) in enumerate(zip(anchors, confs)):
    # class int and confidence per class
    for i, c in enumerate(cs):
        plot_preds = {}
        plot_preds["frame"] = f
        plot_preds["type"] = int2class(i)
        plot_preds["confidence"] = c
        plot_preds["idx"] = idx 
        postprocess_list.append(plot_preds)

pp_df = pd.DataFrame(postprocess_list)

# plot predicted anchors from postprocessing
pp_in_plot = pp_df[(pp_df.frame < end_frame) & (pp_df.frame > start_frame)]
for i, row in pp_in_plot.iterrows():
    if row["type"] != "Background":
        ax.axvline(x=row["frame"], color=palette[class2int(row["type"])], ymax=row["confidence"], linestyle=':')
        idx = row["idx"]


In [None]:
print(plot_data[(plot_data["type"] == "Pass") & (plot_data["confidence"] > 0.5)][["frame", "confidence"]])
print(pp_in_plot[(pp_in_plot["type"] == "Pass") & (pp_in_plot["confidence"] > 0.5)].frame)

In [None]:
fps = 29.97
print(average_mAP(confs, anchors, gt_labels_ll, gt_anchors_ll, tolerances=[fps, 2*fps, 3*fps, 10*fps]))