In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
import torch
from transformers import WhisperProcessor, AutoTokenizer
from datasets import load_dataset

In [None]:
WHISPER_MODEL_NAME = "openai/whisper-base"
LLAMA_MODEL_NAME = "meta-llama/Llama-3.2-3B"

In [None]:
whisper_processor = WhisperProcessor.from_pretrained(WHISPER_MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME)

In [None]:
dataset_name = "openslr/librispeech_asr"

In [None]:
dataset = load_dataset(dataset_name, 'clean', split='train.100', streaming=True)

In [None]:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
print(torch.cuda.memory_summary())

In [None]:
from dataclasses import dataclass
from transformers import WhisperProcessor, PreTrainedTokenizer

@dataclass
class LibriSpeechDataCollator:
    whisper_processor: WhisperProcessor
    tokenizer: PreTrainedTokenizer
    separator_token_id: int = 128000

    def __call__(self, batch):
        audios = [sample['audio']['array'] for sample in batch]
        texts = [sample['text'] for sample in batch]

        # all libri speech are 16kHz
        audio_inputs = self.whisper_processor(
            audios,
            sampling_rate=16000,
            return_tensors="pt",
        )
        input_features = audio_inputs.input_features  # size [B, 80, 1500]
        batch_size, seq_audio, _ = input_features.shape

        self.tokenizer.pad_token = self.tokenizer.eos_token
        tokenized = self.tokenizer(
            texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            add_special_tokens=False
        )
        input_ids = tokenized.input_ids
        attention_mask = tokenized.attention_mask

        separator_token = torch.full((batch_size, 1), self.separator_token_id, dtype=input_ids.dtype)
        input_ids_prepended = torch.cat([separator_token, input_ids], dim=1)

        attend_to_separator = torch.full((batch_size, 1), 1, dtype=attention_mask.dtype)
        attention_mask_prepended = torch.cat([attend_to_separator, attention_mask], dim=1)

        labels = input_ids_prepended.clone()

        return {
            "input_features": input_features,
            "input_ids": input_ids_prepended,
            "attention_mask": attention_mask_prepended,
            "labels": labels
        }

In [None]:
iterator = iter(dataset)
batch = [next(iterator) for _ in range(3)]

In [None]:
ldpc = input_parameters = LibriSpeechDataCollator(
    whisper_processor=whisper_processor,
    tokenizer=tokenizer,
)

In [None]:
input_parameters = ldpc(batch)

In [None]:
print(input_parameters["input_features"].shape)
print(input_parameters["labels"].shape)
print(input_parameters["input_ids"].shape)
print(input_parameters["attention_mask"].shape)

In [None]:
input_parameters['input_features'] = input_parameters['input_features'].cuda(0).to(torch.bfloat16)
input_parameters['labels'] = input_parameters['labels'].cuda(0)
input_parameters['input_ids'] = input_parameters['input_ids'].cuda(0)
input_parameters['attention_mask'] = input_parameters['attention_mask'].cuda(0)

In [None]:
from models import SpeechToTextModel

In [None]:
model = SpeechToTextModel(
    whisper_model_name=WHISPER_MODEL_NAME,
    llama_model_name=LLAMA_MODEL_NAME,
    hidden_dims=[2048, 1024, 2048, 1024, 2048],
    train_whisper=False,
    train_llama=False
)
model = model.to(torch.device("cuda:0"), dtype=torch.bfloat16)

In [None]:
for param in model.parameters():
    print(param.device)

for input_id in input_parameters['labels']:
    print(input_id.device)

In [None]:
outputs = model(
    input_features=input_parameters['input_features'],
    input_ids=input_parameters['input_ids'],
    attention_mask=input_parameters['attention_mask'],
    labels=input_parameters['labels'],
)

In [None]:
class Collator:
  def __init__(self, tokenizer: PreTrainedTokenizer, whisper_processor):
    self.whisper_processor = whisper_processor
    self.tokenizer = tokenizer
    self.separator_token_id: int = 128000

  def preprocess(self, batch):
    audios = [sample['audio']['array'] for sample in batch]
    texts = [sample['text'] for sample in batch]

    # all libri speech are 16kHz
    audio_inputs = self.whisper_processor(
        audios,
        sampling_rate=16000,
        return_tensors="pt",
    )
    input_features = audio_inputs.input_features
    batch_size, seq_audio, _ = input_features.shape

    self.tokenizer.pad_token = self.tokenizer.eos_token
    tokenized = self.tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        add_special_tokens=False
    )
    input_ids = tokenized.input_ids

    separator_token = torch.full((batch_size, 1), self.separator_token_id, dtype=input_ids.dtype)
    input_ids_prepended = torch.cat([separator_token, input_ids], dim=1)

    labels = input_ids_prepended.clone()

    return {
        "input_features": input_features,
        "input_ids": input_ids_prepended,
        "labels": labels
    }