In [1]:
import data
import torch
from models import imagebind_model
from models.imagebind_model import ModalityType

text_list=["A dog.", "A car", "A bird"]
image_paths=[".assets/dog_image.jpg", ".assets/car_image.jpg", ".assets/bird_image.jpg"]
audio_paths=[".assets/dog_audio.wav", ".assets/car_audio.wav", ".assets/bird_audio.wav"]

device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Instantiate model
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)

# Load data
inputs = {
    ModalityType.TEXT: data.load_and_transform_text(text_list, device),
    ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device),
    ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device),
}

with torch.no_grad():
    embeddings = model(inputs)

print(
    "Vision x Text: ",
    torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1),
)
print(
    "Audio x Text: ",
    torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1),
)
print(
    "Vision x Audio: ",
    torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=-1),
)

  from .autonotebook import tqdm as notebook_tqdm


Downloading imagebind weights to .checkpoints/imagebind_huge.pth ...


100%|██████████| 4.47G/4.47G [00:28<00:00, 171MB/s] 


Vision x Text:  tensor([[9.9761e-01, 2.3694e-03, 1.8612e-05],
        [3.3836e-05, 9.9994e-01, 2.4119e-05],
        [4.7996e-05, 1.3496e-02, 9.8646e-01]], device='cuda:0')
Audio x Text:  tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]], device='cuda:0')
Vision x Audio:  tensor([[0.8070, 0.1088, 0.0842],
        [0.1036, 0.7884, 0.1079],
        [0.0018, 0.0022, 0.9960]], device='cuda:0')


In [9]:
embeddings['audio'].norm(dim=-1)

tensor([19.8767, 19.0552, 19.0818], device='cuda:0')

# Audio-Text: Clotho

In [55]:
import pandas as pd
import random
from tqdm import tqdm, trange
import data
import torch
from models import imagebind_model
from models.imagebind_model import ModalityType
import torch.nn.functional as F

device = "cuda:0" if torch.cuda.is_available() else "cpu"

CLOTHO_PATH = "/pasteur/u/yuhuiz/data/CLOTHO/"

# Read CSV from clotho_captions_development.csv
clotho_df_train = pd.read_csv(CLOTHO_PATH + "clotho_captions_development.csv")
file_names_train = clotho_df_train['file_name'].tolist()
file_names_train = [CLOTHO_PATH + "development/" + file_name for file_name in file_names_train]

# Read CSV from clotho_captions_evaluation.csv
clotho_df_test = pd.read_csv(CLOTHO_PATH + "clotho_captions_evaluation.csv")
file_names_test = clotho_df_test['file_name'].tolist()
file_names_test = [CLOTHO_PATH + "evaluation/" + file_name for file_name in file_names_test]

file_names = file_names_train + file_names_test

print(len(file_names), len(file_names_train), len(file_names_test))

3938 2893 1045


In [None]:
# Instantiate model
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)

In [28]:
audio_embeddings = []

batch_size = 32
for i in trange(0, len(file_names), batch_size):
    with torch.no_grad():
        # Load data
        inputs = {
            ModalityType.AUDIO: data.load_and_transform_audio_data(file_names[i: i + batch_size], device),
        }
        embeddings = model(inputs)
        audio_embeddings.append(embeddings[ModalityType.AUDIO].cpu())

audio_embeddings = torch.cat(audio_embeddings, dim=0)
torch.save([file_names, audio_embeddings], "audio_embeddings_clotho.pt")

100%|██████████| 124/124 [04:41<00:00,  2.27s/it]


In [42]:
captions = clotho_df_train['caption_1'].tolist() + clotho_df_train['caption_2'].tolist() + clotho_df_train['caption_3'].tolist() + clotho_df_train['caption_4'].tolist() + clotho_df_train['caption_5'].tolist() + clotho_df_test['caption_1'].tolist() + clotho_df_test['caption_2'].tolist() + clotho_df_test['caption_3'].tolist() + clotho_df_test['caption_4'].tolist() + clotho_df_test['caption_5'].tolist()
text_embeddings = []

batch_size = 32
for i in trange(0, len(captions), batch_size):
    with torch.no_grad():
        # Load data
        inputs = {
            ModalityType.TEXT: data.load_and_transform_text(captions[i: i + batch_size], device),
        }
        embeddings = model(inputs)
        text_embeddings.append(embeddings[ModalityType.TEXT].cpu())

text_embeddings = torch.cat(text_embeddings, dim=0)
torch.save([captions, text_embeddings], "text_embeddings_clotho.pt")

100%|██████████| 616/616 [02:15<00:00,  4.55it/s]


In [46]:
print(len(file_names), len(audio_embeddings), len(captions), len(text_embeddings))
audio_to_embeddings = {key: value for key, value in zip(file_names, audio_embeddings)}
text_to_embeddings = {key: value for key, value in zip(captions, text_embeddings)}

