In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import typing
from typing import List
import logging
import torch.nn as nn
from cleo.cleoCLAP import CLEOClap
from datasets import load_from_disk, load_dataset
import tqdm
import numpy as np
import os
from transformers import ClapProcessor, ClapModel
from torch.utils.data import Dataset, DataLoader
import os
from cleo.QFormer import BertConfig, BertLMHeadModel
from torch.nn import functional as F

BATCH_SIZE = 8
clapModelVr = "laion/clap-htsat-unfused"
dataset = load_dataset("patrickvonplaten/librispeech_asr_self_contained", split="train.clean.100")
audio_gpu = "cpu"
clapModelProcessor = ClapProcessor.from_pretrained(clapModelVr)
clapModel = ClapModel.from_pretrained(clapModelVr)
clapModel = clapModel.to(audio_gpu)

class CLEODataset(Dataset):
    def __init__(self, dataset, instruction, processor, sampling_rate = 48000):
        self.dataset = dataset
        self.instruction = instruction
        self.processor = processor
        self.sampling_rate = sampling_rate

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        ## Create the label
        label = self.dataset[idx]["text"].lower()
        
        ## Save the audio
        audio_array = self.dataset[idx]["audio"]["array"]
        return self.instruction, audio_array, label

def custom_collate_fn(original_batch):
    instructions = [each[0] for each in original_batch]
    audios = [each[1] for each in original_batch]
    labels = [each[2] for each in original_batch]
    return instructions, audios, labels

instruction = """Repeat back the information that you see below:
<wav>

Information:
"""
cleoDataset = CLEODataset(dataset, instruction, clapModelProcessor)
train_dataloader = DataLoader(cleoDataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=custom_collate_fn)
batch_idx, (instructions, audios, labels) = next(enumerate(train_dataloader))
print("Dataset Loaded...")

  warn(f"Failed to load image Python extension: {e}")


Dataset Loaded...


In [2]:
## Initialize QFormer
def init_Qformer(num_query_token, audio_width, freeze):
    encoder_config = BertConfig.from_pretrained("bert-base-uncased")
    encoder_config.encoder_width = audio_width
    # insert cross-attention layer every other block
    encoder_config.add_cross_attention = True
    encoder_config.cross_attention_freq = 2
    encoder_config.query_length = num_query_token
    Qformer = BertLMHeadModel(config=encoder_config)
    query_tokens = nn.Parameter(
        torch.zeros(1, num_query_token, encoder_config.hidden_size)
    )
    query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)

#    Qformer.cls = None
#    Qformer.bert.embeddings.word_embeddings = None
#    Qformer.bert.embeddings.position_embeddings = None
#    for layer in Qformer.bert.encoder.layer:
#        layer.output = None
#        layer.intermediate = None

    if freeze:
        for name, param in Qformer.named_parameters():
            param.requires_grad = False
        Qformer = Qformer.eval()
        query_tokens.requires_grad = False
        logging.info("freeze Qformer")
    return Qformer, query_tokens

def __load_llm__(llm_model, freeze_llm, pad_token_id=None, device="cpu"):
    ## Load the model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(llm_model)
    if pad_token_id is not None:
        tokenizer.pad_token_id = pad_token_id
    else:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    model = AutoModelForCausalLM.from_pretrained(llm_model, device_map=device)
    logging.info("Loaded LLAMA model")
    if freeze_llm:
        for param in model.parameters():
            param.requires_grad = False
        logging.info("Model parameters frozen")
    return tokenizer, model

## Load the qformer model
num_query_tokens = 32
audio_width = 512
freeze = False
Qformer, query_tokens = init_Qformer(num_query_tokens, audio_width, freeze)
Qformer = Qformer.to(audio_gpu)
query_tokens = query_tokens.to(audio_gpu)

## Load the LLM model
tokenizer, llm = __load_llm__("/home/models/Llama-2-7b-hf", True, device="cpu")

## Create projection layer
proj = nn.Linear(Qformer.config.hidden_size, llm.config.hidden_size)
proj = proj.to(audio_gpu)
print("All models loaded")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

All models loaded


