In [1]:
import numpy as np
import os, sys
import importlib
from tqdm.notebook import tqdm, trange
import pandas as pd
import json
import itertools
from IPython.display import display

import torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Import dataset types
from dataset.dataset import FewShotTaskDataset, SequentialVideoDataset, SequentialCategoryNameDataset

# Import base classes
from SimilarityVLM import SimilarityVLM
from classifier.FewShotClassifier import FewShotClassifier

## Test Parameters

### Choose Dataset to Test

In [2]:
DATASET_SPLIT_PATH = "/home/datasets/kinetics_100_split/test.txt"

### Choose Few-Shot Task Parameters

In [3]:
N_WAY = 5                       # Number of categories to choose between in each task
N_SUPPORT = 10                  # Number of example videos per category per task
N_QUERY = 1                     # Number of test videos per category per task
N_EPISODES = 1000               # Number of few-shot tasks sampled in one iteration of the dataset

## VLM Setup

### Load VLM and Few-Shot Classifier

Note: This notebook must be run using the corresponding conda environment

In [4]:
if True:
    from VTTWINS.wrapper import VTTWINS_SimilarityVLM
    vlm = VTTWINS_SimilarityVLM(reset_cache=False)
    
    from classifier.FewShotClassifier import FewShotClassifier
    classifier = FewShotClassifier(vlm, metric=None)
    
if False:
    from CLIP.CLIPVLM import ClipVLM
    vlm = ClipVLM(reset_cache=False)
    
    from classifier.FewShotClassifier import FewShotClassifier
    classifier = FewShotClassifier(vlm, metric=None)
    
if False:
    from CLIP.CLIPVLM import ClipVLM
    vlm = ClipVLM(reset_cache=False)
    
    from classifier.WeightedTextFewShotClassifier import WeightedTextFewShotClassifier
    classifier = WeightedTextFewShotClassifier(vlm, metric=None, text_weight=4)

### Fill the Cache

In [5]:
video_dataset = SequentialVideoDataset(DATASET_SPLIT_PATH)

try:
    for vid_path in tqdm(video_dataset):
        if vid_path not in vlm.embed_cache:
            vlm.get_video_embeds(vid_path)
except KeyboardInterrupt:
    pass
finally:
    vlm.save_cache()

  0%|          | 0/2400 [00:00<?, ?it/s]

## Test Setup

### Setup DataFrame for Saving Test Results

In [6]:
TEST_RESULTS_PATH = "test_results.csv"
TEST_RESULTS_COLUMNS = ["vlm_class", "vlm_params", "classifier_class", "classifier_params", "dataset_split", "n_way", "n_support", "n_query", "n_episodes", "accuracy"]

if os.path.exists(TEST_RESULTS_PATH):
    test_results = pd.read_csv(TEST_RESULTS_PATH)
else:
    test_results = pd.DataFrame(columns=TEST_RESULTS_COLUMNS)

In [7]:
test_results

Unnamed: 0,vlm_class,vlm_params,classifier_class,classifier_params,dataset_split,n_way,n_support,n_query,n_episodes,accuracy
0,ClipVLM,"{""path"": ""openai/clip-vit-base-patch32"", ""num_...",FewShotClassifier,"{""metric"": ""COSINE""}",/home/datasets/kinetics_100_split/test.txt,5,10,1,1000,0.83000
1,VTTWINS_SimilarityVLM,{},FewShotClassifier,"{""metric"": ""DOT""}",/home/datasets/kinetics_100_split/test.txt,5,10,1,1000,0.75900
2,ClipVLM,"{""path"": ""openai/clip-vit-base-patch32"", ""num_...",WeightedTextFewShotClassifier,"{""metric"": ""COSINE"", ""text_weight"": 4.0}",/home/datasets/kinetics_100_split/test.txt,5,10,1,1000,0.85580
3,ClipVLM,"{""path"": ""openai/clip-vit-base-patch32"", ""num_...",WeightedTextFewShotClassifier,"{""metric"": ""COSINE"", ""text_weight"": 1.0}",/home/datasets/kinetics_100_split/test.txt,5,0,1,1000,0.86740
4,ClipVLM,"{""path"": ""openai/clip-vit-base-patch32"", ""num_...",WeightedTextFewShotClassifier,"{""metric"": ""COSINE"", ""text_weight"": 0.1}",/home/datasets/kinetics_100_split/test.txt,5,1,1,1000,0.72120
...,...,...,...,...,...,...,...,...,...,...
123,ClipVLM,"{""path"": ""openai/clip-vit-base-patch32"", ""num_...",WeightedTextFewShotClassifier,"{""metric"": ""COSINE"", ""text_weight"": 1.0}",/home/datasets/kinetics_100_split/test.txt,20,10,1,1000,0.67465
124,ClipVLM,"{""path"": ""openai/clip-vit-base-patch32"", ""num_...",WeightedTextFewShotClassifier,"{""metric"": ""COSINE"", ""text_weight"": 2.0}",/home/datasets/kinetics_100_split/test.txt,20,10,1,1000,0.68030
125,ClipVLM,"{""path"": ""openai/clip-vit-base-patch32"", ""num_...",WeightedTextFewShotClassifier,"{""metric"": ""COSINE"", ""text_weight"": 5.0}",/home/datasets/kinetics_100_split/test.txt,20,10,1,1000,0.70970
126,ClipVLM,"{""path"": ""openai/clip-vit-base-patch32"", ""num_...",WeightedTextFewShotClassifier,"{""metric"": ""COSINE"", ""text_weight"": 10.0}",/home/datasets/kinetics_100_split/test.txt,20,10,1,1000,0.71055


