In [None]:
import numpy as np
import os, sys
import importlib
from tqdm.notebook import tqdm, trange
import pandas as pd
import json
import itertools
from IPython.display import display

# Import test handler
from FewShotTestHandler import FewShotTestHandler

# Dataset Handler
from dataset import DatasetHandler

# Most common classifier
from classifier.WeightedTextFewShotClassifier import WeightedTextFewShotClassifier

### Load VLM

Note: This notebook must be run using the corresponding conda environment

In [None]:
if False:
    from VTTWINS.wrapper import VTTWINS_SimilarityVLM
    vlm = VTTWINS_SimilarityVLM(reset_cache=False)
    
if True:
    from CLIP.CLIPVLM import ClipVLM
    vlm = ClipVLM(reset_cache=False)
    
if False:
    from UNIVL.wrapper import UniVL_SimilarityVLM
    vlm = UniVL_SimilarityVLM(reset_cache=False)

### Test Handler

Runs few-shot testing and permanently saves results

In [None]:
test_handler = FewShotTestHandler()

### Run Repeated Tests

In [None]:
DATASET = DatasetHandler("moma_sact", split="val")
print(f"Dataset:\t{DATASET.id()}")
print(f"Categories:\t{DATASET.category_count()}")
print(f"Videos:\t\t{DATASET.video_count()}")

In [None]:
N_WAY_LIST = [5, 10]
N_SUPPORT_LIST = [0, 1, 2, 5, 10]
TEXT_WEIGHT_LIST = [0.1, 0.5, 1, 2, 5, 10, 20]
N_QUERY = 1
N_EPISODES = 1000

In [None]:
# Fill VLM Cache for chosen dataset
test_handler.fill_cache(vlm, DATASET)

In [None]:
# Dynamically display most recent test results
disp = display(display_id=True)
disp.update(test_handler.results.tail(5))

param_list = list(itertools.product(N_WAY_LIST, N_SUPPORT_LIST, TEXT_WEIGHT_LIST))
param_list = list(filter(lambda x: not (x[1] == 0 and x[2] != 1), param_list)) # Remove repeated zero-shot tests with different text_weights

for n_way, n_support, text_weight in tqdm(param_list):
    classifier = WeightedTextFewShotClassifier(vlm, metric=None, text_weight=text_weight)
    
    test_handler.run_few_shot_test(classifier, DATASET,
                                   n_way=n_way, n_support=n_support, n_query=N_QUERY, n_episodes=N_EPISODES)
    disp.update(test_handler.results.tail(5))