In [2]:
import os
os.chdir('/scratch/sagarsj42')
os.environ['TRANSFORMERS_CACHE'] = '/scratch/sagarsj42'

In [3]:
import sys
import string
import random

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

import numpy as np
from bpemb import BPEmb

from tri_model import TriModel
from trimodal_dataset import CosineSimDataset, collate_trimodal_cosine

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# !wget https://huggingface.co/lukewys/laion_clap/resolve/main/music_audioset_epoch_15_esc_90.14.pt

In [5]:
DATASET_INFO_DIR = './yt8m-clips-dataset-info'
CLIP_INFO_FILENAME = 'clip-info.jsonl'
VID_INFO_FILENAME = 'video-info.jsonl'
AUDIO_FEATURES_DIR = './yt8m-audio-features'
VIDEO_FEATURES_DIR = './yt8m-video-features'
AUDIO_MODEL_KEY = 'HTSAT-base'
AUDIO_MODEL_PATH = './music_audioset_epoch_15_esc_90.14.pt'
VIDEO_MODEL_KEY = 'MCG-NJU/videomae-base'
EMB_SIZE = 300
BPE_VOCAB_SIZE = 10000
TRAIN_BATCH_SIZE = 2
EVAL_BATCH_SIZE = 4

In [6]:
split = 'dev'

In [7]:
text_bpe_model = BPEmb(lang='en', vs=BPE_VOCAB_SIZE, dim=EMB_SIZE)

text_bpe_model

BPEmb(lang=en, vs=10000, dim=300)

In [8]:
train_ds = CosineSimDataset('train', DATASET_INFO_DIR, text_bpe_model, AUDIO_FEATURES_DIR, VIDEO_FEATURES_DIR)
dev_ds = CosineSimDataset('dev', DATASET_INFO_DIR, text_bpe_model, AUDIO_FEATURES_DIR, VIDEO_FEATURES_DIR)
test_ds = CosineSimDataset('test', DATASET_INFO_DIR, text_bpe_model, AUDIO_FEATURES_DIR, VIDEO_FEATURES_DIR)

len(train_ds), len(dev_ds), len(test_ds)

(58847, 15961, 14806)

In [9]:
%%timeit
train_ds[100]

1.81 ms ± 6.16 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [10]:
sample = train_ds[100]
sample.keys()

dict_keys(['text_inp', 'audio_inp', 'video_inp'])

In [11]:
sample['text_inp']

array([ 216, 8401, 9918])

In [12]:
type(sample['audio_inp']), sample['audio_inp'].shape

(numpy.ndarray, (384001,))

In [13]:
type(sample['video_inp']), sample['video_inp'].shape

(numpy.ndarray, (16, 3, 224, 224))

In [14]:
np.concatenate((np.array([1, 2]), np.zeros(0)))

array([1., 2.])

In [15]:
text_bpe_model.encode_ids('')

[]

