In [1]:
import json
import torch
import pandas as pd
from datasets import load_dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torchaudio
from Cave_model import CAVMAEFTAudio
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto")

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
file = "./data/toy_dataset/openaqa_toy.json"

with open(file, "r") as jsonFile:
    data = json.load(jsonFile)

In [4]:
print(data[0]['audio_id'])

./data/toy_dataset/audio/4tnW9atZKo0.flac


In [5]:
for i in range(len(data)):
    path = data[i]['audio_id']
    exten = path[len(path)-4:]
    if exten == "flac" or exten == ".wav":
        mini_path = ""
        for j in range(len(path)-1, -1, -1):
            if path[j] == "/":
                break
            mini_path += path[j]
        data[i]['audio_id'] = "./data/toy_dataset/audio/" + mini_path[::-1]
        
with open(file, "w") as jsonFile:
    json.dump(data, jsonFile)

In [5]:
data = load_dataset("json", data_files=file, split='train')

In [6]:
dataset = data.train_test_split(test_size=0.2)
train_dataset = dataset['train']
test_dataset = dataset['test']

In [7]:
train_dataset.shape, test_dataset.shape

((5044, 6), (1262, 6))

In [8]:
train_dataset[0]

{'instruction': 'Closed-ended question: Included sounds in clip are? Examine sound traits prior to making a decision.',
 'input': '',
 'audio_id': './data/toy_dataset/audio/aah1FLl5EjU.flac',
 'dataset': 'as_strong_train',
 'task': 'cla_label_des',
 'output': 'Labels with acoustic features: Dynamic and full-bodied -> Music; Vibrant and chaotic -> Crowd; Loud and high-pitched -> Cheering; Bright, warm, and smooth -> Female singing; Sharp and explosive -> Firecracker'}

In [9]:
# Filterbank
def load_audio(audio_path):
    waveform, sample_rate = torchaudio.load(audio_path)
    audio_info = 'Original input audio length {:.2f} seconds, number of channels: {:d}, sampling rate: {:d}.'.format(waveform.shape[1]/sample_rate, waveform.shape[0], sample_rate)
    if waveform.shape[0] != 1:
        waveform = waveform[0].unsqueeze(0)
        audio_info += ' Only the first channel is used.'
    if sample_rate == 16000:
        pass
    else:
        waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=16000)
        sample_rate = 16000
        audio_info += ' Resample to 16000Hz.'
    waveform = waveform - waveform.mean()
    fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=sample_rate,
                                              use_energy=False, window_type='hanning',
                                              num_mel_bins=128, dither=0.0, frame_shift=10)
    target_length = 1024
    n_frames = fbank.shape[0]
    p = target_length - n_frames
    if p > 0:
        m = torch.nn.ZeroPad2d((0, 0, 0, p))
        fbank = m(fbank)
    elif p < 0:
        fbank = fbank[0:target_length, :]
    # normalize the fbank
    fbank = (fbank + 5.081) / 4.4849
    return fbank, audio_info
    

In [10]:
audio_encoder = CAVMAEFTAudio()

def return_audio(path):

    cur_audio_input, audio_info = load_audio(path)
    cur_audio_input = cur_audio_input.unsqueeze(0)
    
    # projecting to 1024 input embedding dimension for T5
    audio_proj = nn.Sequential(nn.LayerNorm(768, elementwise_affine=False), nn.Linear(768, 1024))
    audio_input = audio_encoder(cur_audio_input)  # [B, 512, 768]
    audio_input = audio_input.reshape(audio_input.shape[0], 8, 64, audio_input.shape[-1])
    audio_input = torch.mean(audio_input, dim=1)  # mean pool over the frequency dimension # [B, 64, 768]
    audio_input = torch.nn.functional.avg_pool2d(audio_input, (2, 1)) #[B, 32, 768]
    # hard norm to 50
    audio_input = audio_input / 50
    audio_input = audio_proj(audio_input) #[B, 32, 1024]
    
    return audio_input

In [11]:

device = device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [12]:
import itertools
from torch.nn.utils.rnn import pad_sequence

class CombinedEmbeddingsDataset(Dataset):
    def __init__(self, data, tokenizer, model, device, max_length=512):
        self.data = data
        self.device = device
        self.model = model
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        instruction = self.data[idx]['instruction']
        audio = return_audio(self.data[idx]['audio_id'])
        label = self.data[idx]['output']
        
        input_ids = self.tokenizer(instruction, return_tensors="pt").input_ids.to(self.device)
        decoder_input_ids = self.tokenizer(label, return_tensors="pt").input_ids.to(self.device)

        with torch.no_grad():
            prompt_embeddings = self.model.shared(input_ids).squeeze(0)  # Shape: (sequence_length, 1024)

        audio_embeddings = audio.to(self.device).squeeze(0)  # Ensure audio is already a tensor with shape (32, 1024)

        combined_embeddings = torch.cat((prompt_embeddings, audio_embeddings), dim=0)  # Shape: (sequence_length + 32, 1024)

        if combined_embeddings.size(0) > self.max_length:
            combined_embeddings = combined_embeddings[:self.max_length, :]
        

        attention_mask = torch.ones(combined_embeddings.size(0)).to(self.device)


        padding_length = self.max_length - combined_embeddings.size(0)
        if padding_length > 0:
            padding_tensor = torch.zeros((padding_length, combined_embeddings.size(1))).to(self.device)
            combined_embeddings = torch.cat((combined_embeddings, padding_tensor), dim=0)
            
            padding_attention_mask = torch.zeros(padding_length).to(self.device)
            attention_mask = torch.cat((attention_mask, padding_attention_mask))

        return combined_embeddings, decoder_input_ids.squeeze(0), attention_mask

