In [1]:
import os
os.chdir('/scratch/sagarsj42')
os.environ['TRANSFORMERS_CACHE'] = '/scratch/sagarsj42'

In [2]:
import random

import torch
from torch.utils.data import DataLoader

import numpy as np
from tqdm import tqdm
from bpemb import BPEmb

from tri_model import TriModel
from trimodal_dataset import CosineSimDatasetWithMD, collate_trimodal_with_metadata

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
SEED = 15
EXP_NAME = 'weighted-contrastive'
DATASET_INFO_DIR = './yt8m-clips-dataset-info'
AUDIO_FEATURES_DIR = './yt8m-audio-features'
VIDEO_FEATURES_DIR = './yt8m-video-features'
EMB_SIZE = 300
BPE_VOCAB_SIZE = 10000
BATCH_SIZE = 8

EMBEDS_DIR = f'{EXP_NAME}-embeds'
# EMBEDS_DIR = 'zeroshot-embeds'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

split = 'test'

EMBEDS_DIR, DEVICE

('weighted-contrastive-embeds', device(type='cuda'))

In [4]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7f9b15459c10>

In [5]:
text_bpe_model = BPEmb(lang='en', vs=BPE_VOCAB_SIZE, dim=EMB_SIZE)
ds = CosineSimDatasetWithMD(split, DATASET_INFO_DIR, text_bpe_model, 
            AUDIO_FEATURES_DIR, VIDEO_FEATURES_DIR)

len(ds)

14806

In [6]:
dl = DataLoader(ds, collate_fn=collate_trimodal_with_metadata, batch_size=BATCH_SIZE, shuffle=False)

len(dl)

1851

In [7]:
sample_batch = next(iter(dl))

sample_batch.keys()

dict_keys(['text_batch', 'audio_batch', 'video_batch', 'vids', 'clip_nos'])

In [8]:
sample_batch['text_batch'].shape, sample_batch['audio_batch'].shape, sample_batch['video_batch'].shape, \
sample_batch['vids'], sample_batch['clip_nos']

(torch.Size([8, 1]),
 torch.Size([8, 384000]),
 torch.Size([8, 16, 3, 224, 224]),
 ['ZKBM2XCWfo8',
  'ZKBM2XCWfo8',
  'ZKBM2XCWfo8',
  'ZKBM2XCWfo8',
  'ZKBM2XCWfo8',
  'ZKBM2XCWfo8',
  'ZKBM2XCWfo8',
  'ZKBM2XCWfo8'],
 [21, 20, 22, 23, 27, 26, 18, 24])

In [9]:
ckpt = torch.load(os.path.join(EXP_NAME, 'best.pth'))
model_args = ckpt['model_args']

model_args

{'emb_size': 300,
 'bpe_vocab_size': 10000,
 'audio_model_key': 'HTSAT-base',
 'audio_model_path': 'music_audioset_epoch_15_esc_90.14.pt',
 'video_model_key': 'MCG-NJU/videomae-base'}

In [10]:
model = TriModel(**model_args)
# model.load_state_dict(ckpt['model_state_dict'])
model.to(DEVICE)

model

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


