In [1]:
import json
import torch
import pandas as pd
from datasets import load_dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
import torchaudio
from Cave_model import CAVMAEFTAudio
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from typing import List, Optional, Tuple, Union
from transformers.modeling_outputs import Seq2SeqLMOutput

from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
import torch.nn.init as init
from transformers.modeling_outputs import BaseModelOutput

In [2]:
import transformers
print(transformers.__file__)

/home/sayyss/.conda/envs/LTU-Replication/lib/python3.12/site-packages/transformers/__init__.py


In [3]:

device = device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
device

device(type='cuda')

In [5]:
class CustomT5ForConditionalGeneration(T5ForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)
        
        self.audio_encoder = CAVMAEFTAudio()
        self.audio_proj = nn.Sequential(nn.LayerNorm(768, elementwise_affine=False), nn.Linear(768, 1024)) 
        
        self.post_init()
        
    def process_audio(self,audio_input):
        
        audio_input = self.audio_encoder(audio_input.to(device))  # [B, 512, 768]
        audio_input = audio_input.reshape(audio_input.shape[0], 8, 64, audio_input.shape[-1])
        audio_input = torch.mean(audio_input, dim=1)  # mean pool over the frequency dimension # [B, 64, 768]
        audio_input = torch.nn.functional.avg_pool2d(audio_input, (2, 1)) #[B, 32, 768]
        # hard norm to 50
        audio_input = audio_input / 50
        audio_input = self.audio_proj(audio_input) #[B, 32, 1024]
        audio_length = audio_input.shape[1]
        audio_embeds = audio_input.to(device) # [B,32,1024]
        
        return audio_embeds

    def prepare_inputs_for_generation(
        self,
        input_ids,
        audio_input,
        past_key_values=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        decoder_attention_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):
        # cut decoder_input_ids if past_key_values is used
        if past_key_values is not None:
            past_length = past_key_values[0][0].shape[2]

            # Some generation methods already pass only the last input ID
            if input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                # Default to old behavior: keep only final ID
                remove_prefix_length = input_ids.shape[1] - 1

            input_ids = input_ids[:, remove_prefix_length:]

        audio_embeds = self.process_audio(audio_input)
        # Combine decoder input embeddings with audio embeddings
        decoder_inputs_embeds = self.shared(input_ids)
        decoder_inputs_embeds = torch.cat((decoder_inputs_embeds, audio_embeds), dim=1) # [B, seq_length+32, 1024]

        return {
            "decoder_inputs_embeds": decoder_inputs_embeds,
            "audio_input": audio_input,
            "past_key_values": past_key_values,
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "decoder_attention_mask": decoder_attention_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,
        }
        
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        audio_input = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
            labels in `[0, ..., config.vocab_size]`

        Returns:

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, T5ForConditionalGeneration

        >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
        >>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")

        >>> # training
        >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
        >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
        >>> outputs = model(input_ids=input_ids, labels=labels)
        >>> loss = outputs.loss
        >>> logits = outputs.logits

        >>> # inference
        >>> input_ids = tokenizer(
        ...     "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
        ... ).input_ids  # Batch size 1
        >>> outputs = model.generate(input_ids)
        >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
        >>> # studies have shown that owning a dog is good for you.
        ```"""
        
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

        # ******** Custom modifications start *********
        # Get audio embeddings
        if audio_input == None:
            raise ValueError("audio input cannot be empty")

        audio_embeds = self.process_audio(audio_input) # [B, 32, 1024]
        
        """
        audio_input = self.audio_encoder(audio_input)  # [B, 512, 768]
        audio_input = audio_input.reshape(audio_input.shape[0], 8, 64, audio_input.shape[-1])
        audio_input = torch.mean(audio_input, dim=1)  # mean pool over the frequency dimension # [B, 64, 768]
        audio_input = torch.nn.functional.avg_pool2d(audio_input, (2, 1)) #[B, 32, 768]
        # hard norm to 50
        audio_input = audio_input / 50
        audio_input = self.audio_proj(audio_input) #[B, 32, 1024]
        audio_length = audio_input.shape[1]
        audio_embeds = audio_input.to(device) # [B,32,1024] 
        """
        # Custom: get embeddings
        # Only runs during training
        if inputs_embeds is None and decoder_input_ids is None and input_ids is not None:
            
            inputs_embeds = self.shared(input_ids).to(device) # [b, seq_length, 768]
            inputs_embeds = torch.cat((inputs_embeds, audio_embeds), dim=1)  # Shape: (b,sequence_length + 32, 1024)

            max_length = 512
            seq_length = inputs_embeds.size(1)
            
            # Truncate if the sequence length exceeds max_length
            if seq_length > max_length:
                inputs_embeds = inputs_embeds[:, :max_length, :]
                seq_length = max_length
            
            # Apply padding if the sequence is shorter than max_length
            padding_length = max_length - seq_length
            if padding_length > 0:
                padding_tensor = torch.zeros((inputs_embeds.size(0), padding_length, inputs_embeds.size(2))).to(device)
                inputs_embeds = torch.cat((inputs_embeds, padding_tensor), dim=1)
            
            # Create attention mask
            attention_mask = torch.ones((inputs_embeds.size(0), max_length)).to(device)
            if padding_length > 0:
                attention_mask[:, seq_length:] = 0
            
            
        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=None, # Custom: change to none because we already defined embeddings
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        hidden_states = encoder_outputs[0]

        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)

        # modifications for inference
        # only runs when inference, skips to using labels as decoder input when training instead
        if decoder_inputs_embeds is not None:
            decoder_input_ids = None
            
        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = decoder_outputs[0]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.encoder.first_device)
            self.lm_head = self.lm_head.to(self.encoder.first_device)
            sequence_output = sequence_output.to(self.lm_head.weight.device)

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim**-0.5)

        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            # move labels to correct device to enable PP
            labels = labels.to(lm_logits.device)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

        if not return_dict:
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )

In [6]:
# Load the tokenizer and config
#tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
config = T5Config.from_pretrained("MBZUAI/LaMini-T5-738M")
tokenizer = T5Tokenizer.from_pretrained("MBZUAI/LaMini-T5-738M")

# Initialize your custom model
customT5 = CustomT5ForConditionalGeneration(config)
#model = customT5.from_pretrained("google/flan-t5-large", device_map="auto")\
model = customT5.from_pretrained("MBZUAI/LaMini-T5-738M", device_map="auto")

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Some weights of CustomT5ForConditionalGeneration were not initialized from the model checkpoint at MBZUAI/LaMini-T5-738M and are newly initialized: ['audio_encoder.blocks_a.0.attn.proj.bias', 'audio_encoder.blocks_a.0.attn.proj.weight', 'audio_encoder.blocks_a.0.attn.qkv.bias', 'audio_encoder.blocks_a.0.attn.qkv.weight', 'audio_encoder.blocks_a.0.mlp.fc1.bias', 'audio_encoder.blocks_a.0.mlp.fc1.weight', 'audio_encode

In [7]:
model.post_init()
model.audio_encoder.initialize_weights()

In [8]:
model.audio_proj[0].weight

In [9]:
import inspect
def get_methods_and_params(cls):
    methods_and_params = []
    for name, member in inspect.getmembers(cls):
        if inspect.ismethod(member) or inspect.isfunction(member):
            parameters = inspect.signature(member).parameters
            param_names = [param for param in parameters.keys() if param != 'self']
            methods_and_params.append((name, tuple(param_names)))
    return methods_and_params

In [10]:
pytorch_total_params = sum(p.numel() for p in model.parameters())
pytorch_total_params

824176640

In [11]:
file = "./data/toy_dataset/openaqa_toy.json"

with open(file, "r") as jsonFile:
    data = json.load(jsonFile)

In [11]:
for i in range(len(data)):
    path = data[i]['audio_id']
    exten = path[len(path)-4:]
    if exten == "flac" or exten == ".wav":
        mini_path = ""
        for j in range(len(path)-1, -1, -1):
            if path[j] == "/":
                break
            mini_path += path[j]
        data[i]['audio_id'] = "./data/toy_dataset/audio/" + mini_path[::-1]
        
with open(file, "w") as jsonFile:
    json.dump(data, jsonFile)

In [12]:
data = load_dataset("json", data_files=file, split='train')

In [13]:
dataset = data.train_test_split(test_size=0.2)
train_dataset = dataset['train']
test_dataset = dataset['test']

In [14]:
train_dataset.shape, test_dataset.shape

((5044, 6), (1262, 6))

### Only train on classifying audio first

In [15]:
count = 0
for i in range(train_dataset.shape[0]):
    if train_dataset[i]['task'].startswith("cla"):
        #print(train_dataset[i])
        count += 1
print(count)

1099


In [16]:
def is_classification_task(input):
    return input['task'].startswith("cla")

train_cla_dataset = train_dataset.filter(is_classification_task)
test_cla_dataset = test_dataset.filter(is_classification_task)

Filter:   0%|          | 0/5044 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1262 [00:00<?, ? examples/s]

In [17]:
train_cla_dataset.shape, test_cla_dataset.shape

((1099, 6), (296, 6))

In [18]:
train_cla_dataset[19]

{'instruction': 'Closed-ended question: Perform audio event classification on audio clip, produce tags solely for audio.',
 'input': '',
 'audio_id': './data/toy_dataset/audio/rKJYI_rn_sg_000001.flac',
 'dataset': 'vggsound_train',
 'task': 'cla_label',
 'output': 'Labels: Gibbon howling'}

In [19]:
# Filterbank
def load_audio(audio_path):
    waveform, sample_rate = torchaudio.load(audio_path)
    audio_info = 'Original input audio length {:.2f} seconds, number of channels: {:d}, sampling rate: {:d}.'.format(waveform.shape[1]/sample_rate, waveform.shape[0], sample_rate)
    if waveform.shape[0] != 1:
        waveform = waveform[0].unsqueeze(0)
        audio_info += ' Only the first channel is used.'
    if sample_rate == 16000:
        pass
    else:
        waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=16000)
        sample_rate = 16000
        audio_info += ' Resample to 16000Hz.'
    waveform = waveform - waveform.mean()
    fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=sample_rate,
                                              use_energy=False, window_type='hanning',
                                              num_mel_bins=128, dither=0.0, frame_shift=10)
    target_length = 1024
    n_frames = fbank.shape[0]
    p = target_length - n_frames
    if p > 0:
        m = torch.nn.ZeroPad2d((0, 0, 0, p))
        fbank = m(fbank)
    elif p < 0:
        fbank = fbank[0:target_length, :]
    # normalize the fbank
    fbank = (fbank + 5.081) / 4.4849
    return fbank, audio_info
    

In [20]:


def return_audio(path):

    cur_audio_input, audio_info = load_audio(path)
    cur_audio_input = cur_audio_input.unsqueeze(0)
    
    # projecting to 1024 input embedding dimension for T5
    audio_proj = nn.Sequential(nn.LayerNorm(768, elementwise_affine=False), nn.Linear(768, 1024))
    audio_input = audio_encoder(cur_audio_input)  # [B, 512, 768]
    audio_input = audio_input.reshape(audio_input.shape[0], 8, 64, audio_input.shape[-1])
    audio_input = torch.mean(audio_input, dim=1)  # mean pool over the frequency dimension # [B, 64, 768]
    audio_input = torch.nn.functional.avg_pool2d(audio_input, (2, 1)) #[B, 32, 768]
    # hard norm to 50
    audio_input = audio_input / 50
    audio_input = audio_proj(audio_input) #[B, 32, 1024]
    
    return audio_input

In [21]:
device = device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [20]:
device

device(type='cuda')

In [21]:
import itertools
from torch.nn.utils.rnn import pad_sequence

class CombinedEmbeddingsDataset(Dataset):
    def __init__(self, data, tokenizer, model, device, max_length=512):
        self.data = data
        self.device = device
        self.model = model
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        instruction = self.data[idx]['instruction']
        audio_path = self.data[idx]['audio_id']
        label = self.data[idx]['output']
        
        input_ids = self.tokenizer(instruction, return_tensors="pt").input_ids.to(self.device)
        decoder_input_ids = self.tokenizer(label, return_tensors="pt").input_ids.to(self.device)
        audio_bank, audio_info = load_audio(audio_path)
        audio_bank = audio_bank.to(self.device)
        
        return input_ids, decoder_input_ids.squeeze(0), audio_bank

def collate_fn(batch):

    input_ids = [item[0].squeeze(0) for item in batch]  # Remove unnecessary dim
    decoder_input_ids = [item[1] for item in batch]
    audio_input = [item[2] for item in batch]

    # Pad input_ids and decoder_input_ids to the maximum length in the batch
    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    decoder_input_ids_padded = pad_sequence(decoder_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    
    # Stack audio inputs directly if they have the same size
    audio_input_stacked = torch.stack(audio_input)
    return input_ids_padded, audio_input_stacked, decoder_input_ids_padded


#dataset = CombinedEmbeddingsDataset(train_dataset, tokenizer, model, device)
#dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn)

In [22]:
dataset = CombinedEmbeddingsDataset(train_cla_dataset, tokenizer, model, device)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn)

In [23]:


# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-6)


# Training loop
num_epochs = 1
for epoch in range(num_epochs):
    total_loss = 0
    for i, (input_ids, audio_bank, decoder_input_ids) in enumerate(dataloader):
        if torch.isnan(input_ids).any() or torch.isnan(audio_bank).any() or torch.isnan(decoder_input_ids).any():
            print("NaN detected in input tensors")
            continue
        
        outputs = model(input_ids=input_ids, audio_input=audio_bank, labels=decoder_input_ids)
        loss = outputs.loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
    avg_loss = total_loss / 10
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

KeyboardInterrupt: 

#### Freeze Audio encoder and LLM, only train projection layer

In [26]:
for param in model.parameters():
    param.requires_grad = False
for param in model.audio_encoder.parameters():
    param.requires_grad = False
for param in model.audio_proj.parameters():
    param.requires_grad = True

optimizer = torch.optim.AdamW(model.audio_proj.parameters(), lr=1e-6)

# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    total_loss = 0

    for i, (input_ids, audio_bank, decoder_input_ids) in enumerate(dataloader):
        if torch.isnan(input_ids).any() or torch.isnan(audio_bank).any() or torch.isnan(decoder_input_ids).any():
            print("NaN detected in input tensors")
            continue
        
        outputs = model(input_ids=input_ids, audio_input=audio_bank, labels=decoder_input_ids)
        loss = outputs.loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
    avg_loss = total_loss / 10
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
    

Epoch 1/5, Average Loss: 644.1280
Epoch 2/5, Average Loss: 644.1280
Epoch 3/5, Average Loss: 644.1280
Epoch 4/5, Average Loss: 644.1280
Epoch 5/5, Average Loss: 644.1280


In [18]:
torch.save({
            'epoch': num_epochs,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, "working_model_2.pt")

In [33]:
file = "./data/toy_dataset/audio/74S7Gw80ZOo_000040.flac"
cur_audio_input, audio_info = load_audio(file)
cur_audio_input = cur_audio_input.unsqueeze(0).to(device)
prompt_text = "This clip contains these sounds? Generate audio labels instantly:"
input_ids = tokenizer(prompt_text, return_tensors='pt').input_ids.to(device)

with torch.no_grad():
    outputs = model.generate(input_ids=input_ids, audio_input=cur_audio_input)
    
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)

factura factura factura factura factura factura factura factura factura factura factura factura factura factura factura factura factura factura factura


In [34]:
outputs

tensor([[    0, 27188, 27188, 27188, 27188, 27188, 27188, 27188, 27188, 27188,
         27188, 27188, 27188, 27188, 27188, 27188, 27188, 27188, 27188, 27188]])

In [25]:
print("Model parameters:")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, ":", param.data.mean(), param.data.std())

Model parameters:
shared.weight : tensor(-0.1905, device='cuda:0') tensor(9.9409, device='cuda:0')
encoder.block.0.layer.0.SelfAttention.q.weight : tensor(7.0893e-06, device='cuda:0') tensor(0.0321, device='cuda:0')
encoder.block.0.layer.0.SelfAttention.k.weight : tensor(0.0006, device='cuda:0') tensor(0.2651, device='cuda:0')
encoder.block.0.layer.0.SelfAttention.v.weight : tensor(-0.0002, device='cuda:0') tensor(0.1377, device='cuda:0')
encoder.block.0.layer.0.SelfAttention.o.weight : tensor(0.0006, device='cuda:0') tensor(0.2651, device='cuda:0')
encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight : tensor(2.3460, device='cuda:0') tensor(4.1137, device='cuda:0')
encoder.block.0.layer.0.layer_norm.weight : tensor(0.1013, device='cuda:0') tensor(0.0228, device='cuda:0')
encoder.block.0.layer.1.DenseReluDense.wi_0.weight : tensor(-0.0036, device='cuda:0') tensor(0.2012, device='cuda:0')
encoder.block.0.layer.1.DenseReluDense.wi_1.weight : tensor(-0.0005, device='cuda:0

In [24]:
print("Token 30207:", tokenizer.convert_ids_to_tokens([30207]))

Token 30207: ['zugreifen']


### Experiments

In [None]:
checkpoint = torch.load("working_model_2.pt")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

In [16]:
from Cave_model import CAVMAEFTAudio
audio_encoder = CAVMAEFTAudio()

In [17]:
file = "./data/toy_dataset/audio/577026.flac"
cur_audio_input, audio_info = load_audio(file)
cur_audio_input = cur_audio_input.unsqueeze(0)
audio_input = audio_encoder(cur_audio_input) 
audio_input

from CAVE inside: tensor([[[-2.2843, -2.4218, -2.0455,  ..., -0.9920, -1.5722, -2.3284],
         [-1.8898, -2.3895, -1.6217,  ..., -0.2275, -1.0407, -2.1476],
         [-1.6393, -2.1380, -1.3702,  ...,  0.0030, -0.7107, -1.5814],
         ...,
         [ 1.1329,  1.1329,  1.1329,  ...,  1.1329,  1.1329,  1.1329],
         [ 1.1329,  1.1329,  1.1329,  ...,  1.1329,  1.1329,  1.1329],
         [ 1.1329,  1.1329,  1.1329,  ...,  1.1329,  1.1329,  1.1329]]])
shape before CAVE: torch.Size([1, 1024, 128])
from CAVE before patch: tensor([[[[-2.2843, -1.8898, -1.6393,  ...,  1.1329,  1.1329,  1.1329],
          [-2.4218, -2.3895, -2.1380,  ...,  1.1329,  1.1329,  1.1329],
          [-2.0455, -1.6217, -1.3702,  ...,  1.1329,  1.1329,  1.1329],
          ...,
          [-0.9920, -0.2275,  0.0030,  ...,  1.1329,  1.1329,  1.1329],
          [-1.5722, -1.0407, -0.7107,  ...,  1.1329,  1.1329,  1.1329],
          [-2.3284, -2.1476, -1.5814,  ...,  1.1329,  1.1329,  1.1329]]]])
patch embed: torch.S

tensor([[[-1.7023, -0.7344, -0.0461,  ..., -0.0025, -0.1100, -0.4455],
         [-1.6768, -0.7517, -0.0163,  ...,  0.0435, -0.1022, -0.4384],
         [-1.6403, -0.7311,  0.0251,  ...,  0.0277, -0.0684, -0.4439],
         ...,
         [-2.0549, -0.5034, -0.2656,  ..., -0.2660,  0.1061, -0.4010],
         [-2.0651, -0.5383, -0.2489,  ..., -0.2203,  0.1020, -0.4096],
         [-2.0548, -0.5933, -0.2267,  ..., -0.1624,  0.0865, -0.4231]]],
       grad_fn=<NativeLayerNormBackward0>)

In [18]:
file = "./data/toy_dataset/audio/577026.flac"
cur_audio_input, audio_info = load_audio(file)
cur_audio_input = cur_audio_input.unsqueeze(0).to(device)
prompt_text = "what can be infered from this audio following"
input_ids = tokenizer(prompt_text, return_tensors='pt').input_ids.to(device)
target_text = "It can be inferred that the audio is a recording of a musical performance or a rehearsal.The combination of music, speech,"\
"and sound effects suggests that the audio is a representation of a musical performance or a rehearsal, where the music is being played"\
"and the performers are practicing their performance or rehearsing."
target_ids = tokenizer(target_text, return_tensors='pt').input_ids.to(device)
decoder_input_ids = model._shift_right(target_ids).to(device)

outputs = model(input_ids=input_ids, audio_input=cur_audio_input,labels=decoder_input_ids)

from CAVE inside: tensor([[[-2.2843, -2.4218, -2.0455,  ..., -0.9920, -1.5722, -2.3284],
         [-1.8898, -2.3895, -1.6217,  ..., -0.2275, -1.0407, -2.1476],
         [-1.6393, -2.1380, -1.3702,  ...,  0.0030, -0.7107, -1.5814],
         ...,
         [ 1.1329,  1.1329,  1.1329,  ...,  1.1329,  1.1329,  1.1329],
         [ 1.1329,  1.1329,  1.1329,  ...,  1.1329,  1.1329,  1.1329],
         [ 1.1329,  1.1329,  1.1329,  ...,  1.1329,  1.1329,  1.1329]]],
       device='cuda:0')
shape before CAVE: torch.Size([1, 1024, 128])
from CAVE before patch: tensor([[[[-2.2843, -1.8898, -1.6393,  ...,  1.1329,  1.1329,  1.1329],
          [-2.4218, -2.3895, -2.1380,  ...,  1.1329,  1.1329,  1.1329],
          [-2.0455, -1.6217, -1.3702,  ...,  1.1329,  1.1329,  1.1329],
          ...,
          [-0.9920, -0.2275,  0.0030,  ...,  1.1329,  1.1329,  1.1329],
          [-1.5722, -1.0407, -0.7107,  ...,  1.1329,  1.1329,  1.1329],
          [-2.3284, -2.1476, -1.5814,  ...,  1.1329,  1.1329,  1.1329]]

In [36]:
file = "./data/toy_dataset/audio/3CHor3uzS00.flac"
cur_audio_input, audio_info = load_audio(file)
cur_audio_input = cur_audio_input.unsqueeze(0).to(device)
prompt_text = "what can be infered from this audio following"
input_ids = tokenizer(prompt_text, return_tensors='pt').input_ids.to(device)
target_text = ""
target_ids = tokenizer(target_text, return_tensors='pt').input_ids.to(device)
decoder_input_ids = model._shift_right(target_ids).to(device)

outputs = model(input_ids=input_ids, audio_input=cur_audio_input,labels=decoder_input_ids)
logits = outputs.logits  
predicted_ids = torch.argmax(logits, dim=-1)
decoded_text = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
print(decoded_text)

There


In [None]:
"""
audio_encoder = CAVMAEFTAudio()
prompt_text = "what can be infered from this audio following"
input_ids = tokenizer(prompt_text, return_tensors='pt').input_ids.to(device)
audio_input = return_audio("./data/toy_dataset/audio/_4X8RNeWeDI.flac")

with torch.no_grad():
    prompt_embeddings = model.shared(input_ids)  # Shape: (1, sequence_length, 1024)
    prompt_embeddings = prompt_embeddings.to(device)


target_text = "It can be inferred that the audio is a recording of a musical performance or a rehearsal.The combination of music, speech,"\
"and sound effects suggests that the audio is a representation of a musical performance or a rehearsal, where the music is being played"\
"and the performers are practicing their performance or rehearsing."

target_ids = tokenizer(target_text, return_tensors='pt').input_ids.to(device)

decoder_input_ids = model._shift_right(target_ids)

audio_embeddings = audio_input.to(device)  # Shape: (1, 32, 1024)

# Concatenate prompt and audio embeddings
combined_embeddings = torch.cat((prompt_embeddings, audio_embeddings), dim=1)  # Shape: (1, sequence_length + 32, 1024)

max_length = 512


if combined_embeddings.size(1) > max_length:
    combined_embeddings = combined_embeddings[:, :max_length, :]

padding_length = max_length - combined_embeddings.size(1)
if padding_length > 0:
    padding_tensor = torch.zeros((combined_embeddings.size(0), padding_length, combined_embeddings.size(2))).to(device)
    combined_embeddings = torch.cat((combined_embeddings, padding_tensor), dim=1)


attention_mask = torch.ones(combined_embeddings.size(0), combined_embeddings.size(1)).to(device)
if padding_length > 0:
    attention_mask[:, -padding_length:] = 0


outputs = model.generate(inputs_embeds=combined_embeddings, attention_mask=attention_mask,labels=decoder_input_ids)
"""