In [9]:
import numpy as np
import os, sys
import importlib
from tqdm.notebook import tqdm, trange

import torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

from SimilarityVLM import SimilarityVLM
from dataset.dataset import FewShotTaskDataset, SequentialVideoDataset, SequentialCategoryNameDataset
from FewShotClassifier import FewShotClassifier

### Choose VLM to Test

Note, this notebook must be run using the corresponding conda environment

In [10]:
VLM = importlib.import_module("VT-TWINS.wrapper").VTTWINS_SimilarityVLM
vlm = VLM(reset_cache=False)

### Choose Dataset to Test

In [11]:
DATASET_SPLIT_PATH = "/home/datasets/kinetics_100_split/test.txt"

### Filling the Cache

In [12]:
video_dataset = SequentialVideoDataset(DATASET_SPLIT_PATH)
for vid_path in tqdm(video_dataset):
    vlm.get_video_embeds(vid_path)
    
vlm.save_cache()

  0%|          | 0/2400 [00:00<?, ?it/s]

### Testing Function

In [4]:
def few_shot_accuracy(classifier: FewShotClassifier, dataset_split_path: str, n_way: int, n_support: int, n_query: int = 1, n_episodes: int = 1000) -> float:
    
    # Load dataset to generate tasks with the desired params
    dataset = FewShotTaskDataset(dataset_split_path, n_episodes, n_way, n_support, n_query)
    
    correct_predictions = 0
    total_queries = 0
    for vid_paths, category_names in tqdm(dataset):
        
        query_vid_paths = vid_paths[:, n_support:]
        if n_support > 0:
            support_vid_paths = vid_paths[:, :n_support]
        else:
            support_vid_paths = None
            
        query_predictions = classifier.predict(category_names, support_vid_paths, query_vid_paths)
        
        correct_predictions += np.sum(query_predictions == np.arange(n_way)[:, None])
        total_queries += n_way * n_query
        
    return correct_predictions / total_queries

### Run the Test

In [5]:
classifier = FewShotClassifier(vlm, metric=None)

In [6]:
few_shot_accuracy(classifier, DATASET_SPLIT_PATH, 5, 5, 2, 20)

  0%|          | 0/20 [00:00<?, ?it/s]

  video = th.from_numpy(video)


KeyboardInterrupt: 