In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
import torch
from datasets import load_dataset
from models import SpeechToTextModel

In [None]:
from transformers import WhisperProcessor, TrainingArguments, Trainer, AutoTokenizer

In [None]:
from utils import LibriSpeechDataCollator

In [None]:
WHISPER_MODEL_NAME = "openai/whisper-base"
LLAMA_MODEL_NAME = "meta-llama/Llama-3.2-3B"

In [None]:
model = SpeechToTextModel(
    whisper_model_name=WHISPER_MODEL_NAME,
    llama_model_name=LLAMA_MODEL_NAME,
    hidden_dims=[2048, 1024, 2048],
    train_whisper=False
)

In [None]:
dataset_name = "openslr/librispeech_asr"

In [None]:
dataset = load_dataset(dataset_name, 'clean', split='train.100')

In [None]:
processor = WhisperProcessor.from_pretrained(WHISPER_MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME)

In [None]:
training_args = TrainingArguments(
    output_dir="./v1-checkpoints",
    overwrite_output_dir=True,
    per_device_train_batch_size=1,
    num_train_epochs=1,
    logging_steps=10,
    save_steps=100,
    bf16=True,
    remove_unused_columns=False,
    learning_rate=5e-5,
    report_to="none",
    save_safetensors=False,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=LibriSpeechDataCollator(processor, tokenizer),
)

In [None]:
trainer.train()

In [None]:
model


In [None]:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
print(torch.cuda.memory_summary())

In [None]:
import gc

torch.cuda.empty_cache()
gc.collect()
with torch.no_grad():
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj):
                del obj
        except:
            pass
torch.cuda.empty_cache()

In [None]:
inspect.getsource(LibriSpeechDataCollator)

In [None]:
batch = dataset.select(range(1))
lbdc = LibriSpeechDataCollator(processor, tokenizer)

In [None]:
input_parameters = lbdc(batch)

In [None]:
print(input_parameters["input_features"].shape)
print(input_parameters["labels"].shape)
print(input_parameters["input_ids"].shape)

In [None]:
input_parameters['input_features'] = input_parameters['input_features'].cuda()
input_parameters['labels'] = input_parameters['labels'].cuda()
input_parameters['input_ids'] = input_parameters['input_ids'].cuda()

In [None]:
model = model.cuda()

In [None]:
outputs = model.forward(
    input_features=input_parameters['input_features'],
    labels=input_parameters['labels'],
    input_ids=input_parameters['input_ids'],
)

In [None]:
print(outputs.logits.shape)

In [None]:
print(outputs.loss)