In [1]:
import numpy as np
import os, sys
import importlib
from tqdm.notebook import tqdm, trange
import pandas as pd
import json
import itertools
from IPython.display import display

### VLM and Parameter Setup

In [2]:
ENV = os.environ["CONDA_DEFAULT_ENV"]

if ENV == "videoclip":
    from video_clip.video_clip import VideoClipVLM as VLM
    vlm_params = {
        "num_seconds": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
        "sample_strat": ["center", "start", "spread"]
    }

elif ENV == "VLM_MILES":
    from MILES.wrapper import MILES_SimilarityVLM as VLM
    vlm_params = {}

elif ENV == "VLM_CLIP":
    from CLIP.CLIPVLM import ClipVLM as VLM
    vlm_params = {
        "num_frames": [1, 2, 3, 4, 6, 8, 10, 20, 50, 100]
    }

else:
    ValueError(ENV)


### Data Setup

In [3]:
DATA_FOLDER = "/home/datasets/500p/10s_dataset"
MP4_FOLDER = f"{DATA_FOLDER}/10s_clips"
TXT_FOLDER = f"{DATA_FOLDER}/10s_kaldi_texts"

video_names = [name[:-4] for name in os.listdir(MP4_FOLDER) if name.endswith(".mp4")]
video_names.sort()
text_names = [name[:-4] for name in os.listdir(TXT_FOLDER) if name.endswith(".txt")]
text_names.sort()
assert video_names == text_names
pair_names = video_names

In [21]:
DATASET_ID = "500p"

if DATASET_ID == "500p":
    N = len(pair_names)
    vid_paths = np.array([f"{MP4_FOLDER}/{name}.mp4" for name in pair_names])
    vid_text = []
    for name in pair_names:
        text_path = f"{TXT_FOLDER}/{name}.txt"
        with open(text_path, "r") as fp:
            vid_text.append(fp.read().lower().strip())
    vid_text = np.array(vid_text)

elif DATASET_ID == "500p.text_overflow_1_1":
    N = len(pair_names)
    vid_paths = np.array([f"{MP4_FOLDER}/{name}.mp4" for name in pair_names])
    raw_vid_text = []
    for name in pair_names:
        text_path = f"{TXT_FOLDER}/{name}.txt"
        with open(text_path, "r") as fp:
            raw_vid_text.append(fp.read().lower().strip())
    vid_text = []
    for i in range(len(raw_vid_text)):
        vid_text.append(" ".join(raw_vid_text[max(0, i - 1) : min(len(raw_vid_text), i + 2)]))
    vid_text = np.array(vid_text)

else:
    raise ValueError(DATASET_ID)

print(DATASET_ID)
print(f"{N} Video-Text Pairs")
for i in range(N):
    print(f"{i:>3}: {vid_paths[i].split('/')[-1]:>20}")
    print(vid_text[i])
    print()

500p
82 Video-Text Pairs
  0:      Video 1_000.mp4
the nurse is on the phone patient is trying to get nurse's attention by touching her arm the nurse in the hall

  1:      Video 1_001.mp4
is counting up putting on blue gown the

  2:      Video 1_002.mp4
the patient appears slightly more restless as his leg moves the nurse put

  3:      Video 1_003.mp4
it's down the patient's arm and holds his hand for a moment as she walks to check on her coworker in the hall who's counting up then returns to the

  4:      Video 1_004.mp4
patient's bedside resumes looking at her phone and is joined by her clinical colleague who can

  5:      Video 1_005.mp4
in and checks the foley urinary catheter line draining any urine in the line into the receiving bag

  6:      Video 1_006.mp4
where she checks the mill leaders that has been collected hello

  7:      Video 1_007.mp4
staff clinician observes and holds the restrain or mitten protocol the points

  8:      Video 1_008.mp4
to the patient passes p

### Retrieval Tester Function

In [5]:
if os.path.exists("retrieval_results.csv"):
    results = pd.read_csv("retrieval_results.csv")
else:
    results = pd.DataFrame(columns=["vlm_class", "vlm_params", "dataset",
                                    "R@1 (of 82)", "R@5 (of 82)", "R@10 (of 82)", "R@20 (of 82)", "R@50 (of 82)",
                                    "Mean R", "Med R",
                                    "Ranked Text Indices", "Vid-Text Choice Ranks",
                                    "Correct Vid-Text Pair Ranks", "Vid-Text Pairs (Easy->Hard)"])