In [5]:
## Def get audio embeddings
def get_audio_embeddings(audios):
    inputs = clapModelProcessor(audios=audios, sampling_rate=48000, return_tensors="pt")
    if audio_gpu != "cpu":
        inputs = inputs.to(audio_gpu)
    with torch.no_grad():
        embeddings = clapModel.get_audio_features(**inputs, return_dict=True)
    return embeddings


def encode_audio(audios):
    ## Get the embeddings first
    wav_embs = get_audio_embeddings(audios)
    wav_embs = wav_embs.unsqueeze(1)
    if audio_gpu != "cpu":
        wav_embs = wav_embs.to(audio_gpu)

    ## Create the attention mask for the wav
    wav_attn = torch.ones(wav_embs.size()[:-1], dtype=torch.long).to(audio_gpu)

    ## Expand the query tokens
    wav_query_tokens = query_tokens.expand(wav_embs.shape[0], -1, -1)

    ## Create Qformer output
    query_output = Qformer.bert(
        query_embeds = wav_query_tokens,
        encoder_hidden_states = wav_embs,
        encoder_attention_mask = wav_attn,
        return_dict = True
    )

    return query_output

def project_query(query_output):
    wav_input = proj(query_output["last_hidden_state"])
    wav_attn = torch.ones(wav_input.size()[:-1], dtype=torch.long).to(audio_gpu)
    return wav_input, wav_attn

def encode_text(text, device="cpu"):
    output_dict = llm(tokenizer.encode(text, return_tensors="pt").to(device), return_dict=True, output_hidden_states=True)
    return output_dict

def ATC(audios, labels):
    ## Get the wav_input and wav_attn
    wav_rep, _ = project_query(encode_audio(audios))
    wav_rep = wav_rep[:,-1,:]

    ## Get the text_input
    text_rep = []
    for label in labels:
        text_rep.append(encode_text(label, device="cuda:1").hidden_states[-1][:,-1,:])
    text_rep = torch.cat(text_rep, dim=0)    

    temp = .5
    similarity = torch.matmul(wav_rep, text_rep.T) * temp
    labels = torch.arange(similarity.shape[0], device=similarity.device, dtype=torch.long)
    loss = (
        F.cross_entropy(similarity, labels, reduction="mean")
        + F.cross_entropy(similarity.T, labels, reduction="mean")
    ) / 2
    return loss


In [8]:
## Get the wav_input and wav_attn
wav_embs = get_audio_embeddings(audios)
wav_embs = wav_embs.unsqueeze(1)

## Create the attention mask for the wav
wav_attn = torch.ones(wav_embs.size()[:-1], dtype=torch.long).to(audio_gpu)

In [126]:
import torch.nn.functional as F
from info_nce import InfoNCE, info_nce
## Expand the query tokens
wav_query_tokens = query_tokens.expand(wav_embs.shape[0], -1, -1)

## Create Qformer output
wav_output = Qformer.bert(
    query_embeds = wav_query_tokens,
    encoder_hidden_states = wav_embs,
    encoder_attention_mask = wav_attn,
    return_dict = True
)

Qtokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
stuff = Qtokenizer(labels, return_tensors="pt", padding=True, truncation=True, max_length=256)
text_output = Qformer.bert(
    input_ids = stuff["input_ids"],
    attention_mask = stuff["attention_mask"],
    return_dict = True
)

best_match = torch.argmax(F.cosine_similarity(wav_output.last_hidden_state[:,:,:], text_output.last_hidden_state[:,0,:].unsqueeze(1), dim=2), dim=1)
wav_rep = torch.gather(wav_output.last_hidden_state, 1, best_match.view(-1, 1, 1).expand(-1, 1, 768)).squeeze(1)
cls_rep = text_output.last_hidden_state[:,0,:]
loss = InfoNCE()
output = loss(wav_rep, cls_rep)

In [117]:
batch_size, embedding_size = 32, 128
query = torch.randn(batch_size, embedding_size)
positive_key = torch.randn(batch_size, embedding_size)

In [121]:
wav_rep = torch.gather(wav_output.last_hidden_state, 1, best_match.view(-1, 1, 1).expand(-1, 1, 768)).squeeze(1)

In [129]:
output.backward()

In [124]:
output = loss(wav_rep, cls_rep)

In [125]:
output

tensor(1.8851, grad_fn=<NllLossBackward0>)