3938 3938 19690 19690


In [64]:
display(clotho_df_train.head())

# transform to x_embed (audio embedding), y_embed (caption embedding), y (caption)
train_data = []

for index, row in tqdm(clotho_df_train.iterrows(), total=len(clotho_df_train)):
    file_name = CLOTHO_PATH + "development/" + row['file_name']
    for i in range(5): # 5 captions
        train_data.append({
            "x": file_name,
            "y": row['caption_' + str(i + 1)],
            "x_embed": F.normalize(audio_to_embeddings[file_name], dim=0).numpy(),
            "y_embed": F.normalize(text_to_embeddings[row['caption_' + str(i + 1)]], dim=0).numpy(),
            "split": "train"
        })

test_data = []

for index, row in tqdm(clotho_df_test.iterrows(), total=len(clotho_df_test)):
    file_name = CLOTHO_PATH + "evaluation/" + row['file_name']
    for i in range(5): # 5 captions
        test_data.append({
            "x": file_name,
            "y": row['caption_' + str(i + 1)],
            "x_embed": F.normalize(audio_to_embeddings[file_name], dim=0).numpy(),
            "y_embed": F.normalize(text_to_embeddings[row['caption_' + str(i + 1)]], dim=0).numpy(),
            "split": "test"
        })

data = train_data + test_data

import pickle 
with open('data_audio_clotho_imagebind.pkl', 'wb') as f:
    pickle.dump(data, f)

Unnamed: 0,file_name,caption_1,caption_2,caption_3,caption_4,caption_5
0,Distorted AM Radio noise.wav,A muddled noise of broken channel of the TV,A television blares the rhythm of a static TV.,Loud television static dips in and out of focus,The loud buzz of static constantly changes pit...,heavy static and the beginnings of a signal on...
1,Paper_Parchment_Rustling.wav,A person is turning a map over and over.,A person is very carefully rapping a gift for ...,A person is very carefully wrapping a gift for...,"He sighed as he turned the pages of the book, ...","papers are being turned, stopped, then turned ..."
2,03 Whales Slowing Down.wav,Several barnyard animals mooing in a barn whil...,"The vocalization of several whales, along with...","Underwater, large numbers of shrimp clicking a...",Whales sing to one another over the flowing wa...,wales sing to one another with water flowing i...
3,Rope tied to boat in port.wav,An office chair is squeaking as someone bends ...,Popping and squeaking gradually tapers off to ...,Someone is opening a creaky door slowly while ...,Squeaking and popping followed by gradual popp...,an office chair is squeaking as someone leans ...
4,carpenter bee.wav,A flying bee is buzzing loudly around an objec...,An annoying fly is buzzing loudly and consiste...,An insect buzzing in the foreground as birds c...,"An insect trapped in a spider web struggles, b...","Outdoors, insect trapped in a spider web and t..."


100%|██████████| 2893/2893 [00:00<00:00, 3565.06it/s]
100%|██████████| 1045/1045 [00:00<00:00, 3717.84it/s]


In [66]:
data[:10]

