In [1]:
import os
os.chdir('/scratch/sagarsj42')
os.environ['TRANSFORMERS_CACHE'] = '/scratch/sagarsj42'

In [2]:
import sys
import string
import random

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

import transformers
from transformers import VideoMAEModel

import laion_clap
import pandas as pd
import numpy as np
from bpemb import BPEmb

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# !wget https://huggingface.co/lukewys/laion_clap/resolve/main/music_audioset_epoch_15_esc_90.14.pt

In [4]:
DATASET_INFO_DIR = './yt8m-clips-dataset-info'
CLIP_INFO_FILENAME = 'clip-info.jsonl'
VID_INFO_FILENAME = 'video-info.jsonl'
AUDIO_FEATURES_DIR = './yt8m-audio-features'
VIDEO_FEATURES_DIR = './yt8m-video-features'
AUDIO_MODEL_KEY = 'HTSAT-base'
AUDIO_MODEL_PATH = './music_audioset_epoch_15_esc_90.14.pt'
VIDEO_MODEL_KEY = 'MCG-NJU/videomae-base'
EMB_SIZE = 300
BPE_VOCAB_SIZE = 10000
TRAIN_BATCH_SIZE = 2
EVAL_BATCH_SIZE = 4

In [5]:
split = 'dev'

In [6]:
vid_df = pd.read_json(os.path.join(DATASET_INFO_DIR, split, VID_INFO_FILENAME), lines=True)

print(vid_df.info())

