In [None]:
import numpy as np
import os, sys
import importlib
from tqdm.notebook import tqdm, trange
import pandas as pd
import json
import itertools
from IPython.display import display

### VLM and Parameter Setup

In [None]:
ENV = os.environ["CONDA_DEFAULT_ENV"]

if ENV == "videoclip":
    from video_clip.video_clip import VideoClipVLM as VLM
    vlm_params = {
        "num_seconds": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
        "sample_strat": ["center", "start", "spread"]
    }

elif ENV == "VLM_MILES":
    from MILES.wrapper import MILES_SimilarityVLM as VLM
    vlm_params = {}

elif ENV == "VLM_CLIP":
    from CLIP.CLIPVLM import ClipVLM as VLM
    vlm_params = {
        "num_frames": [1, 2, 3, 4, 6, 8, 10, 20, 50, 100]
    }

else:
    ValueError(ENV)


### Data Setup

In [None]:
DATA_FOLDER = "/home/datasets/500p/10s_dataset"
MP4_FOLDER = f"{DATA_FOLDER}/10s_clips"
TXT_FOLDER = f"{DATA_FOLDER}/10s_kaldi_texts"

video_names = [name[:-4] for name in os.listdir(MP4_FOLDER) if name.endswith(".mp4")]
video_names.sort()
text_names = [name[:-4] for name in os.listdir(TXT_FOLDER) if name.endswith(".txt")]
text_names.sort()
assert video_names == text_names
pair_names = video_names

In [None]:
DATASET_ID = "kinetics_100"

if DATASET_ID == "500p":
    N = len(pair_names)
    vid_paths = np.array([f"{MP4_FOLDER}/{name}.mp4" for name in pair_names])
    vid_text = []
    for name in pair_names:
        text_path = f"{TXT_FOLDER}/{name}.txt"
        with open(text_path, "r") as fp:
            vid_text.append(fp.read().lower().strip())
    vid_text = np.array(vid_text)

elif DATASET_ID == "500p.text_overflow_1_1":
    N = len(pair_names)
    vid_paths = np.array([f"{MP4_FOLDER}/{name}.mp4" for name in pair_names])
    raw_vid_text = []
    for name in pair_names:
        text_path = f"{TXT_FOLDER}/{name}.txt"
        with open(text_path, "r") as fp:
            raw_vid_text.append(fp.read().lower().strip())
    vid_text = []
    for i in range(len(raw_vid_text)):
        vid_text.append(" ".join(raw_vid_text[max(0, i - 1) : min(len(raw_vid_text), i + 2)]))
    vid_text = np.array(vid_text)
    
elif DATASET_ID == "kinetics_100":
    from dataset import DatasetHandler
    dataset = DatasetHandler("kinetics_100", split="all")
    N = dataset.category_count()
    vid_paths, vid_text = next(iter(dataset.few_shot(1, N, 0, 1)))
    vid_paths = vid_paths[:, 0]

else:
    raise ValueError(DATASET_ID)

print(DATASET_ID)
print(f"{N} Video-Text Pairs")
for i in range(N):
    print(f"{i:>3}: {vid_paths[i].split('/')[-1]:>20}")
    print(vid_text[i])
    print()

### Retrieval Tester Function

In [None]:
if os.path.exists("retrieval_results.csv"):
    results = pd.read_csv("retrieval_results.csv")
else:
    results = pd.DataFrame(columns=["vlm_class", "vlm_params", "dataset",
                                    "R@1 (of 82)", "R@5 (of 82)", "R@10 (of 82)", "R@20 (of 82)", "R@50 (of 82)",
                                    "Mean R", "Med R",
                                    "Ranked Text Indices", "Vid-Text Choice Ranks",
                                    "Correct Vid-Text Pair Ranks", "Vid-Text Pairs (Easy->Hard)"])