### Testing Function

In [8]:
'''
Runs a few shot test using the given classifier, dataset, and task parameters.
Returns the average accuracy over all sampled query videos in all sampled tasks.
'''
def few_shot_test(classifier: FewShotClassifier, dataset_split_path: str,
                  n_way: int, n_support: int, n_query: int = 1, n_episodes: int = 1000) -> float:
    
    # Load dataset to generate tasks with the desired params
    dataset = FewShotTaskDataset(dataset_split_path, n_episodes, n_way, n_support, n_query)
    
    correct_predictions = 0
    total_queries = 0
    for vid_paths, category_names in tqdm(dataset, leave=False):
        
        query_vid_paths = vid_paths[:, n_support:]
        if n_support > 0:
            support_vid_paths = vid_paths[:, :n_support]
        else:
            support_vid_paths = None
            
        query_predictions = classifier.predict(category_names, support_vid_paths, query_vid_paths)
        
        correct_predictions += np.sum(query_predictions == np.arange(n_way)[:, None])
        total_queries += n_way * n_query
    
    return correct_predictions / total_queries

'''
Runs the given few-shot test if it has not already been performed,
saving the result into a dataframe
'''
def collect_few_shot_test_results(test_results_df: pd.DataFrame,
                                  classifier: FewShotClassifier, dataset_split_path: str,
                                  n_way: int, n_support: int, n_query: int = 1, n_episodes: int = 1000,
                                  ) -> None:
    test_params = {
        "vlm_class": classifier.vlm.__class__.__name__,
        "vlm_params": json.dumps(classifier.vlm.params()),
        "classifier_class": classifier.__class__.__name__,
        "classifier_params": json.dumps(classifier.params()),
        "dataset_split": dataset_split_path,
        "n_way": n_way,
        "n_support": n_support,
        "n_query": n_query,
        "n_episodes": n_episodes
    }
    
    # Abort if test has already been recorded
    filtered_tests = test_results_df
    for key, val in test_params.items():
        filtered_tests = filtered_tests[filtered_tests[key] == val]
    if not filtered_tests.empty:
        return
    
    # Run Test
    accuracy = few_shot_test(classifier=classifier, dataset_split_path=dataset_split_path,
                             n_way=n_way, n_support=n_support, n_query=n_query, n_episodes=n_episodes)
    
    # Save results
    df_row = dict(test_params, accuracy=accuracy)
    test_results_df.loc[len(test_results_df)] = df_row

## Run Test

In [9]:
collect_few_shot_test_results(test_results,
                              classifier, DATASET_SPLIT_PATH,
                              n_way=N_WAY, n_support=N_SUPPORT, n_query=N_QUERY, n_episodes=N_EPISODES)

### Save Updated Test Results

In [10]:
test_results.to_csv(TEST_RESULTS_PATH, index=False)

In [11]:
test_results

Unnamed: 0,vlm_class,vlm_params,classifier_class,classifier_params,dataset_split,n_way,n_support,n_query,n_episodes,accuracy
0,ClipVLM,"{""path"": ""openai/clip-vit-base-patch32"", ""num_...",FewShotClassifier,"{""metric"": ""COSINE""}",/home/datasets/kinetics_100_split/test.txt,5,10,1,1000,0.83000
1,VTTWINS_SimilarityVLM,{},FewShotClassifier,"{""metric"": ""DOT""}",/home/datasets/kinetics_100_split/test.txt,5,10,1,1000,0.75900
2,ClipVLM,"{""path"": ""openai/clip-vit-base-patch32"", ""num_...",WeightedTextFewShotClassifier,"{""metric"": ""COSINE"", ""text_weight"": 4.0}",/home/datasets/kinetics_100_split/test.txt,5,10,1,1000,0.85580
3,ClipVLM,"{""path"": ""openai/clip-vit-base-patch32"", ""num_...",WeightedTextFewShotClassifier,"{""metric"": ""COSINE"", ""text_weight"": 1.0}",/home/datasets/kinetics_100_split/test.txt,5,0,1,1000,0.86740
4,ClipVLM,"{""path"": ""openai/clip-vit-base-patch32"", ""num_...",WeightedTextFewShotClassifier,"{""metric"": ""COSINE"", ""text_weight"": 0.1}",/home/datasets/kinetics_100_split/test.txt,5,1,1,1000,0.72120
...,...,...,...,...,...,...,...,...,...,...
123,ClipVLM,"{""path"": ""openai/clip-vit-base-patch32"", ""num_...",WeightedTextFewShotClassifier,"{""metric"": ""COSINE"", ""text_weight"": 1.0}",/home/datasets/kinetics_100_split/test.txt,20,10,1,1000,0.67465
124,ClipVLM,"{""path"": ""openai/clip-vit-base-patch32"", ""num_...",WeightedTextFewShotClassifier,"{""metric"": ""COSINE"", ""text_weight"": 2.0}",/home/datasets/kinetics_100_split/test.txt,20,10,1,1000,0.68030
125,ClipVLM,"{""path"": ""openai/clip-vit-base-patch32"", ""num_...",WeightedTextFewShotClassifier,"{""metric"": ""COSINE"", ""text_weight"": 5.0}",/home/datasets/kinetics_100_split/test.txt,20,10,1,1000,0.70970
126,ClipVLM,"{""path"": ""openai/clip-vit-base-patch32"", ""num_...",WeightedTextFewShotClassifier,"{""metric"": ""COSINE"", ""text_weight"": 10.0}",/home/datasets/kinetics_100_split/test.txt,20,10,1,1000,0.71055


