In [None]:
# Import packages
import torch, os
from huggingface_hub import login
from datasets import load_dataset, load_metric, DatasetDict
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer, 
    TrainerCallback
)

# Set device to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Log into HuggingFace Hub
os.environ['HUGGINGFACE_HUB_TOKEN'] = '####'
login(token=os.environ['HUGGINGFACE_HUB_TOKEN'])

In [None]:
# Load and split dataset
edacc = DatasetDict()
edacc["dev"] = load_dataset("sage-bergerson/edacc_processed", split="dev", token=True)
edacc["test"] = load_dataset("sage-bergerson/edacc_processed", split="test", token=True)

# Initialize Whisper processing tools
processor = WhisperProcessor.from_pretrained("openai/whisper-large", language="English", task="transcribe")

In [None]:
# Download and initialize Whisper model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large").to(device)
model.generation_config.language = "english"
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None

# Define data collator
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Treat the audio inputs by returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # Get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        # Pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # Replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # If BOS token is appended in previous tokenization step, cut
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        return batch

# Initialize data collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

# Load the WER metric
wer_metric = load_metric("wer")

# Define training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-large-edacc-v3",
    per_device_train_batch_size=32,
    gradient_accumulation_steps=1,
    learning_rate=5e-6,
    warmup_steps=500,
    max_steps=600,
    gradient_checkpointing=True,
    fp16=True,
    eval_strategy="steps",
    per_device_eval_batch_size=16,
    predict_with_generate=True,
    generation_max_length=448,
    save_steps=600,
    eval_steps=600,
    logging_steps=600,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    greater_is_better=False,
    push_to_hub=True,
    lr_scheduler_type="linear"
)

# Define a list to store loss values
loss_values = {
    "steps": [],
    "train_loss": [],
    "eval_loss": [],
}

# Define trainer
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=edacc["dev"],
    eval_dataset=edacc["test"],
    data_collator=data_collator,
    compute_metrics=None,
    tokenizer=processor.feature_extractor,
)

# Custom callback to log and store losses
class CustomTrainerCallback(TrainerCallback):
    def on_evaluate(self, args, state, control, **kwargs):
        
        # Log eval metrics
        train_loss = kwargs['metrics'].get('loss', None)
        eval_loss = kwargs['metrics']['eval_loss']
        # wer = kwargs['metrics']['eval_wer']

        # Save eval metrics
        loss_values["steps"].append(state.global_step)
        if train_loss is not None:
            loss_values["train_loss"].append(train_loss)
        loss_values["eval_loss"].append(eval_loss)
        # loss_values["eval_wer"].append(wer)

trainer.add_callback(CustomTrainerCallback)

# Save training arguments
processor.save_pretrained(training_args.output_dir)

In [None]:
# Train model
trainer.train()

In [None]:
# Save output and push to HF hub
kwargs = {
    "dataset_tags": "sage-bergerson/edacc_processed",
    "dataset": "EdAcc",
    "dataset_args": "config: en, split: train",
    "language": "en",
    "model_name": "Whisper Large EdAcc V3",
    "finetuned_from": "openai/whisper-large",
    "tasks": "automatic-speech-recognition",
}

trainer.push_to_hub(**kwargs)