In [None]:
def retrieval_test(vlm):    
    pbar = trange(2 * N, leave=False)
    vid_embeds = []
    text_embeds = []
    for path in vid_paths:
        vid_embeds.append(vlm.get_video_embeds(path))
        pbar.update(1)
    for text in vid_text:
        text_embeds.append(vlm.get_text_embeds(text))
        pbar.update(1)
    vid_embeds = np.array(vid_embeds)
    text_embeds = np.array(text_embeds)
    
    similarity = vlm.default_similarity_metric()(vid_embeds, text_embeds)
    
    sorted_text_choice_indices = np.argsort(-similarity, axis=1) # i, j = text index which is the jth best match to video index i
    pair_ranks = np.argsort(sorted_text_choice_indices, axis=1) # i, j = rank position of text j for vid i out of all text options (0 = best choice, 81 = worst)
    correct_pair_ranks = pair_ranks[np.arange(N), np.arange(N)] # i = rank position of correct pair (vid i - text i) out of all text options (0 = best, 81 = worst)
    sorted_pair_indices = np.argsort(correct_pair_ranks) # vid-text pair index with best rank to vid-text pair index with worst rank
    
    
    # R@1
    R1_sum = np.sum(correct_pair_ranks < 1)
    R1 = R1_sum / N
    
    # R@5
    R5_sum = np.sum(correct_pair_ranks < 5)
    R5 = R5_sum / N
    
    # R@10
    R10_sum = np.sum(correct_pair_ranks < 10)
    R10 = R10_sum / N
    
    # R@20
    R20_sum = np.sum(correct_pair_ranks < 20)
    R20 = R20_sum / N
    
    # R@50
    R50_sum = np.sum(correct_pair_ranks < 50)
    R50 = R50_sum / N
    
    # Mean/Med Rank
    mean_R = np.mean(correct_pair_ranks) + 1
    med_R = np.median(correct_pair_ranks) + 1
    
    print(f"{vlm.__class__.__name__}   {json.dumps(vlm.params())}")
    print(f"{'R1':>10}{'R5':>10}{'R10':>10}{'R20':>10}{'R50':>10}{'MeanR':>10}{'MedR':>10}")
    print(f"{R1:10.3f}{R5:10.3f}{R10:10.3f}{R20:10.3f}{R50:10.3f}{mean_R:10.3f}{med_R:10.3f}")
    print()
    
    
    # Save results
    # Remove any previously-saved versions of this same test
    global results
    test_spec = {
        "vlm_class": vlm.__class__.__name__,
        "vlm_params": vlm.params(),
        "dataset": DATASET_ID
    }
    prev_matching_tests = (results[list(test_spec.keys())] == pd.Series(test_spec)).all(axis=1)
    if np.any(prev_matching_tests):
        results = results[~prev_matching_tests].reset_index(drop=True)
        
    results.loc[len(results)] = [
        test_spec["vlm_class"],
        test_spec["vlm_params"],
        test_spec["dataset"],
        R1_sum, R5_sum, R10_sum, R20_sum, R50_sum,
        mean_R, med_R,
        sorted_text_choice_indices, pair_ranks,
        correct_pair_ranks, sorted_pair_indices
    ]
    results.to_csv("retrieval_results.csv", index=False)

### Test over all vlm params

In [None]:
# Dynamically display most recent test results
disp = display(display_id=True)
disp.update(results.tail(5))

vlm = None
if len(vlm_params):
    param_list = tqdm(list(itertools.product(*vlm_params.values())))
else:
    param_list = tqdm([[]])
    
for params in param_list:
    params = dict(zip(vlm_params.keys(), params))
    param_list.set_postfix(params)
    
    vlm = VLM(**params)
    
    retrieval_test(vlm)
    disp.update(results.tail(5))

In [None]:
test = []
for x in results["Correct Vid-Text Pair Ranks"][[3, 30, 40]].values:
    test.append([int(rank) for rank in x.replace("[", "").replace("]", "").replace("\n", "").split()])
test = np.array(test)
test = np.mean(test, axis=0)
print(test)
print(np.argsort(test))