In [6]:
def retrieval_test(vlm):    
    pbar = trange(2 * N, leave=False)
    vid_embeds = []
    text_embeds = []
    for path in vid_paths:
        vid_embeds.append(vlm.get_video_embeds(path))
        pbar.update(1)
    for text in vid_text:
        text_embeds.append(vlm.get_text_embeds(text))
        pbar.update(1)
    vid_embeds = np.array(vid_embeds)
    text_embeds = np.array(text_embeds)
    
    similarity = vlm.default_similarity_metric()(vid_embeds, text_embeds)
    
    sorted_text_choice_indices = np.argsort(-similarity, axis=1) # i, j = text index which is the jth best match to video index i
    pair_ranks = np.argsort(sorted_text_choice_indices, axis=1) # i, j = rank position of text j for vid i out of all text options (0 = best choice, 81 = worst)
    correct_pair_ranks = pair_ranks[np.arange(N), np.arange(N)] # i = rank position of correct pair (vid i - text i) out of all text options (0 = best, 81 = worst)
    sorted_pair_indices = np.argsort(correct_pair_ranks) # vid-text pair index with best rank to vid-text pair index with worst rank
    
    
    # R@1
    R1_sum = np.sum(correct_pair_ranks < 1)
    R1 = R1_sum / N
    
    # R@5
    R5_sum = np.sum(correct_pair_ranks < 5)
    R5 = R5_sum / N
    
    # R@10
    R10_sum = np.sum(correct_pair_ranks < 10)
    R10 = R10_sum / N
    
    # R@20
    R20_sum = np.sum(correct_pair_ranks < 20)
    R20 = R20_sum / N
    
    # R@50
    R50_sum = np.sum(correct_pair_ranks < 50)
    R50 = R50_sum / N
    
    # Mean/Med Rank
    mean_R = np.mean(correct_pair_ranks) + 1
    med_R = np.median(correct_pair_ranks) + 1
    
    print(f"{vlm.__class__.__name__}   {json.dumps(vlm.params())}")
    print(f"{'R1':>10}{'R5':>10}{'R10':>10}{'R20':>10}{'R50':>10}{'MeanR':>10}{'MedR':>10}")
    print(f"{R1:10.3f}{R5:10.3f}{R10:10.3f}{R20:10.3f}{R50:10.3f}{mean_R:10.3f}{med_R:10.3f}")
    print()
    
    
    # Save results
    # Remove any previously-saved versions of this same test
    global results
    test_spec = {
        "vlm_class": vlm.__class__.__name__,
        "vlm_params": vlm.params(),
        "dataset": DATASET_ID
    }
    prev_matching_tests = (results[list(test_spec.keys())] == pd.Series(test_spec)).all(axis=1)
    if np.any(prev_matching_tests):
        results = results[~prev_matching_tests].reset_index(drop=True)
        
    results.loc[len(results)] = [
        test_spec["vlm_class"],
        test_spec["vlm_params"],
        test_spec["dataset"],
        R1_sum, R5_sum, R10_sum, R20_sum, R50_sum,
        mean_R, med_R,
        sorted_text_choice_indices, pair_ranks,
        correct_pair_ranks, sorted_pair_indices
    ]
    results.to_csv("retrieval_results.csv", index=False)

### Test over all vlm params

In [7]:
# Dynamically display most recent test results
disp = display(display_id=True)
disp.update(results.tail(5))

vlm = None
if len(vlm_params):
    param_list = tqdm(list(itertools.product(*vlm_params.values())))
else:
    param_list = tqdm([[]])
    
for params in param_list:
    params = dict(zip(vlm_params.keys(), params))
    param_list.set_postfix(params)
    
    vlm = VLM(**params)
    
    retrieval_test(vlm)
    disp.update(results.tail(5))