vid_df.head()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 499 entries, 0 to 498
Data columns (total 9 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   vid          499 non-null    object 
 1   n_clips      499 non-null    int64  
 2   audio_dur    499 non-null    float64
 3   video_dur    499 non-null    float64
 4   split        499 non-null    object 
 5   labels       499 non-null    object 
 6   title        499 non-null    object 
 7   description  499 non-null    object 
 8   tags         499 non-null    object 
dtypes: float64(2), int64(1), object(6)
memory usage: 35.2+ KB
None


Unnamed: 0,vid,n_clips,audio_dur,video_dur,split,labels,title,description,tags
0,yieL_efMuE0,28,223.376,223.22,dev,"[Concert, Musical ensemble]",Goencho Avaz - CIELDA PEREIRA,Herald's Goan Voice - Cielda Pereira\nMore her...,"Joegoauk,goa,cielda,konkani,Lorna,Chris perry,..."
1,Yn_mhi1dDAA,47,371.473,371.29,dev,"[Musician, Guitar, String instrument, Acoustic...",Blues Guitar Lick in Minor Pentatonic Scale - ...,"Please watch: ""Beginner Acoustic guitar lesson...","Pentatonic Scale,Blues (Musical Genre),Lick,Gu..."
2,06N2Msd1qUU,26,205.636,205.51,dev,[Music video],AMV - K-Pop Culture,Music by TAK (https://youtu.be/pftsmKHvlvY)\n\...,"Culture (Website Category),amv,kpop,k-pop,K-po..."
3,SI-HfG-y4dU,43,346.105,346.07,dev,[Music video],DIAURA Lily-sub español,"Aqui les dejo un Bonus, Disfrutenlo! ;)\n\nTr...","Diaura (Musical Group),Visual Kei (Musical Genre)"
4,oMaFCb12BMA,21,169.924,169.76,dev,[Music video],Youtube Rewind INDONESIA 2014,Youtube Rewind Indonesia 2015 https://www.yout...,"youtube rewind 2014,youtube rewind indonesia 2..."


In [7]:
clip_df = pd.read_json(os.path.join(DATASET_INFO_DIR, split, CLIP_INFO_FILENAME), lines=True)

print(clip_df.info())

clip_df.head()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 15961 entries, 0 to 15960
Data columns (total 6 columns):
 #   Column           Non-Null Count  Dtype  
---  ------           --------------  -----  
 0   vid              15961 non-null  object 
 1   clip_no          15961 non-null  int64  
 2   audio_clip_name  15961 non-null  object 
 3   audio_clip_dur   15961 non-null  float64
 4   video_clip_name  15961 non-null  object 
 5   video_clip_dur   15961 non-null  float64
dtypes: float64(2), int64(1), object(3)
memory usage: 748.3+ KB
None


Unnamed: 0,vid,clip_no,audio_clip_name,audio_clip_dur,video_clip_name,video_clip_dur
0,yieL_efMuE0,18,yieL_efMuE0-audio-18.mp3,8.0,yieL_efMuE0-video-18.mp4,8.01
1,yieL_efMuE0,24,yieL_efMuE0-audio-24.mp3,8.0,yieL_efMuE0-video-24.mp4,8.01
2,yieL_efMuE0,25,yieL_efMuE0-audio-25.mp3,8.0,yieL_efMuE0-video-25.mp4,8.01
3,yieL_efMuE0,19,yieL_efMuE0-audio-19.mp3,8.0,yieL_efMuE0-video-19.mp4,8.01
4,yieL_efMuE0,27,yieL_efMuE0-audio-27.mp3,8.0,yieL_efMuE0-video-27.mp4,8.01


In [8]:
text_bpe_model = BPEmb(lang='en', vs=BPE_VOCAB_SIZE, dim=EMB_SIZE)

text_bpe_model

BPEmb(lang=en, vs=10000, dim=300)

In [9]:
text_bpe_model.vectors.shape

(10000, 300)

In [10]:
class CosineSimDataset(Dataset):
    def __init__(self, split, clips_info_path, text_bpe_model, audio_features_path, video_features_path):
        self.split = split
        self.clips_info_path = clips_info_path
        self.text_bpe_model = text_bpe_model
        self.audio_features_path = audio_features_path
        self.video_features_path = video_features_path
        
        self.vid_df = pd.read_json(os.path.join(self.clips_info_path, split, 
                            'video-info.jsonl'), lines=True)
        self.clips_df = pd.read_json(os.path.join(self.clips_info_path, self.split, 
                            'clip-info.jsonl'), lines=True)
        
        allowed_text = set(string.ascii_lowercase)
        all_texts = list()
        for _, row in self.vid_df.iterrows():
            text = [t.lower().strip() for t in row['labels'] + row['tags'].split(',')]
            text = filter(lambda t: len(t) > 2, text)
            text = list(filter(lambda t: set(t) <= set(allowed_text), text))
            all_texts.append(text)
        self.vid_df['texts'] = all_texts
        
        
    def __len__(self):
        return self.clips_df.shape[0]
    
    
    def __getitem__(self, idx):
        clips_row = self.clips_df.iloc[idx]
        vid = clips_row['vid']
        vid_row = self.vid_df[self.vid_df['vid'] == vid].iloc[0]
        split_dir = vid_row['split']
        
        try:
            text = random.choice(vid_row['texts'])
            text_ids = np.array(self.text_bpe_model.encode_ids(text))
        except IndexError:
            text_ids = np.array([0])
        audio_clip_filename = clips_row['audio_clip_name']
        audio_feat_filename = audio_clip_filename[:-4].replace('-audio-', '-audfeat-') + '.npy'
        video_clip_filename = clips_row['video_clip_name']
        video_feat_filename = video_clip_filename[:-4].replace('-video-', '-vidfeat-') + '.npy'
        audio_feat = np.load(os.path.join(self.audio_features_path, split_dir, vid, audio_feat_filename))
        video_feat = np.load(os.path.join(self.video_features_path, split_dir, vid, video_feat_filename))
    
        return {'text_inp': text_ids, 'audio_inp': audio_feat, 'video_inp': video_feat}

In [11]:
train_ds = CosineSimDataset('train', DATASET_INFO_DIR, text_bpe_model, AUDIO_FEATURES_DIR, VIDEO_FEATURES_DIR)
dev_ds = CosineSimDataset('dev', DATASET_INFO_DIR, text_bpe_model, AUDIO_FEATURES_DIR, VIDEO_FEATURES_DIR)
test_ds = CosineSimDataset('test', DATASET_INFO_DIR, text_bpe_model, AUDIO_FEATURES_DIR, VIDEO_FEATURES_DIR)

len(train_ds), len(dev_ds), len(test_ds)

(58847, 15961, 14806)

In [12]:
%%timeit
train_ds[100]

1.88 ms ± 5.15 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [13]:
sample = train_ds[100]
sample.keys()

dict_keys(['text_inp', 'audio_inp', 'video_inp'])

In [14]:
sample['text_inp']

array([112, 126,   6])

In [15]:
type(sample['audio_inp']), sample['audio_inp'].shape

(numpy.ndarray, (384001,))

In [16]:
type(sample['video_inp']), sample['video_inp'].shape

(numpy.ndarray, (16, 3, 224, 224))

In [17]:
np.concatenate((np.array([1, 2]), np.zeros(0)))

array([1., 2.])

In [18]:
text_bpe_model.encode_ids('')

[]

In [19]:
text_bpe_model.decode_ids([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

['', '', '', 't', 'a', 'he', 'in', 'the', 'er', 'on', 's']

In [34]:
def collate_trimodal_cosine(batch):
    texts = list()
    audios = list()
    videos = list()
    max_t = max([len(s['text_inp']) for s in batch])
    audio_feats_len = 384000
    for sample in batch:
        text_ids = sample['text_inp']
        pad_len = max_t - len(text_ids)
        padded_text_ids = np.concatenate((text_ids, np.zeros(pad_len)))
        texts.append(torch.tensor(padded_text_ids, dtype=torch.long).unsqueeze(0))
        
        audio_feats = sample['audio_inp'][:audio_feats_len]
        pad_len = audio_feats_len - len(audio_feats)
        padded_audio_feats = np.concatenate((audio_feats, np.zeros(pad_len)))
        audios.append(torch.tensor(padded_audio_feats, dtype=torch.float32).unsqueeze(0))
        
        videos.append(torch.tensor(sample['video_inp'], dtype=torch.float32).unsqueeze(0))
    texts = torch.cat(texts, dim=0)
    audios = torch.cat(audios, dim=0)
    videos = torch.cat(videos, dim=0)
    
    return {'text_batch': texts, 'audio_batch': audios, 'video_batch': videos}

In [35]:
train_dl = DataLoader(train_ds, collate_fn=collate_trimodal_cosine, 
                      batch_size=TRAIN_BATCH_SIZE, shuffle=True)
dev_dl = DataLoader(dev_ds, collate_fn=collate_trimodal_cosine, 
                      batch_size=EVAL_BATCH_SIZE, shuffle=False)
test_dl = DataLoader(test_ds, collate_fn=collate_trimodal_cosine, 
                      batch_size=EVAL_BATCH_SIZE, shuffle=False)

len(train_dl), len(dev_dl), len(test_dl)

(29424, 3991, 3702)

In [36]:
sample_batch = next(iter(train_dl))

sample_batch.keys()

dict_keys(['text_batch', 'audio_batch', 'video_batch'])

In [37]:
sample_batch['text_batch']

tensor([[3585],
        [1357]])

In [38]:
sample_batch['audio_batch'].shape

torch.Size([2, 384000])

In [41]:
sample_batch['video_batch'].shape

torch.Size([2, 16, 3, 224, 224])

In [26]:
class TextModel(nn.Module):
    def __init__(self, emb_size, bpe_vocab_size):
        super(TextModel, self).__init__()
        self.emb_size = emb_size
        self.bpe_vocab_size = bpe_vocab_size
        
        self.emb_model = BPEmb(lang='en', vs=self.bpe_vocab_size, dim=self.emb_size)
        self.model = nn.Sequential(
            nn.Embedding.from_pretrained(torch.tensor(self.emb_model.vectors)),
            nn.LayerNorm(self.emb_size),
            nn.GELU(),
            nn.Dropout(p=0.2),
            nn.Linear(self.emb_size, self.emb_size)
        )
    
    def forward(self, x):
        mask = (x != 0) * 1
        n_tokens = torch.clamp(mask.sum(dim=1), min=1).unsqueeze(-1)
        expanded_mask = mask.unsqueeze(-1).expand(-1, -1, self.emb_size)
        
        x = self.model(x)
        x = x * expanded_mask
        x = x.sum(dim=1)
        x = x / n_tokens
        
        return x


text_model = TextModel(EMB_SIZE, BPE_VOCAB_SIZE)
text_model

TextModel(
  (model): Sequential(
    (0): Embedding(10000, 300)
    (1): LayerNorm((300,), eps=1e-05, elementwise_affine=True)
    (2): GELU(approximate='none')
    (3): Dropout(p=0.2, inplace=False)
    (4): Linear(in_features=300, out_features=300, bias=True)
  )
)

In [27]:
text_emb = text_model(sample_batch['text_batch'])

text_emb.shape, text_emb

(torch.Size([2, 300]),
 tensor([[ 0.1430, -0.3933,  0.3621,  0.2887,  0.6107, -0.1417,  0.1118, -0.1488,
           0.2626,  0.7315,  0.2968,  0.1911,  0.6233,  0.1070, -0.6465,  0.3091,
          -0.6058,  0.0350, -0.0645, -0.0433, -0.0307,  0.3927,  0.5315,  0.2415,
          -0.5172, -0.0404, -0.2899,  0.4093,  0.2182,  0.4165,  0.4148, -0.0157,
           0.3405, -0.2870,  0.2293,  0.2375, -0.0207,  0.0636, -0.0106, -0.0826,
          -0.0438, -0.5452,  0.3489, -0.1245, -0.0946,  0.3676,  0.1745, -0.0524,
           0.3004, -0.3933, -0.3097, -0.0125,  0.2758, -0.2303,  0.5781,  0.1127,
          -0.2225,  0.4518,  0.5754, -0.6813, -0.5262, -0.0671,  0.3752, -0.5867,
          -0.1893, -0.6048, -0.7808,  0.0541,  0.2966, -0.2297, -0.1309, -0.0392,
           0.5144, -0.2665, -0.4916, -0.0498, -0.2526,  0.0591,  0.0211,  0.1176,
           0.0948, -0.5843,  0.0443,  0.3372, -0.2101, -0.0736, -0.1996,  0.2176,
          -0.4355,  0.3781,  0.2783, -0.1091,  0.1149,  0.1552, -0.4517, -0

In [28]:
class AudioModel(nn.Module):
    def __init__(self, emb_size, audio_model_key, audio_model_path='', quantize_input=True):
        super(AudioModel, self).__init__()
        self.emb_size = emb_size
        self.audio_model_key = audio_model_key
        self.audio_model_path = audio_model_path
        self.quantize = quantize_input
        
        transformers.logging.set_verbosity_error()
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')
        self.encoder = laion_clap.CLAP_Module(amodel=self.audio_model_key, enable_fusion=False)
        if self.audio_model_path:
            self.encoder.load_ckpt(self.audio_model_path)
        del self.encoder.model.text_branch
        del self.encoder.model.text_transform
        del self.encoder.model.text_projection
        sys.stdout.close()
        sys.stdout = self._original_stdout
        transformers.logging.set_verbosity_warning()
        
        self.encoder_out_size = self.encoder.model.audio_projection[2].out_features
        self.projector = nn.Sequential(
            nn.LayerNorm(self.encoder_out_size),
            nn.GELU(),
            nn.Dropout(p=0.2),
            nn.Linear(self.encoder_out_size, self.emb_size)
        )
    
    def forward(self, x):
        if self.quantize:
            x = self.int16_to_float32(self.float32_to_int16(x))
        x = self.encoder.get_audio_embedding_from_data(x, use_tensor=True)
        x = self.projector(x)
        
        return x
    
    def int16_to_float32(self, x):
        return (x / 32767.0).type(torch.float32)

    def float32_to_int16(self, x):
        x = torch.clamp(x, min=-1., max=1.)
        return (x * 32767.).type(torch.int16)


audio_model = AudioModel(EMB_SIZE, AUDIO_MODEL_KEY, AUDIO_MODEL_PATH).to('cpu')
audio_model

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


AudioModel(
  (encoder): CLAP_Module(
    (model): CLAP(
      (audio_branch): HTSAT_Swin_Transformer(
        (spectrogram_extractor): Spectrogram(
          (stft): STFT(
            (conv_real): Conv1d(1, 513, kernel_size=(1024,), stride=(480,), bias=False)
            (conv_imag): Conv1d(1, 513, kernel_size=(1024,), stride=(480,), bias=False)
          )
        )
        (logmel_extractor): LogmelFilterBank()
        (spec_augmenter): SpecAugmentation(
          (time_dropper): DropStripes()
          (freq_dropper): DropStripes()
        )
        (bn0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (patch_embed): PatchEmbed(
          (proj): Conv2d(1, 128, kernel_size=(4, 4), stride=(4, 4))
          (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        )
        (pos_drop): Dropout(p=0.0, inplace=False)
        (layers): ModuleList(
          (0): BasicLayer(
            dim=128, input_resolution=(64, 64), depth=2
     

In [29]:
audio_emb = audio_model(sample_batch['audio_batch'])

audio_emb.shape, audio_emb

(torch.Size([2, 300]),
 tensor([[-1.8073e-01, -4.1579e-02, -3.7276e-01, -5.0331e-01, -8.0826e-02,
          -5.9152e-01, -3.6822e-01,  3.7489e-01,  5.2272e-01,  2.1830e-01,
          -7.8895e-01, -1.5372e-01,  2.6710e-01,  4.1248e-01,  1.2446e-01,
          -1.1751e-01,  4.4489e-01, -1.7110e-02,  4.8069e-01, -3.7553e-01,
          -4.8587e-01, -2.4678e-01, -2.0443e-02,  7.0252e-02,  1.3200e+00,
          -4.3580e-01, -2.4522e-01, -8.6814e-02, -3.1731e-01,  4.1141e-02,
           1.6837e-01,  2.9671e-01, -6.6236e-01, -6.5714e-02, -5.2141e-01,
          -9.9603e-02, -3.9455e-01,  1.9275e-02, -1.1607e-01, -3.8710e-01,
           4.9667e-01,  8.0394e-01,  5.3016e-01,  2.1021e-01, -1.6244e-01,
           3.9313e-01, -4.5103e-01,  1.8595e-01,  1.6604e-01, -1.2172e-01,
           7.9483e-01,  4.0929e-01,  7.3080e-02, -2.3662e-01, -5.9093e-01,
           2.5641e-01, -9.8692e-02,  1.0108e-02, -9.6935e-01, -2.6181e-01,
          -7.1860e-01, -4.3967e-01,  1.1687e+00,  6.2672e-01,  5.8343e-02,
  

In [30]:
class VideoModel(nn.Module):
    def __init__(self, emb_size, video_model_key):
        super(VideoModel, self).__init__()
        self.emb_size = emb_size
        self.video_model_key = video_model_key
        
        transformers.logging.set_verbosity_error()
        self.encoder = VideoMAEModel.from_pretrained(self.video_model_key)
        self.encoder_out_size = self.encoder.config.hidden_size
        self.projector = nn.Sequential(
            nn.LayerNorm(self.encoder_out_size),
            nn.GELU(),
            nn.Dropout(p=0.2),
            nn.Linear(self.encoder_out_size, self.emb_size)
        )
        transformers.logging.set_verbosity_warning()
    
    def forward(self, x):
        x = self.encoder(pixel_values=x).last_hidden_state
        x = x.mean(dim=1)
        x = self.projector(x)
        
        return x

video_model = VideoModel(EMB_SIZE, VIDEO_MODEL_KEY)
video_model

VideoModel(
  (encoder): VideoMAEModel(
    (embeddings): VideoMAEEmbeddings(
      (patch_embeddings): VideoMAEPatchEmbeddings(
        (projection): Conv3d(3, 768, kernel_size=(2, 16, 16), stride=(2, 16, 16))
      )
    )
    (encoder): VideoMAEEncoder(
      (layer): ModuleList(
        (0): VideoMAELayer(
          (attention): VideoMAEAttention(
            (attention): VideoMAESelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=False)
              (key): Linear(in_features=768, out_features=768, bias=False)
              (value): Linear(in_features=768, out_features=768, bias=False)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): VideoMAESelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): VideoMAEIntermediate(
            (dense): Linear(in_features=768, out_feat

In [31]:
with torch.no_grad():
    video_emb = video_model(sample_batch['video_batch'])

video_emb.shape, video_emb

(torch.Size([2, 300]),
 tensor([[-0.5772, -0.3672, -0.9825,  0.0232, -0.2266,  0.0496, -1.2599,  0.6325,
           0.2771,  0.2600, -0.2905, -0.2857, -1.2756,  0.9483, -0.8149, -0.0745,
           0.6951,  0.6323, -0.2415, -0.4669, -1.0633,  0.2361, -0.4361,  1.0861,
          -0.3240, -0.3145, -0.2219,  0.7223, -0.2788,  1.2020, -0.1038,  1.2844,
          -0.6843,  0.7090, -0.2504, -0.5071, -1.0594, -0.6429, -0.9707, -0.5802,
           0.0932, -0.7462,  0.9548, -0.4646,  0.1008, -0.4022, -0.2463, -0.5239,
          -0.7948, -1.2660,  0.5816, -0.1840, -0.2004, -0.1838, -0.2320, -0.4598,
          -1.2962, -0.9520, -0.9570,  0.6878, -0.6147, -0.5386,  0.8518, -0.3208,
          -0.1727, -0.1498,  0.1091,  0.4571,  0.0142, -0.4953,  0.8885, -0.1491,
          -0.9241, -0.0977,  0.3356, -1.0355, -0.0832,  0.0390,  1.3734, -0.4376,
           0.9956,  0.5463,  0.9985, -1.2804,  0.8385,  0.3050, -0.0207,  0.1187,
           0.1019,  0.0126,  0.6233,  0.9568, -0.6971, -0.0622,  1.2899,  0

In [93]:
class TriModel(nn.Module):
    def __init__(self, emb_size, bpe_vocab_size, audio_model_key, 
                 audio_model_path, video_model_key):
        
        super(TriModel, self).__init__()
        self.emb_size = emb_size
        self.bpe_vocab_size = bpe_vocab_size
        self.audio_model_key = audio_model_key
        self.audio_model_path = audio_model_path
        self.video_model_key = video_model_key
        
        self.text_model = TextModel(self.emb_size, self.bpe_vocab_size)
        self.audio_model = AudioModel(self.emb_size, self.audio_model_key, self.audio_model_path, 
                              quantize_input=True).to('cpu')
        self.video_model = VideoModel(self.emb_size, self.video_model_key)
    
    def forward(self, t_x, a_x, v_x):
        t_x = self.text_model(t_x)
        a_x = self.audio_model(a_x)
        v_x = self.video_model(v_x)
        
        return {'text_emb': t_x, 'audio_emb': a_x, 'video_emb': v_x}


trimodel = TriModel(EMB_SIZE, BPE_VOCAB_SIZE, AUDIO_MODEL_KEY, AUDIO_MODEL_PATH, VIDEO_MODEL_KEY).to('cpu')
trimodel

TriModel(
  (text_model): TextModel(
    (model): Sequential(
      (0): Embedding(10000, 300)
      (1): LayerNorm((300,), eps=1e-05, elementwise_affine=True)
      (2): GELU(approximate='none')
      (3): Dropout(p=0.2, inplace=False)
      (4): Linear(in_features=300, out_features=300, bias=True)
    )
  )
  (audio_model): AudioModel(
    (encoder): CLAP_Module(
      (model): CLAP(
        (audio_branch): HTSAT_Swin_Transformer(
          (spectrogram_extractor): Spectrogram(
            (stft): STFT(
              (conv_real): Conv1d(1, 513, kernel_size=(1024,), stride=(480,), bias=False)
              (conv_imag): Conv1d(1, 513, kernel_size=(1024,), stride=(480,), bias=False)
            )
          )
          (logmel_extractor): LogmelFilterBank()
          (spec_augmenter): SpecAugmentation(
            (time_dropper): DropStripes()
            (freq_dropper): DropStripes()
          )
          (bn0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=T

In [45]:
with torch.no_grad():
    embeds = trimodel(*sample_batch.values())

embeds.keys()

dict_keys(['text_emb', 'audio_emb', 'video_emb'])

In [46]:
embeds['text_emb'].shape, embeds['audio_emb'].shape, embeds['video_emb'].shape

(torch.Size([2, 300]), torch.Size([2, 300]), torch.Size([2, 300]))

In [47]:
def get_n_params(model):
    n_params = 0
    for p in model.parameters():
        n_params += p.numel()
    
    return n_params

print('# params')
print('text model:', get_n_params(text_model))
print('audio model:', get_n_params(audio_model))
print('video model:', get_n_params(video_model))
print('tri model:', get_n_params(trimodel))

# params
text model: 3090900
audio model: 73892229
video model: 86459436
tri model: 163442565


In [54]:
criterion = nn.CosineEmbeddingLoss()

criterion

CosineEmbeddingLoss()

In [95]:
def cosine_sim_dissim_objective(embeds):
    te = embeds['text_emb']
    ae = embeds['audio_emb']
    ve = embeds['video_emb']
    bs = te.shape[0]
    bs_h = bs // 2
    pos_targets = torch.ones(bs)
    neg_targets = torch.ones(bs_h) * -1
    criterion = nn.CosineEmbeddingLoss(margin=0.0, reduction='mean')
    
    cos_loss = 0
    cos_loss += criterion(te, ae, pos_targets) + criterion(te, ve, pos_targets) + criterion(ae, ve, pos_targets)
    cos_loss += criterion(te[:bs_h, :], ae[bs_h:, :], neg_targets) + \
        criterion(te[bs_h:, :], ae[:bs_h, :], neg_targets)
    cos_loss += criterion(te[:bs_h, :], ve[bs_h:, :], neg_targets) + \
        criterion(te[bs_h:, :], ve[:bs_h, :], neg_targets)
    cos_loss += criterion(ae[:bs_h, :], ve[bs_h:, :], neg_targets) + \
        criterion(ae[bs_h:, :], ve[:bs_h, :], neg_targets)
    
    return cos_loss

In [96]:
cosine_sim_dissim_objective(embeds)

tensor(3.2567)