# Speech transformer 

## Environment

### Imports

In [1]:
import sys
sys.path.append('~/Projects/transformer_wrappers/src')

In [1]:
import torch

In [2]:
from transformers import BitsAndBytesConfig

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from transformer_wrappers.wrappers import SpeechCausalLMWrapper

### Constants and globals

In [4]:
TOKEN = None  # HF Token

In [5]:
MODEL = 'gpt2'  
# MODEL = 'mistralai/Mistral-7B-Instruct-v0.3'  
# MODEL = 'meta-llama/Llama-3.1-8B-Instruct'
# MODEL = 'google/gemma-2-9b-it'
MODEL_CONFIGS = {
    'torch_dtype': torch.bfloat16,
    'device_map': 'cpu',  # torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'token': TOKEN
}
TOKENIZER_CONFIGS = {'token': TOKEN, 'pad_token': '<|endoftext|>', 'additional_special_tokens': ['<|audio|>']}

In [6]:
QUANTIZATION_CONFIGS = BitsAndBytesConfig(
    load_in_4bit=True, 
    bnb_4bit_use_double_quant=True, 
    bnb_4bit_quant_type='nf4', 
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [7]:
model = SpeechCausalLMWrapper.from_pretrained(
    MODEL,
    model_kwargs=MODEL_CONFIGS,
    # quantization_configs=QUANTIZATION_CONFIGS,
    tokenizer_kwargs=TOKENIZER_CONFIGS
)



### Helper functions

In [8]:
...

Ellipsis

## Speech recognition

In this first example we show how to forward an input composed of text and audio to the model

In [9]:
text = f'Transcribe the following audio clip:\n{model.audio_token}\n\nTranscription:\n"In a hole in the ground there lived a hobbit."'

In [10]:
audio_file_path = '../audio.wav'

In [11]:
input_encoding = model.prepare_input(text, audio_file_path)

In [12]:
input_encoding['input_spectrograms'][0].size()

torch.Size([128, 402])

In [14]:
output = model.forward(**input_encoding)

In [15]:
for k in output:
    print(k)

position_ids
past_key_values
batch_size
prefix_length
sequence_length
cache_position
dtype
device
speech_mask
attention_mask
add_attn_residual
add_ffnn_residual
use_cache
output_attentions
output_hidden_states
return_intermediate_hidden_states
return_attention_output
return_feed_forward_up_proj_output
return_feed_forward_gate_output
return_feed_forward_inner_activations
return_feed_forward_output
output_hidden_state
logits
spectrograms
modality_mask
loss
cache
hidden_states
attention_weights
return_dict


In [16]:
output['spectrograms']

[tensor([[-10.7500, -10.2500,   5.1875,  ...,   2.8125,   1.5625,  -3.0938],
         [ -2.5312, -10.1875,   1.9609,  ...,   1.0625,  -3.2031,  -3.9219],
         [ -5.4375,   4.4688,  -8.2500,  ...,   0.6172,  -4.5000,  -0.3613],
         ...,
         [  1.9297,  -7.7188, -11.8125,  ...,   0.1885,  -2.0781,  -2.0156],
         [ -6.5938, -16.8750,  -0.0503,  ...,   0.4629,  -0.3027,   0.1143],
         [  4.0000,  -1.1328,  12.3750,  ...,  -2.3281,  -2.0000,  -0.5820]],
        dtype=torch.bfloat16, grad_fn=<SqueezeBackward1>)]