TriModel(
  (text_model): TextModel(
    (model): Sequential(
      (0): Embedding(10000, 300)
      (1): LayerNorm((300,), eps=1e-05, elementwise_affine=True)
      (2): GELU(approximate='none')
      (3): Dropout(p=0.2, inplace=False)
      (4): Linear(in_features=300, out_features=300, bias=True)
    )
  )
  (audio_model): AudioModel(
    (encoder): CLAP_Module(
      (model): CLAP(
        (audio_branch): HTSAT_Swin_Transformer(
          (spectrogram_extractor): Spectrogram(
            (stft): STFT(
              (conv_real): Conv1d(1, 513, kernel_size=(1024,), stride=(480,), bias=False)
              (conv_imag): Conv1d(1, 513, kernel_size=(1024,), stride=(480,), bias=False)
            )
          )
          (logmel_extractor): LogmelFilterBank()
          (spec_augmenter): SpecAugmentation(
            (time_dropper): DropStripes()
            (freq_dropper): DropStripes()
          )
          (bn0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=T

In [11]:
inp_batch = {
    'text_batch': sample_batch['text_batch'].to(DEVICE),
    'audio_batch': sample_batch['audio_batch'].to(DEVICE),
    'video_batch': sample_batch['video_batch'].to(DEVICE)
}
with torch.no_grad():
    embeds = model(*inp_batch.values())

embeds

{'text_emb': tensor([[ 0.1906,  0.4494,  0.0010,  ...,  0.2890, -0.5092,  0.2342],
         [ 0.0437,  0.6291, -0.0514,  ...,  0.8207, -0.5110,  0.3962],
         [ 0.1707,  0.3814, -0.4413,  ...,  0.6442, -0.2961,  0.3346],
         ...,
         [ 0.2324,  0.5423,  0.0429,  ...,  0.1731, -0.4147,  0.3479],
         [ 0.4373,  0.5902, -0.1453,  ...,  0.5791, -0.4079,  0.2834],
         [ 0.3658,  0.6429, -0.0472,  ...,  0.5376, -0.5122,  0.6119]],
        device='cuda:0'),
 'audio_emb': tensor([[ 0.5166, -0.3874,  0.1859,  ...,  0.0628,  0.0161, -0.3880],
         [ 0.3991, -0.5823, -0.1141,  ...,  0.0145,  0.1358, -0.4432],
         [ 0.5374, -0.1998,  0.1334,  ..., -0.0272, -0.1605, -0.1940],
         ...,
         [ 0.5221, -0.6458,  0.1965,  ..., -0.2781, -0.2170, -0.3621],
         [ 0.5072, -0.0896, -0.0030,  ..., -0.1893, -0.1466,  0.1588],
         [ 0.7991, -0.8976, -0.1543,  ..., -0.4348, -0.1894, -0.3201]],
        device='cuda:0'),
 'video_emb': tensor([[ 0.0120,  0.5583, 

In [None]:
for split in ['test', 'dev', 'train']:
    ds = CosineSimDatasetWithMD(split, DATASET_INFO_DIR, text_bpe_model, 
                AUDIO_FEATURES_DIR, VIDEO_FEATURES_DIR)
    dl = DataLoader(ds, collate_fn=collate_trimodal_with_metadata, batch_size=BATCH_SIZE, shuffle=False)

    os.makedirs(os.path.join(EMBEDS_DIR, split, 'text'), exist_ok=True)
    os.makedirs(os.path.join(EMBEDS_DIR, split, 'audio'), exist_ok=True)
    os.makedirs(os.path.join(EMBEDS_DIR, split, 'video'), exist_ok=True)

    for batch in tqdm(dl):
        vids = batch['vids']
        clip_nos = batch['clip_nos']
        inp_batch = {
            'text_batch': batch['text_batch'].to(DEVICE),
            'audio_batch': batch['audio_batch'].to(DEVICE),
            'video_batch': batch['video_batch'].to(DEVICE)
        }

        with torch.no_grad():
            embeds = model(*inp_batch.values())

        for i in range(len(vids)):
            vid = vids[i]
            clip_no = clip_nos[i]
            te = embeds['text_emb'][i, :].cpu().numpy()
            ae = embeds['audio_emb'][i, :].cpu().numpy()
            ve = embeds['video_emb'][i, :].cpu().numpy()

            np.save(os.path.join(EMBEDS_DIR, split, 'text', f'{vid}-{clip_no}-text-emb.npy'), te)
            np.save(os.path.join(EMBEDS_DIR, split, 'audio', f'{vid}-{clip_no}-audio-emb.npy'), ae)
            np.save(os.path.join(EMBEDS_DIR, split, 'video', f'{vid}-{clip_no}-video-emb.npy'), ve)

100%|██████████████████████████████████████████████████| 1851/1851 [40:04<00:00,  1.30s/it]
100%|██████████████████████████████████████████████████| 1996/1996 [42:50<00:00,  1.29s/it]
 67%|████████████████████████████████▎               | 4953/7356 [1:34:34<48:28,  1.21s/it]