def collate_fn(batch):
    combined_embeddings = [item[0] for item in batch]
    decoder_input_ids = [item[1] for item in batch]
    attention_masks = [item[2] for item in batch]

    combined_embeddings_padded = torch.stack(combined_embeddings)
    decoder_input_ids_padded = pad_sequence(decoder_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_masks_padded = torch.stack(attention_masks)

    return combined_embeddings_padded, decoder_input_ids_padded, attention_masks_padded


dataset = CombinedEmbeddingsDataset(data, tokenizer, model, device)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)

In [13]:
# Example labels (target sequence for training)

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# Training loop
num_epochs = 15
for epoch in range(num_epochs):

    # Iterate through the DataLoader
    for i, (combined_embeddings, decoder_input_ids_padded, attention_masks_padded) in enumerate(itertools.islice(dataloader, 2)):
        
        optimizer.zero_grad()
        
        outputs = model(inputs_embeds=combined_embeddings, attention_mask=attention_masks_padded,labels=decoder_input_ids_padded)
        loss = outputs.loss
        print(loss.item())
        #predicted_ids = torch.argmax(logits, dim=-1)
        loss.backward()
        optimizer.step()
        

16.471179962158203
4.1364216804504395
14.224686622619629
3.105179786682129
12.362961769104004
2.648189067840576
10.938106536865234
2.23211669921875
9.685318946838379
1.8894554376602173
8.905342102050781
1.6156588792800903
7.267980098724365
1.6330735683441162
6.187309265136719
1.2089747190475464
3.951199531555176
0.9764653444290161
3.5838725566864014
0.8278824687004089
2.4870128631591797
0.7491388916969299
2.3570995330810547
0.5651708245277405
2.278388023376465
0.4262784421443939
2.1048436164855957
0.6759395599365234
2.0151476860046387
0.4251023232936859


In [14]:
torch.save({
            'epoch': num_epochs,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, "first_model.pt")

In [13]:


checkpoint = torch.load("first_model.pt")
model.load_state_dict(checkpoint['model_state_dict'])

model.eval()

T5ForConditionalGeneration(
  (shared): Embedding(32128, 1024)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 1024)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1024, out_features=1024, bias=False)
              (k): Linear(in_features=1024, out_features=1024, bias=False)
              (v): Linear(in_features=1024, out_features=1024, bias=False)
              (o): Linear(in_features=1024, out_features=1024, bias=False)
              (relative_attention_bias): Embedding(32, 16)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=1024, out_features=2816, bias=False)
              (wi_1): Linear(in_features=1024, out_features=2816, bias=False)
       

### Inference

In [15]:

prompt_text = "what can be infered from this audio following"
input_ids = tokenizer(prompt_text, return_tensors='pt').input_ids.to(device)
audio_input = return_audio("./data/toy_dataset/audio/_4X8RNeWeDI.flac")

with torch.no_grad():
    prompt_embeddings = model.shared(input_ids)  # Shape: (1, sequence_length, 1024)
    prompt_embeddings = prompt_embeddings.to(device)


target_text = "It can be inferred that the audio is a recording of a musical performance or a rehearsal.The combination of music, speech,"\
"and sound effects suggests that the audio is a representation of a musical performance or a rehearsal, where the music is being played"\
"and the performers are practicing their performance or rehearsing."

target_ids = tokenizer(target_text, return_tensors='pt').input_ids.to(device)

decoder_input_ids = model._shift_right(target_ids)

audio_embeddings = audio_input.to(device)  # Shape: (1, 32, 1024)

# Concatenate prompt and audio embeddings
combined_embeddings = torch.cat((prompt_embeddings, audio_embeddings), dim=1)  # Shape: (1, sequence_length + 32, 1024)

max_length = 512


if combined_embeddings.size(1) > max_length:
    combined_embeddings = combined_embeddings[:, :max_length, :]

padding_length = max_length - combined_embeddings.size(1)
if padding_length > 0:
    padding_tensor = torch.zeros((combined_embeddings.size(0), padding_length, combined_embeddings.size(2))).to(device)
    combined_embeddings = torch.cat((combined_embeddings, padding_tensor), dim=1)


attention_mask = torch.ones(combined_embeddings.size(0), combined_embeddings.size(1)).to(device)
if padding_length > 0:
    attention_mask[:, -padding_length:] = 0


outputs = model(inputs_embeds=combined_embeddings, attention_mask=attention_mask,labels=decoder_input_ids)
logits = outputs.logits  

predicted_ids = torch.argmax(logits, dim=-1)

decoded_text = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
decoded_text

'The Theune be inferred from the following file a recording of thea conversation performance. a recording of audio of the and sound and and instrumental effects are that the performers is a recording of thea musical performance. a rehearsal. and the performers is played played. the sound are  their musical skills ahearsing.'