### Imports

In [None]:
import sys
sys.path.append("..") # root of repo
import os
import numpy as np
import pandas as pd
import pickle as pkl
import ipywidgets as widgets
from collections import Counter, defaultdict
import os.path as osp
from pathlib import Path
from tqdm import tqdm
from PIL import Image
from src.data import MultiModalHblDataset
from src.utils import array2gif, draw_trajectory
import torchvision
import torch

### Checking frequencies of labels across matches

In [None]:
dataset = MultiModalHblDataset(
    meta_path="/nfs/home/rhotertj/datasets/hbl/meta3d.csv",
    seq_len=16,
    sampling_rate=2,
    load_frames=False   
)
ctr = Counter()
no_action = 0
print("Dataset size:", len(dataset))

In [None]:
for idx, instance in tqdm(enumerate(dataset), total=len(dataset)):
    label = instance["label"]
    if not label == {}:
        ctr.update({f"{label['Pass']}{label['Wurf']}" : 1})
    else:
        no_action += 1

class_combinations = pd.DataFrame()
for (pass_cls, shot_cls), n in ctr.items():
    class_combinations.loc[shot_cls, pass_cls] = n
print("No action:", no_action)
class_combinations

In [None]:
class_combinations.sum().sum() + no_action == len(dataset)

In [None]:
class_combinations.to_csv("occurences_16_2.csv")

In [None]:
from collections import defaultdict
import pickle as pkl
class2frame = defaultdict(list)
valid_combs = ['A0', 'A1', 'B0', 'O0', 'C0', 'D0', 'X0', 'O1', 'O2', 'O3', 'O4', 'O5', 'O6', 'O7', 'O8']

for events in tqdm(dataset.event_dfs):
    events["Wurf"] = events.labels.apply(lambda x : x["Wurf"])
    events["Pass"] = events.labels.apply(lambda x : x["Pass"])
    for k in valid_combs:
        p, w = k
        f = events[(events["Pass"] == p) & (events["Wurf"] == w)].index.tolist()
        if f:
            class2frame[k].append(f)

with open("class2frame.pkl", "wb+") as f:
    pkl.dump(class2frame, f)


### Adding info to meta file

In [None]:
sql = 16
sr = 2
dataset_img = MultiModalHblDataset(
    meta_path="/nfs/home/rhotertj/datasets/hbl/meta.csv",
    seq_len=sql,
    sampling_rate=sr,
    load_frames=True   
)
# [(k, v) for (k,v) in dataset_img.event_dfs[0].head(20)["labels"].items()]
# 12126 O0
print("Match boundaries:", dataset_img.index_tracker)

In [None]:
idx = 109340
pos_offset = 0 # positive values move positions "into the past"
example = dataset_img.__getitem__(idx) #, frame_idx=idx, match_number=0)

frames = np.transpose(example["frames"], (3, 0, 1, 2))
positions = example["positions"]

print("Frames and positions shape:", frames.shape, positions.shape)
print("Action label frame number:", example["label_offset"])
print("Label:", example["label"])

array2gif(frames, f"../img/instance_{idx}_{sql}x{sr}.gif", fps=10)
fig = draw_trajectory(positions)
# Mismatch Notes:
# Game | positions offset
# 0s   | 8 frames in the future
# 1s   | 20 frames in the future, 109340 32x1
# 2s   | mirrored, 16 frames in the past 204000 32x1
# 3s   | looks good (319120 32x1 data)
# 4s   | 16 frames in the future, see 426000 32x1
# 5s   | mirrored, ok see 533450 32x1
# 6s   | mirrored, ok, 641300 31x1
# 7s   | 5 frames in the future, 745700 32x1
# 8s   | mirrored, 47 frames in the future 854000 32x1
# 9s   | 16 frames in the future 962400 32x1
# 10s  | mirrored, 8 in the future 1071800

In [None]:
p = "/nfs/home/rhotertj/datasets/hbl/meta3d.csv"
df = pd.read_csv(p, index_col="match_id")