### Repeated Test

In [12]:
from classifier.WeightedTextFewShotClassifier import WeightedTextFewShotClassifier

N_WAY_LIST = [5, 10, 20]
N_SUPPORT_LIST = [0, 1, 2, 5, 10]
TEXT_WEIGHT_LIST = [0.1, 0.5, 1, 2, 5, 10, 20]

# Dynamically display most recent test results
disp = display(display_id=True)
disp.update(test_results.tail(10))

param_list = list(itertools.product(N_WAY_LIST, N_SUPPORT_LIST, TEXT_WEIGHT_LIST))
param_list = list(filter(lambda x: not (x[1] == 0 and x[2] != 1), param_list)) # Remove repeated zero-shot tests with different text_weights

for n_way, n_support, text_weight in tqdm(param_list):
    classifier = WeightedTextFewShotClassifier(vlm, metric=None, text_weight=text_weight)
    collect_few_shot_test_results(test_results,
                                    classifier, DATASET_SPLIT_PATH,
                                    n_way=n_way, n_support=n_support, n_query=N_QUERY, n_episodes=N_EPISODES)
    test_results.to_csv(TEST_RESULTS_PATH, index=False)
    disp.update(test_results.tail(10))

Unnamed: 0,vlm_class,vlm_params,classifier_class,classifier_params,dataset_split,n_way,n_support,n_query,n_episodes,accuracy
205,VTTWINS_SimilarityVLM,{},WeightedTextFewShotClassifier,"{""metric"": ""DOT"", ""text_weight"": 5.0}",/home/datasets/kinetics_100_split/test.txt,20,5,1,1000,0.54705
206,VTTWINS_SimilarityVLM,{},WeightedTextFewShotClassifier,"{""metric"": ""DOT"", ""text_weight"": 10.0}",/home/datasets/kinetics_100_split/test.txt,20,5,1,1000,0.5484
207,VTTWINS_SimilarityVLM,{},WeightedTextFewShotClassifier,"{""metric"": ""DOT"", ""text_weight"": 20.0}",/home/datasets/kinetics_100_split/test.txt,20,5,1,1000,0.5492
208,VTTWINS_SimilarityVLM,{},WeightedTextFewShotClassifier,"{""metric"": ""DOT"", ""text_weight"": 0.1}",/home/datasets/kinetics_100_split/test.txt,20,10,1,1000,0.6505
209,VTTWINS_SimilarityVLM,{},WeightedTextFewShotClassifier,"{""metric"": ""DOT"", ""text_weight"": 0.5}",/home/datasets/kinetics_100_split/test.txt,20,10,1,1000,0.5844
210,VTTWINS_SimilarityVLM,{},WeightedTextFewShotClassifier,"{""metric"": ""DOT"", ""text_weight"": 1.0}",/home/datasets/kinetics_100_split/test.txt,20,10,1,1000,0.5745
211,VTTWINS_SimilarityVLM,{},WeightedTextFewShotClassifier,"{""metric"": ""DOT"", ""text_weight"": 2.0}",/home/datasets/kinetics_100_split/test.txt,20,10,1,1000,0.5632
212,VTTWINS_SimilarityVLM,{},WeightedTextFewShotClassifier,"{""metric"": ""DOT"", ""text_weight"": 5.0}",/home/datasets/kinetics_100_split/test.txt,20,10,1,1000,0.5489
213,VTTWINS_SimilarityVLM,{},WeightedTextFewShotClassifier,"{""metric"": ""DOT"", ""text_weight"": 10.0}",/home/datasets/kinetics_100_split/test.txt,20,10,1,1000,0.5496
214,VTTWINS_SimilarityVLM,{},WeightedTextFewShotClassifier,"{""metric"": ""DOT"", ""text_weight"": 20.0}",/home/datasets/kinetics_100_split/test.txt,20,10,1,1000,0.54555


  0%|          | 0/87 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]