In [16]:
text_bpe_model.decode_ids([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

['', '', '', 't', 'a', 'he', 'in', 'the', 'er', 'on', 's']

In [17]:
train_dl = DataLoader(train_ds, collate_fn=collate_trimodal_cosine, 
                      batch_size=TRAIN_BATCH_SIZE, shuffle=True)
dev_dl = DataLoader(dev_ds, collate_fn=collate_trimodal_cosine, 
                      batch_size=EVAL_BATCH_SIZE, shuffle=False)
test_dl = DataLoader(test_ds, collate_fn=collate_trimodal_cosine, 
                      batch_size=EVAL_BATCH_SIZE, shuffle=False)

len(train_dl), len(dev_dl), len(test_dl)

(29424, 3991, 3702)

In [18]:
sample_batch = next(iter(train_dl))

sample_batch.keys()

dict_keys(['text_batch', 'audio_batch', 'video_batch'])

In [19]:
sample_batch['text_batch']

tensor([[2023, 7196],
        [  27, 5684]])

In [20]:
sample_batch['audio_batch'].shape

torch.Size([2, 384000])

In [21]:
sample_batch['video_batch'].shape

torch.Size([2, 16, 3, 224, 224])

In [22]:
trimodel = TriModel(EMB_SIZE, BPE_VOCAB_SIZE, AUDIO_MODEL_KEY, AUDIO_MODEL_PATH, VIDEO_MODEL_KEY).to('cpu')
trimodel

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


TriModel(
  (text_model): TextModel(
    (model): Sequential(
      (0): Embedding(10000, 300)
      (1): LayerNorm((300,), eps=1e-05, elementwise_affine=True)
      (2): GELU(approximate='none')
      (3): Dropout(p=0.2, inplace=False)
      (4): Linear(in_features=300, out_features=300, bias=True)
    )
  )
  (audio_model): AudioModel(
    (encoder): CLAP_Module(
      (model): CLAP(
        (audio_branch): HTSAT_Swin_Transformer(
          (spectrogram_extractor): Spectrogram(
            (stft): STFT(
              (conv_real): Conv1d(1, 513, kernel_size=(1024,), stride=(480,), bias=False)
              (conv_imag): Conv1d(1, 513, kernel_size=(1024,), stride=(480,), bias=False)
            )
          )
          (logmel_extractor): LogmelFilterBank()
          (spec_augmenter): SpecAugmentation(
            (time_dropper): DropStripes()
            (freq_dropper): DropStripes()
          )
          (bn0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=T

In [23]:
with torch.no_grad():
    embeds = trimodel(*sample_batch.values())

embeds.keys()

dict_keys(['text_emb', 'audio_emb', 'video_emb'])

In [24]:
embeds['text_emb'].shape, embeds['audio_emb'].shape, embeds['video_emb'].shape

(torch.Size([2, 300]), torch.Size([2, 300]), torch.Size([2, 300]))

In [25]:
def get_n_params(model):
    n_params = 0
    for p in model.parameters():
        n_params += p.numel()
    
    return n_params

print('# params')
print('text model:', get_n_params(trimodel.text_model))
print('audio model:', get_n_params(trimodel.audio_model))
print('video model:', get_n_params(trimodel.video_model))
print('tri model:', get_n_params(trimodel))

# params
text model: 3090900
audio model: 73892229
video model: 86459436
tri model: 163442565


In [26]:
t = torch.rand(2,3)
t

tensor([[0.5122, 0.2314, 0.9860],
        [0.6360, 0.8131, 0.1592]])

In [27]:
torch.arange(4)

tensor([0, 1, 2, 3])

In [28]:
nn.CrossEntropyLoss()(
    torch.tensor([[5.0, 0.0],
                  [0.0, 5.0]]),
    torch.tensor([0, 1])
)

tensor(0.0067)

In [29]:
nn.NLLLoss()(
    torch.tensor([[0.0, 0.0],
                  [0.0, 0.0]]),
    torch.tensor([0, 1])
)

tensor(0.)

In [36]:
def l2_normalize(t, p=2, dim=1):
    t = t / t.norm(p=p, dim=dim).unsqueeze(1)

    return t


def contrastive_loss(emb1, emb2):
    prods = emb1.matmul(emb2.T)
    labels = torch.arange(emb1.shape[0]).to(emb1.device)
    
    loss_func = nn.CrossEntropyLoss()
    loss1 = loss_func(prods, labels)
    loss2 = loss_func(prods.T, labels)
    loss = loss1 + loss2

    return loss


def trimodal_contrastive_objective(embeds):
    cl = contrastive_loss
    te = embeds['text_emb']
    ae = embeds['audio_emb']
    ve = embeds['video_emb']
    
    trimodal_loss = cl(te, ae) + cl(te, ve) + cl(ae, ve)
    
    return trimodal_loss

In [37]:
trimodal_contrastive_objective(embeds)

tensor(13.2661)