In [None]:
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
from datasets import load_dataset

dataset = load_dataset("patrickvonplaten/librispeech_asr_self_contained")

In [None]:
from torch.utils.data import Dataset, DataLoader
import uuid
import numpy as np
from scipy.io.wavfile import write as write_wav

class CLEODataset(Dataset):
    def __init__(self, dataset, instruction):
        self.dataset = dataset
        self.instruction = instruction

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        ## Create the label
        label = self.dataset[idx]["text"].lower()
        
        ## Save the audio
        file_name = f"/home/CS546-CLEO/wav_samples/{str(uuid.uuid4())}.wav"
        audio_file = np.array(self.dataset[idx]["audio"]["array"], dtype=np.float32)
        write_wav(file_name, 16000, audio_file)

        return self.instruction, file_name, label

instruction = """Convert the following information to a graph of triplets:
<wav>

Triples:
"""

cleoDataset = CLEODataset(dataset["train.clean.100"], instruction)
train_dataloader = DataLoader(cleoDataset, batch_size=8, shuffle=True)


In [None]:
idx, preBatch = next(enumerate(train_dataloader))

In [None]:
batch = {
                "instructions": list(preBatch[0]),
                "audio_paths": [[each] for each in list(preBatch[1])],
                "labels": list(preBatch[2])
            }

In [None]:
import os
for each in preBatch[1]:
    os.remove(each)

In [None]:
each

In [None]:
from IPython.display import Audio 
from scipy.io.wavfile import write as write_wav
import numpy as np

for i in range(0,5):
    file_name = "wav_samples/test_" + str(i) + ".wav"
    audio_file = np.array(dataset["train.clean.100"][i]["audio"]["array"], dtype=np.float32)
    write_wav(file_name, 16000, audio_file)

In [None]:
import torch
from imagebind import data
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)

In [None]:
text_list = [dataset["train.clean.100"][i]["text"].lower() for i in range(0,5)]
audio_list = ["wav_samples/test_" + str(i) + ".wav" for i in range(0,5)]

In [None]:
text_list

In [None]:
# Load data
inputs = {
    ModalityType.TEXT: data.load_and_transform_text(text_list, device),
    ModalityType.AUDIO: data.load_and_transform_audio_data(audio_list, device),
}

with torch.no_grad():
    embeddings = model(inputs)

print(
    "Audio x Text: ",
    embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T,
)

In [None]:
import seaborn as sns

sns.heatmap(torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1).cpu().numpy())