In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
import torch
from transformers import WhisperProcessor, AutoTokenizer
from datasets import load_dataset

In [None]:
WHISPER_MODEL_NAME = "openai/whisper-base"
LLAMA_MODEL_NAME = "meta-llama/Llama-3.2-3B"

In [None]:
whisper_processor = WhisperProcessor.from_pretrained(WHISPER_MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME)

In [None]:
dataset_name = "gpt-omni/VoiceAssistant-400K"

In [None]:
dataset = load_dataset(dataset_name, split='train', streaming=True)

In [None]:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
print(torch.cuda.memory_summary())

In [None]:
from dataclasses import dataclass
from transformers import WhisperProcessor, PreTrainedTokenizer
import torchaudio
import torch

@dataclass
class GPTVoiceAssistantDataCollator:
    whisper_processor: WhisperProcessor
    tokenizer: PreTrainedTokenizer
    separator_token_id: int = 128000
    base_index: int = 1616
    switch_frequency: int = 4
    required_sample_rate: int = 16000

    def __call__(self, batch):
        audios = []
        for sample in batch:
            waveform = torch.tensor(sample["question_audio"]["array"]).float()
            orig_sr = sample["question_audio"]["sampling_rate"]
            if orig_sr != self.required_sample_rate:
                waveform = torchaudio.functional.resample(waveform, orig_sr, self.required_sample_rate)
            audios.append(waveform)

        # Process audio with WhisperProcessor
        audio_inputs = self.whisper_processor(
            audios,
            sampling_rate=self.required_sample_rate,
            return_tensors="pt"
        )
        input_features = audio_inputs.input_features  # shape [B, 80, T]

        # Tokenize text from the 'question' field
        texts = [sample["question"] for sample in batch]

        self.tokenizer.pad_token = self.tokenizer.eos_token
        tokenized = self.tokenizer(
            texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            add_special_tokens=False
        )

        input_ids = tokenized.input_ids
        attention_mask = tokenized.attention_mask
        batch_size = input_ids.shape[0]

        # Prepend separator token
        sep_token = torch.full((batch_size, 1), self.separator_token_id, dtype=input_ids.dtype)
        input_ids = torch.cat([sep_token, input_ids], dim=1)

        sep_mask = torch.ones((batch_size, 1), dtype=attention_mask.dtype)
        attention_mask = torch.cat([sep_mask, attention_mask], dim=1)

        labels = input_ids.clone()

        return {
            "input_features": input_features,
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }


In [None]:
ldpc = input_parameters = LibriSpeechDataCollator(
    whisper_processor=whisper_processor,
    tokenizer=tokenizer,
)

In [None]:
iterator = iter(dataset)
batch = [next(iterator) for _ in range(3)]

In [None]:
print(batch[0].keys())

In [None]:
input_parameters = ldpc(batch)

In [None]:
print(input_parameters["input_features"].shape)
print(input_parameters["labels"].shape)
print(input_parameters["input_ids"].shape)
print(input_parameters["attention_mask"].shape)

In [None]:
input_parameters['input_features'] = input_parameters['input_features'].cuda(0).to(torch.bfloat16)
input_parameters['labels'] = input_parameters['labels'].cuda(0)
input_parameters['input_ids'] = input_parameters['input_ids'].cuda(0)
input_parameters['attention_mask'] = input_parameters['attention_mask'].cuda(0)

In [None]:
from models import SpeechToTextModel

In [None]:
model = SpeechToTextModel(
    whisper_model_name=WHISPER_MODEL_NAME,
    llama_model_name=LLAMA_MODEL_NAME,
    hidden_dims=[2048, 1024, 2048, 1024, 2048],
    train_whisper=False,
    train_llama=False
)
model = model.to(torch.device("cuda:0"), dtype=torch.bfloat16)

In [None]:
for param in model.parameters():
    print(param.device)

for input_id in input_parameters['labels']:
    print(input_id.device)

In [None]:
outputs = model(
    input_features=input_parameters['input_features'],
    input_ids=input_parameters['input_ids'],
    attention_mask=input_parameters['attention_mask'],
    labels=input_parameters['labels'],
)