In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
from transformers import WhisperProcessor, TrainingArguments, Trainer
from models import SpeechToTextModel
from datasets import load_dataset

In [None]:
from transformers import AutoTokenizer

In [None]:
from utils import LibriSpeechDataCollator

In [None]:
WHISPER_MODEL_NAME = "openai/whisper-base"
LLAMA_MODEL_NAME = "meta-llama/Llama-3.2-3B"

In [None]:
model = SpeechToTextModel(
    whisper_model_name=WHISPER_MODEL_NAME,
    llama_model_name=LLAMA_MODEL_NAME,
    hidden_dims=[2048, 1024, 2048],
    train_whisper=False
)

In [None]:
dataset_name = "openslr/librispeech_asr"

In [None]:
dataset = load_dataset(dataset_name, 'clean', split='train.100')

In [None]:
print(len(dataset))

In [None]:
processor = WhisperProcessor.from_pretrained(WHISPER_MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME)

In [None]:
training_args = TrainingArguments(
    output_dir="./v1-checkpoints",
    overwrite_output_dir=True,
    per_device_train_batch_size=4,
    num_train_epochs=1,
    logging_steps=10,
    save_steps=200,
    bf16=True,
    remove_unused_columns=False,
    learning_rate=5e-5,
    report_to="none",
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=LibriSpeechDataCollator(processor, tokenizer),
)

In [None]:
trainer.train()