[{'x': '/pasteur/u/yuhuiz/data/CLOTHO/development/Distorted AM Radio noise.wav',
  'y': 'A muddled noise of broken channel of the TV',
  'x_embed': array([-0.0238224 , -0.05003118,  0.02089429, ..., -0.01378978,
         -0.02703197,  0.02927944], dtype=float32),
  'y_embed': array([-0.00901202,  0.00539664,  0.00060415, ...,  0.03331101,
         -0.00200131,  0.01315751], dtype=float32),
  'split': 'train'},
 {'x': '/pasteur/u/yuhuiz/data/CLOTHO/development/Distorted AM Radio noise.wav',
  'y': 'A television blares the rhythm of a static TV.',
  'x_embed': array([-0.0238224 , -0.05003118,  0.02089429, ..., -0.01378978,
         -0.02703197,  0.02927944], dtype=float32),
  'y_embed': array([-0.02141449,  0.02404512,  0.0194047 , ...,  0.02225818,
         -0.01603848,  0.00652554], dtype=float32),
  'split': 'train'},
 {'x': '/pasteur/u/yuhuiz/data/CLOTHO/development/Distorted AM Radio noise.wav',
  'y': 'Loud television static dips in and out of focus',
  'x_embed': array([-0.0238224

# Video-Text: MSR-VTT

In [1]:
import pandas as pd
import random
from tqdm import tqdm, trange
import data
import torch
from models import imagebind_model
from models.imagebind_model import ModalityType
import torch.nn.functional as F
import json

device = "cuda:0" if torch.cuda.is_available() else "cpu"

MSRVTT_PATH = "/pasteur/u/yuhuiz/data/MSRVTT/"

data_train = json.load(open(MSRVTT_PATH + "train_val_videodatainfo.json"))
data_test = json.load(open(MSRVTT_PATH + "test_videodatainfo.json"))

file_names_train = [MSRVTT_PATH + "TrainValVideo/" + data_train['videos'][i]['video_id'] + ".mp4" for i in range(len(data_train['videos']))]
file_names_test = [MSRVTT_PATH + "TestVideo/" + data_test['videos'][i]['video_id'] + ".mp4" for i in range(len(data_test['videos']))]

file_names = file_names_train + file_names_test

print(len(file_names), len(file_names_train), len(file_names_test))



10000 7010 2990


In [2]:
# Instantiate model
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)
print("Model loaded")

Model loaded


In [3]:
video_embeddings = []

batch_size = 8
for i in trange(0, len(file_names), batch_size):
    with torch.no_grad():
        # Load data
        inputs = {
            ModalityType.VISION: data.load_and_transform_video_data(file_names[i: i + batch_size], device),
        }
        embeddings = model(inputs)
        video_embeddings.append(embeddings[ModalityType.VISION].cpu())

video_embeddings = torch.cat(video_embeddings, dim=0)
torch.save([file_names, video_embeddings], "video_embeddings_msrvtt.pt")

100%|██████████| 1250/1250 [1:46:11<00:00,  5.10s/it]


In [8]:
captions_train = [item["caption"] for item in data_train["sentences"]]
captions_test = [item["caption"] for item in data_test["sentences"]]
captions = captions_train + captions_test

In [12]:
text_embeddings = []

batch_size = 256
for i in trange(0, len(captions), batch_size):
    with torch.no_grad():
        # Load data
        inputs = {
            ModalityType.TEXT: data.load_and_transform_text(captions[i: i + batch_size], device),
        }
        embeddings = model(inputs)
        text_embeddings.append(embeddings[ModalityType.TEXT].cpu())

text_embeddings = torch.cat(text_embeddings, dim=0)
torch.save([captions, text_embeddings], "text_embeddings_msrvtt.pt")

100%|██████████| 782/782 [15:33<00:00,  1.19s/it]


In [13]:
print(len(file_names), len(video_embeddings), len(captions), len(text_embeddings))
video_to_embeddings = {key: value for key, value in zip(file_names, video_embeddings)}
text_to_embeddings = {key: value for key, value in zip(captions, text_embeddings)}

10000 10000 200000 200000


In [17]:
train_data_processed = []
for item in data_train["sentences"]:
    caption = item["caption"]
    video_id = item["video_id"]
    video_path = MSRVTT_PATH + "TrainValVideo/" + video_id + ".mp4"
    train_data_processed.append({
        "x": video_path,
        "y": caption,
        "x_embed": F.normalize(video_to_embeddings[video_path], dim=0).numpy(),
        "y_embed": F.normalize(text_to_embeddings[caption], dim=0).numpy(),
        "split": "train"
    })

test_data_processed = []
for item in data_test["sentences"]:
    caption = item["caption"]
    video_id = item["video_id"]
    video_path = MSRVTT_PATH + "TestVideo/" + video_id + ".mp4"
    test_data_processed.append({
        "x": video_path,
        "y": caption,
        "x_embed": F.normalize(video_to_embeddings[video_path], dim=0).numpy(),
        "y_embed": F.normalize(text_to_embeddings[caption], dim=0).numpy(),
        "split": "test"
    })

data = train_data_processed + test_data_processed

import pickle 
with open('data_video_msrvtt_imagebind.pkl', 'wb') as f:
    pickle.dump(data, f)

In [18]:
data[:10]

[{'x': '/pasteur/u/yuhuiz/data/MSRVTT/TrainValVideo/video2960.mp4',
  'y': 'a cartoon animals runs through an ice cave in a video game',
  'x_embed': array([-0.01323111,  0.01547959, -0.01875113, ..., -0.02266031,
          0.00901074,  0.02306143], dtype=float32),
  'y_embed': array([-0.0148942 ,  0.0271524 , -0.04431538, ..., -0.00151742,
         -0.0303146 ,  0.03866813], dtype=float32),
  'split': 'train'},
 {'x': '/pasteur/u/yuhuiz/data/MSRVTT/TrainValVideo/video2960.mp4',
  'y': 'a cartoon character runs around inside of a video game',
  'x_embed': array([-0.01323111,  0.01547959, -0.01875113, ..., -0.02266031,
          0.00901074,  0.02306143], dtype=float32),
  'y_embed': array([ 0.00500106,  0.02801896, -0.02620777, ...,  0.01640674,
         -0.00740285,  0.01365293], dtype=float32),
  'split': 'train'},
 {'x': '/pasteur/u/yuhuiz/data/MSRVTT/TrainValVideo/video2960.mp4',
  'y': 'a character is running in the snow',
  'x_embed': array([-0.01323111,  0.01547959, -0.01875113, 