In [1]:
import os
os.chdir('/scratch/sagarsj42')
os.environ['TRANSFORMERS_CACHE'] = '/scratch/sagarsj42'

In [2]:
import faiss
import numpy as np
import pandas as pd

In [3]:
DATASET_INFO_DIR = './yt8m-clips-dataset-info'
EMBEDS_DIR = 'zeroshot-embeds'
EMB_SIZE = 300
RET_SIZE = 20000

In [4]:
split = 'test'
media = ['text', 'audio', 'video']

In [5]:
clip_df = pd.read_json(os.path.join(DATASET_INFO_DIR, split, 'clip-info.jsonl'), lines=True)

print(clip_df.info())

clip_df.head()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 14806 entries, 0 to 14805
Data columns (total 6 columns):
 #   Column           Non-Null Count  Dtype  
---  ------           --------------  -----  
 0   vid              14806 non-null  object 
 1   clip_no          14806 non-null  int64  
 2   audio_clip_name  14806 non-null  object 
 3   audio_clip_dur   14806 non-null  float64
 4   video_clip_name  14806 non-null  object 
 5   video_clip_dur   14806 non-null  float64
dtypes: float64(2), int64(1), object(3)
memory usage: 694.2+ KB
None


Unnamed: 0,vid,clip_no,audio_clip_name,audio_clip_dur,video_clip_name,video_clip_dur
0,ZKBM2XCWfo8,21,ZKBM2XCWfo8-audio-21.mp3,8.0,ZKBM2XCWfo8-video-21.mp4,8.01
1,ZKBM2XCWfo8,20,ZKBM2XCWfo8-audio-20.mp3,8.0,ZKBM2XCWfo8-video-20.mp4,8.01
2,ZKBM2XCWfo8,22,ZKBM2XCWfo8-audio-22.mp3,8.0,ZKBM2XCWfo8-video-22.mp4,8.01
3,ZKBM2XCWfo8,23,ZKBM2XCWfo8-audio-23.mp3,8.0,ZKBM2XCWfo8-video-23.mp4,8.01
4,ZKBM2XCWfo8,27,ZKBM2XCWfo8-audio-27.mp3,8.0,ZKBM2XCWfo8-video-27.mp4,8.01


In [6]:
m_embeds = dict()
m_indices = dict()
for m in media:
    embeds = list()
    for _, row in clip_df.iterrows():
        vid = row['vid']
        clip_no = row['clip_no']
        file_name = f'{vid}-{clip_no}-{m}-emb.npy'
        sample_embed = np.load(os.path.join(EMBEDS_DIR, split, m, file_name))
        embeds.append(sample_embed)
    m_embeds[m] = np.array(embeds)
    
    print(m, m_embeds[m].shape)
    
    index = faiss.IndexFlatIP(EMB_SIZE)
    index.add(m_embeds[m])
    m_indices[m] = index
    
    print('Index constructed')

m_indices

text (14806, 300)
Index constructed
audio (14806, 300)
Index constructed
video (14806, 300)
Index constructed


{'text': <faiss.swigfaiss_avx2.IndexFlatIP; proxy of <Swig Object of type 'faiss::IndexFlatIP *' at 0x7f9e902558a0> >,
 'audio': <faiss.swigfaiss_avx2.IndexFlatIP; proxy of <Swig Object of type 'faiss::IndexFlatIP *' at 0x7f9e8aef0f00> >,
 'video': <faiss.swigfaiss_avx2.IndexFlatIP; proxy of <Swig Object of type 'faiss::IndexFlatIP *' at 0x7f9e8b036cf0> >}

In [7]:
for i in range(len(media)):
    for j in range(i+1, len(media)):
        m_1 = media[i]
        m_2 = media[j]
        print(f'Retrieval: {m_1} to {m_2}')
        _, res = m_indices[media[j]].search(m_embeds[media[i]], RET_SIZE)
        
        n_rows = res.shape[0]
        r_1 = 0
        r_5 = 0
        r_10 = 0
        ranks = list()
        for k in range(n_rows):
            search = res[k, :]
            try:
                pos = np.where(search == k)[0][0] + 1
            except IndexError:
                pos = RET_SIZE + 1
            if pos <= 1:
                r_1 += 1
            elif pos <= 5:
                r_5 =+ 1
            elif pos <= 10:
                r_10 += 1
            ranks.append(pos)
        ranks = np.array(ranks)
        mean_r = ranks.mean()
        median_r = np.median(ranks)
        r_1 = r_1 / n_rows * 100.0
        r_5 = r_5 / n_rows * 100.0
        r_10 = r_10 / n_rows * 100.0
        print(f'Recall @ 1: {r_1}, @ 5: {r_5}, @ 10: {r_10}')
        print(f'Mean rank: {mean_r}, median rank: {median_r}')
        print()

Retrieval: text to audio
Recall @ 1: 0.013508037282182897, @ 5: 0.006754018641091449, @ 10: 0.03377009320545725
Mean rank: 7651.719910846954, median rank: 7747.5

Retrieval: text to video
Recall @ 1: 0.006754018641091449, @ 5: 0.006754018641091449, @ 10: 0.040524111846548694
Mean rank: 7359.062407132244, median rank: 7337.5

Retrieval: audio to video
Recall @ 1: 0.013508037282182897, @ 5: 0.006754018641091449, @ 10: 0.047278130487640145
Mean rank: 7229.739362420641, median rank: 7056.0