vertical = [1, 2, 5, 6, 8, 10]
horizontal = [1, 2, 5, 6, 8, 10]
for idx, (i, _) in enumerate(df.iterrows()):
    df.loc[i, "mirror_vertical"] = idx in vertical
    df.loc[i, "mirror_horizontal"] = idx in horizontal

df.to_csv(p)
df

### Plotting positions and frames per halftime

In [None]:
sql = 16
sr = 2
dataset_img = MultiModalHblDataset(
    meta_path="/nfs/home/rhotertj/datasets/hbl/meta3d.csv",
    seq_len=sql,
    sampling_rate=sr,
    load_frames=True   
)
# [(k, v) for (k,v) in dataset_img.event_dfs[0].head(20)["labels"].items()]
# 12126 O0
print("Match boundaries:", dataset_img.index_tracker)



In [None]:
# match bounds [0, 106194, 212206, 318712, 425410, 531648, 640606, 745140, 853159, 960038, 1065755, 1172197]
events = dataset_img.event_dfs

for i,event_df in enumerate(events):
    print(i, "---")
    print("First frame", dataset_img.idx_to_frame_number[i][0])
    print("First action", event_df.index[0])


### Debugging idx -> Frame index mapping

In [None]:
sql = 4
hql = sql // 2
rate = 2
sr = sql * rate
hr = hql * rate
kernel = np.ones(sr)

availables = [True, True, True, True, True, True, True, True, True, True, True, True, True, False, False]
pos = np.arange(len(availables))

cv = np.convolve(availables, kernel)
print(f"{cv}")
idxs = np.where(cv == sr)[0] - (sr - 1) # subtract filter length - 1
print("idx for valid sequences:", idxs)

q_idx = 0
f_idx = idxs[q_idx] + hr
sequence = pos[f_idx - hr : f_idx + hr : rate]
print(f"Idx for sequence {q_idx}: {sequence}")

### Visualizing each class

In [None]:
sql = 16
sr = 2
dataset_img = MultiModalHblDataset(
    meta_path="/nfs/home/rhotertj/datasets/hbl/meta3d.csv",
    seq_len=sql,
    sampling_rate=sr,
    load_frames=True
)

with open("class2frame.pkl", "rb") as f:
    class2frame = pkl.load(f)

In [None]:
os.makedirs("../img/classes", exist_ok=True)
n_per_class = 3
for cls, frame_lists in class2frame.items():
    match_id = 0
    frame_id = 0
    n = 0
    for i, fl in enumerate(frame_lists):
        if n == n_per_class:
            break
        if fl:
            match_id = i
            try:
                frame_id = fl[(len(fl) // 2) + n] # dont take first pass of the game
            except:
                continue
            example = dataset_img.__getitem__(0, frame_idx=frame_id, match_number=match_id)
            frames = example["frames"].transpose(1, 0, 2, 3)
            gifname = f"../img/classes/{cls}_{n}_{match_id}x{frame_id}.gif"
            array2gif(frames, gifname, fps=10)

            positions = example["positions"]
            fig = draw_trajectory(positions)
            figname = f"../img/classes/{cls}_{n}_{match_id}x{frame_id}.png"
            fig.savefig(figname)
            print(figname)
            print(gifname)

            n+= 1
            

### Calculating Mean and Standard Deviation

In [None]:
sql = 16
sr = 1
dataset_img = MultiModalHblDataset(
    meta_path="/nfs/home/rhotertj/datasets/hbl/meta3d.csv",
    seq_len=sql,
    sampling_rate=sr,
    load_frames=True
)
print(len(dataset_img))
n = 0
means = np.zeros((len(dataset_img) + sql, 3))
stds = np.zeros((len(dataset_img) + sql, 3))

for i in tqdm(range(0, len(dataset_img), sql)):
    frames = torch.tensor(dataset_img[i]["frames"]) / 255

    for j, frame in enumerate(frames):
        mean, std = frame.mean([1,2]), frame.std([1,2])
        means[i+j] = mean
        stds[i+j] = std

        n+=1
print(n)
print(means.shape)


In [None]:
np.save("means.npy", means)
np.save("std.npy", stds)

In [None]:
print(means.mean(0), stds.mean(0))