# Speech transformer 

## Environment

### Imports

In [1]:
# import sys
# sys.path.append('~/Projects/transformer_wrappers/src')

In [1]:
import os

In [2]:
import torch

In [3]:
from transformers import BitsAndBytesConfig
from peft import LoraConfig

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformer_wrappers.data import MozillaCommonVoice, ProcessedMozillaCommonVoice

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from transformer_wrappers.wrappers import SpeechCausalLMWrapper

### Constants and globals

In [5]:
TOKEN = None  # HF Token

In [3]:
MCV_IT_PATH = os.path.join(os.environ['HOME'], 'Data/mozilla_common_voice/cv-corpus-19.0-2024-09-13/')

In [4]:
LANG = 'it'

In [ ]:
AUDIO_TOKEN = '<|audio|>'

In [6]:
MODEL = 'gpt2'  
# MODEL = 'mistralai/Mistral-7B-Instruct-v0.3'  
# MODEL = 'meta-llama/Llama-3.1-8B-Instruct'
# MODEL = 'google/gemma-2-9b-it'
MODEL_CONFIGS = {
    'torch_dtype': torch.bfloat16,
    'device_map': 'cpu',  # torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'token': TOKEN
}
TOKENIZER_CONFIGS = {'token': TOKEN, 'pad_token': '<|endoftext|>', 'additional_special_tokens': [AUDIO_TOKEN]}

In [7]:
QUANTIZATION_CONFIGS = BitsAndBytesConfig(
    load_in_4bit=True, 
    bnb_4bit_use_double_quant=True, 
    bnb_4bit_quant_type='nf4', 
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [8]:
LORA_CONFIGS = LoraConfig(
    target_modules='all-linear',
    lora_alpha=16,
    lora_dropout=0.1,
    r=16,
    bias='none',
    task_type='CAUSAL_LM'
)

In [9]:
model = SpeechCausalLMWrapper.from_pretrained(
    MODEL,
    model_kwargs=MODEL_CONFIGS,
    # quantization_configs=QUANTIZATION_CONFIGS,
    lora_configs=LORA_CONFIGS,
    tokenizer_kwargs=TOKENIZER_CONFIGS
)



### Helper functions

In [10]:
...

Ellipsis

## Data

### Dummy data

In [11]:
text = f'Transcribe the following audio clip:\n{model.audio_token}\n\nTranscription:\n"In a hole in the ground there lived a hobbit."'

In [12]:
audio_file_path = '../audio.wav'

In [13]:
input_encoding = model.prepare_input(text, audio_file_path)

### Mozilla Common Voice

In [5]:
mcv = MozillaCommonVoice(MCV_IT_PATH, 'validation', language=LANG)

In [ ]:
pmcv = ProcessedMozillaCommonVoice(
    model.tokenizer, 
    AUDIO_TOKEN,
    MCV_IT_PATH, 
    'validation', 
    language=LANG, 
    max_samples_per_task=4, 
    random_seed=2307
)

## Forward

In this first example we show how to forward an input composed of text and audio to the model

In [14]:
input_encoding['input_spectrograms'].size()

torch.Size([1, 128, 570])

In [15]:
output = model.forward(**input_encoding)

In [16]:
output['spectrograms'].size()

torch.Size([1, 128, 570])

In [17]:
target_output = model.prepare_output(text, audio_file_path)

In [18]:
target_output['target_spectrograms'].size()

torch.Size([1, 128, 570])

In [19]:
loss = model._loss(
    token_logits=output['logits'],
    token_labels=target_output['token_labels'],
    predicted_spectrograms=output['spectrograms'],
    target_spectrograms=target_output['target_spectrograms']
)
loss

(tensor(3634.9307, grad_fn=<AddBackward0>),
 {'language_modelling_loss': tensor(20.2639, grad_fn=<NllLossBackward0>),
  'spectrogram_generation_loss': tensor(3614.6667, grad_fn=<MseLossBackward0>)})

In [20]:
model.post_process_spectrograms(input_encoding['input_spectrograms'], input_encoding['input_ids'])[0][0].size()

torch.Size([128, 402])

In [21]:
model.audio_processor.encode(audio_file_path).shape

(128, 402)

## Generate

In [22]:
model.enable_benchmarking()

In [23]:
ids, specs = model.generate(**input_encoding, max_new_tokens=4)
# ids, specs = model.generate(input_encoding.input_ids, max_new_tokens=4)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [24]:
ids

tensor([[ 8291,    66,  4892,   262,  1708,  6597, 10651,    25,   198, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257,   198,   198,  8291,  6820,
            25,   198,     1,   818,   257,  7604,   287,   262,  2323,   612,
          5615,   257, 32724,  2545,   526, 50257, 50257, 50257, 50257]])

In [25]:
specs

tensor([[[    nan,     nan,     nan,  ...,  0.0524, -0.3149,  0.5077],
         [    nan,     nan,     nan,  ..., -0.7762, -0.3863,  0.4043],
         [    nan,     nan,     nan,  ..., -0.2447, -0.6477,  0.3145],
         ...,
         [    nan,     nan,     nan,  ..., -0.5020, -0.4748,  0.4301],
         [    nan,     nan,     nan,  ...,  0.0308, -1.0451,  0.3360],
         [    nan,     nan,     nan,  ..., -0.0934,  0.5593, -0.2245]]])