In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import os
import os.path as op
import pandas as pd
import json
import base64

In [None]:
sys.path.append(op.abspath('..'))

In [None]:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

In [None]:
from collections import Counter
from itertools import chain

import torch
import multiprocessing
from scipy.spatial import distance_matrix
import numpy as np

In [None]:
from torch.utils.data import DataLoader

from datasets import SLREmbeddingDataset, collate_fn_padd
from datasets.dataset_loader import LocalDatasetLoader
from models import embeddings_scatter_plot_splits
from models import SPOTER_EMBEDDINGS

## Model and dataset loading

In [None]:
import random
seed = 43
random.seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True) 
generator = torch.Generator()
generator.manual_seed(seed)

In [None]:
BASE_DATA_FOLDER = '../data/'
os.environ["BASE_DATA_FOLDER"] = BASE_DATA_FOLDER
device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")

In [None]:
# LOAD MODEL FROM CLEARML
# from clearml import InputModel
# model = InputModel(model_id='1b736da469b04e91b8451d2342aef6ce')
# checkpoint = torch.load(model.get_weights())

## Set your path to checkoint here
CHECKPOINT_PATH = "../checkpoints/checkpoint_embed_992.pth"
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)

model = SPOTER_EMBEDDINGS(
    features=checkpoint["config_args"].vector_length,
    hidden_dim=checkpoint["config_args"].hidden_dim,
    norm_emb=checkpoint["config_args"].normalize_embeddings,
).to(device)

model.load_state_dict(checkpoint["state_dict"])

In [None]:
SL_DATASET = 'wlasl'  # or 'lsa'
if SL_DATASET == 'wlasl':
    dataset_name = "wlasl_mapped_mediapipe_only_landmarks_25fps"
    num_classes = 100
    split_dataset_path = "WLASL100_{}_25fps.csv"
else:
    dataset_name = "lsa64_mapped_mediapipe_only_landmarks_25fps"
    num_classes = 64
    split_dataset_path = "LSA64_{}.csv"
    
    

In [None]:
def get_dataset_loader(loader_name=None):
    if loader_name == 'CLEARML':
        from datasets.clearml_dataset_loader import ClearMLDatasetLoader
        return ClearMLDatasetLoader()
    else:
        return LocalDatasetLoader()

dataset_loader = get_dataset_loader()
dataset_project = "Sign Language Recognition"
batch_size = 1
dataset_folder = dataset_loader.get_dataset_folder(dataset_project, dataset_name)

In [None]:
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

In [None]:
dataloaders = {}
splits = ['train', 'val']
dfs = {}
for split in splits:
    split_set_path = op.join(dataset_folder, split_dataset_path.format(split))
    split_set = SLREmbeddingDataset(split_set_path, triplet=False, augmentations=False)
    data_loader = DataLoader(
        split_set,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn_padd,
        pin_memory=torch.cuda.is_available(),
        num_workers=multiprocessing.cpu_count(),
        worker_init_fn=seed_worker,
        generator=generator,
    )
    dataloaders[split] = data_loader
    dfs[split] =  pd.read_csv(split_set_path)

with open(op.join(dataset_folder, 'id_to_label.json')) as fid:
    id_to_label = json.load(fid)
id_to_label = {int(key): value for key, value in id_to_label.items()}

In [None]:
labels_split = {}
embeddings_split = {}
splits = list(dataloaders.keys())
with torch.no_grad():
    for split, dataloader in dataloaders.items():
        labels_str = []
        embeddings = []
        k = 0
        for i, (inputs, labels, masks) in enumerate(dataloader):
            k += 1
            inputs = inputs.to(device)
            masks = masks.to(device)
            outputs = model(inputs, masks)
            for n in range(outputs.shape[0]):
                embeddings.append(outputs[n, 0].cpu().detach().numpy())
        embeddings_split[split] = embeddings

In [None]:
len(embeddings_split['train']), len(dfs['train'])

In [None]:
for split in splits:
    df = dfs[split]
    df['embeddings'] =  embeddings_split[split]

## Compute metrics
Here computing top1 and top5 metrics either by using only a class centroid or by using the whole dataset to classify vectors.


In [None]:
for use_centroids, str_use_centroids in zip([True, False],
                                           ['Using centroids only', 'Using all embeddings']):

    df_val = dfs['val']
    df_train = dfs['train']
    if use_centroids:
        df_train = dfs['train'].groupby('labels')['embeddings'].apply(np.mean).reset_index()
    x_train = np.vstack(df_train['embeddings'])
    x_val = np.vstack(df_val['embeddings'])

    d_mat = distance_matrix(x_val, x_train, p=2)

    top5_embs = 0
    top5_classes = 0
    knn = 0
    top1 = 0

    len_val_dataset = len(df_val)
    good_samples = []

    for i in range(d_mat.shape[0]):
        true_label = df_val.loc[i, 'labels']
        labels = df_train['labels'].values
        argsort = np.argsort(d_mat[i])
        sorted_labels = labels[argsort]
        if sorted_labels[0] == true_label:
            top1 += 1
            if use_centroids:
                good_samples.append(df_val.loc[i, 'video_id'])
            else:
                good_samples.append((df_val.loc[i, 'video_id'],
                                     df_train.loc[argsort[0], 'video_id'],
                                     i,
                                     argsort[0]))


        if true_label == Counter(sorted_labels[:5]).most_common()[0][0]:
            knn += 1
        if true_label in sorted_labels[:5]:
            top5_embs += 1
        if true_label in list(dict.fromkeys(sorted_labels))[:5]:
            top5_classes += 1
        else:
            continue


    print(str_use_centroids)


    print(f'Top-1 accuracy: {100 * top1 / len_val_dataset : 0.2f} %')
    if not use_centroids:
        print(f'5-nn accuracy: {100 * knn / len_val_dataset : 0.2f} % (Picks the class that appears most often in the 5 closest embeddings)')
    print(f'Top-5 embeddings class match: {100 * top5_embs / len_val_dataset: 0.2f} %  (Picks any class in the 5 closest embeddings)')
    if not use_centroids:
        print(f'Top-5 unique class match: {100 * top5_classes / len_val_dataset: 0.2f} %  (Picks the 5 closest distinct classes)')
    print('\n' + '#'*32 + '\n')

## Show some examples (only for WLASL)

In [None]:
from IPython.display import Video

In [None]:
for row in df_train[df_train.label_name == 'thursday'][:3].itertuples():
    display(Video(op.join(BASE_DATA_FOLDER, f'wlasl/videos/{row.video_id}.mp4'), embed=True))