In [1]:
import numpy as np
import os, sys
import importlib
from tqdm.notebook import tqdm, trange
import pandas as pd
import json
import itertools
from IPython.display import display

# Import test handler
from FewShotTestHandler import FewShotTestHandler

# Dataset Handler
from dataset import DatasetHandler

# Most common classifier
from classifier import WeightedTextFewShotClassifier

### Load VLM

Note: This notebook must be run using the corresponding conda environment

In [2]:
if False:
    from VTTWINS.wrapper import VTTWINS_SimilarityVLM
    vlm = VTTWINS_SimilarityVLM(reset_cache=False)
    
if False:
    from CLIP.CLIPVLM import ClipVLM
    vlm = ClipVLM(reset_cache=False)
    
if False:
    from UNIVL.wrapper import UniVL_SimilarityVLM
    vlm = UniVL_SimilarityVLM(reset_cache=False)
    
if True:
    from MILES.wrapper import MILES_SimilarityVLM
    vlm = MILES_SimilarityVLM(reset_cache=False)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


######USING ATTENTION STYLE:  frozen-in-time


### Test Handler

Runs few-shot testing and permanently saves results

In [3]:
test_handler = FewShotTestHandler()

### Run Repeated Tests

In [4]:
#DATASET = DatasetHandler("kinetics_100", split="val")
DATASET = DatasetHandler("moma_sact", split="val")
print(f"Dataset:\t{DATASET.id()}")
print(f"Categories:\t{DATASET.category_count()}")
print(f"Videos:\t\t{DATASET.video_count()}")

momaapi.lookup._read_anns() took 1.10209059715271 sec
momaapi.statistics._read_statistics() took 0.0015096664428710938 sec
Dataset:	moma_sact.val
Categories:	91
Videos:		2657


In [5]:
N_WAY_LIST = [5, 10]
N_SUPPORT_LIST = [0, 1, 2, 5, 10]
TEXT_WEIGHT_LIST = [0, 0.1, 0.5, 1, 2, 5, 10, 20]
N_QUERY = 1
N_EPISODES = 1000

In [6]:
# Fill VLM Cache for chosen dataset
test_handler.fill_cache(vlm, DATASET)

  0%|          | 0/91 [00:00<?, ?it/s]

  0%|          | 0/2657 [00:00<?, ?it/s]

In [7]:
# Dynamically display most recent test results
disp = display(display_id=True)
disp.update(test_handler.results.tail(5))

param_list = list(itertools.product(N_WAY_LIST, N_SUPPORT_LIST, TEXT_WEIGHT_LIST))
param_list = list(filter(lambda x: not (x[1] == 0 and x[2] != 1), param_list)) # Remove repeated zero-shot tests with different text_weights

for n_way, n_support, text_weight in tqdm(param_list):
    classifier = WeightedTextFewShotClassifier(vlm, metric=None, text_weight=text_weight)
    
    test_handler.run_few_shot_test(classifier, DATASET,
                                   n_way=n_way, n_support=n_support, n_query=N_QUERY, n_episodes=N_EPISODES)
    
    disp.update(test_handler.results.tail(5))

Unnamed: 0,vlm_class,vlm.num_frames,vlm.path,vlm.sample_strat,classifier_class,classifier.metric,classifier.text_weight,dataset,n_way,n_support,n_query,n_episodes,accuracy
791,MILES_SimilarityVLM,,,,WeightedTextFewShotClassifier,COSINE,1.0,moma_sact.val,10,10,1,1000,0.8359
792,MILES_SimilarityVLM,,,,WeightedTextFewShotClassifier,COSINE,2.0,moma_sact.val,10,10,1,1000,0.828
793,MILES_SimilarityVLM,,,,WeightedTextFewShotClassifier,COSINE,5.0,moma_sact.val,10,10,1,1000,0.8346
794,MILES_SimilarityVLM,,,,WeightedTextFewShotClassifier,COSINE,10.0,moma_sact.val,10,10,1,1000,0.8354
795,MILES_SimilarityVLM,,,,WeightedTextFewShotClassifier,COSINE,20.0,moma_sact.val,10,10,1,1000,0.8421


  0%|          | 0/66 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]