Unnamed: 0,vlm_class,vlm_params,dataset,R@1 (of 82),R@5 (of 82),R@10 (of 82),R@20 (of 82),R@50 (of 82),Mean R,Med R,Ranked Text Indices,Vid-Text Choice Ranks,Correct Vid-Text Pair Ranks,Vid-Text Pairs (Easy->Hard)
90,ClipVLM,"{'path': 'openai/clip-vit-base-patch32', 'num_...",500p.text_overflow_1_1,1,6,9,23,46,41.463415,41.5,[[68 70 62 ... 50 80 32]\n [68 70 62 ... 50 80...,[[28 16 54 ... 75 80 72]\n [32 17 60 ... 70 80...,[28 17 54 15 39 15 15 3 11 14 24 44 68 76 74 ...,[68 69 58 62 7 70 77 72 63 8 56 74 61 67 9 ...
91,ClipVLM,"{'path': 'openai/clip-vit-base-patch32', 'num_...",500p.text_overflow_1_1,1,6,9,22,46,41.378049,40.0,[[68 70 62 ... 50 80 32]\n [68 70 69 ... 50 80...,[[28 16 55 ... 76 80 72]\n [30 18 58 ... 70 80...,[28 18 52 15 39 14 13 3 11 13 26 44 68 76 73 ...,[68 69 58 62 7 70 63 72 77 56 61 8 9 67 6 ...
92,ClipVLM,"{'path': 'openai/clip-vit-base-patch32', 'num_...",500p.text_overflow_1_1,1,6,9,22,48,41.353659,41.5,[[68 70 62 ... 50 80 32]\n [68 70 69 ... 50 80...,[[28 16 54 ... 77 80 72]\n [30 17 58 ... 70 80...,[28 17 52 16 37 14 14 3 11 13 26 44 68 76 73 ...,[68 69 58 7 62 70 63 72 77 56 8 61 9 67 5 ...
93,ClipVLM,"{'path': 'openai/clip-vit-base-patch32', 'num_...",500p.text_overflow_1_1,1,6,9,22,48,41.426829,41.5,[[68 70 62 ... 50 80 32]\n [68 70 69 ... 50 80...,[[28 16 55 ... 77 80 72]\n [29 18 58 ... 70 80...,[28 18 52 15 37 14 14 3 11 13 26 44 67 76 73 ...,[68 69 58 7 70 62 63 72 77 56 8 61 9 67 5 ...
94,MILES_SimilarityVLM,{},500p.text_overflow_1_1,1,6,11,23,51,40.512195,42.5,"[[62, 68, 63, 69, 67, 61, 18, 3, 75, 23, 2, 19...","[[33, 22, 10, 7, 18, 24, 61, 35, 69, 58, 46, 3...","[33, 28, 12, 7, 18, 19, 60, 32, 67, 57, 47, 26...","[63, 68, 62, 61, 18, 69, 75, 23, 74, 3, 28, 22..."


  0%|          | 0/1 [00:00<?, ?it/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


######USING ATTENTION STYLE:  frozen-in-time


  0%|          | 0/164 [00:00<?, ?it/s]

MILES_SimilarityVLM   {}
        R1        R5       R10       R20       R50     MeanR      MedR
     0.012     0.073     0.134     0.280     0.622    40.512    42.500



In [20]:
test = []
for x in results["Correct Vid-Text Pair Ranks"][[3, 30, 40]].values:
    test.append([int(rank) for rank in x.replace("[", "").replace("]", "").replace("\n", "").split()])
test = np.array(test)
test = np.mean(test, axis=0)
print(test)
print(np.argsort(test))

[15.66666667 46.66666667  8.66666667 15.33333333  7.66666667 32.
 63.         37.33333333 20.33333333 31.         25.33333333 19.66666667
 70.         59.66666667 24.33333333 23.         20.66666667 62.66666667
 30.33333333 20.33333333 38.33333333 67.         24.66666667 21.
 15.         37.66666667 63.         17.         35.66666667 31.66666667
 37.66666667 58.66666667 63.66666667 65.66666667 23.33333333 55.
 61.66666667 62.66666667 14.33333333 15.33333333 59.         14.66666667
 48.66666667 64.33333333 69.66666667 42.33333333 80.         56.
 41.66666667 60.33333333 62.66666667 56.         30.33333333 30.
 63.66666667 59.33333333 59.66666667  3.66666667 31.66666667 24.33333333
  8.66666667 44.33333333  6.         27.         53.66666667 17.33333333
 51.         41.33333333  1.66666667 23.         25.66666667 64.33333333
 47.         25.         52.         24.66666667 21.         33.33333333
 40.         62.66666667 57.33333333 39.        ]
[68 57 62  4  2 60 38 41 24  3 39  0 27 6