In [None]:
from transformers import ClapProcessor, ClapModel
processor = ClapProcessor.from_pretrained("laion/clap-htsat-unfused")

In [None]:
from datasets import load_dataset, Audio

dataset = load_dataset("patrickvonplaten/librispeech_asr_self_contained", split="train.clean.100")
dataset = dataset.cast_column("audio", Audio(sampling_rate=48000))

In [None]:
from torch.utils.data import Dataset, DataLoader
import numpy as np
class CLEODataset(Dataset):
    def __init__(self, dataset, instruction, processor, sampling_rate = 48000):
        self.dataset = dataset
        self.instruction = instruction
        self.processor = processor
        self.sampling_rate = sampling_rate

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        ## Create the label
        label = self.dataset[idx]["text"].lower()
        
        ## Save the audio
        audio_array = self.dataset[idx]["audio"]["array"]
        return self.instruction, audio_array, label

def custom_collate_fn(original_batch):
    instructions = [each[0] for each in original_batch]
    audios = [each[1] for each in original_batch]
    labels = [each[2] for each in original_batch]
    return instructions, audios, labels

In [None]:
instruction = """Repeat back the information that you see below:
<wav>

Information:
"""
cleoDataset = CLEODataset(dataset, instruction, processor)
train_dataloader = DataLoader(cleoDataset, batch_size=8, shuffle=True, collate_fn=custom_collate_fn)


In [None]:
idx, preBatch = next(enumerate(train_dataloader))

In [None]:
batch = {
    "instructions": preBatch[0],
    "audio_array": preBatch[1],
    "labels": preBatch[2]
}

In [None]:
from cleo.cleoCLAP import CLEOClap
clapModelVr = "laion/clap-htsat-unfused"
cleo_model = CLEOClap(
        llm_model_path = "/home/models/Llama-2-7b-hf",
        audio_features = 512, # 1024 if ImageBind,
        host_llm_on_cuda = True,
        audio_gpu = "cuda:1",
        clapModelVr = clapModelVr
)

In [None]:
output = cleo_model(batch)
loss = output.loss
loss_val = loss.item()