# Training a Whisper model for doing ASR on Tagalog Bible data
Some of this code for preparing the data and models was taken or adapted from https://huggingface.co/blog/fine-tune-whisper. All experiments were run on Google Colab, so it hasn't been tested on a local machine with a fresh Python environment.

# Install requirements

In [None]:
!pip install transformers datasets evaluate jiwer
!pip install -U accelerate

# Imports

In [None]:
import csv
from pathlib import Path

from dataclasses import dataclass
from typing import Any, Dict, List, Union

import evaluate
import torch

from transformers import (
    WhisperProcessor,
    WhisperTokenizer,
    WhisperForConditionalGeneration,
    WhisperFeatureExtractor,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
from datasets import load_dataset, concatenate_datasets

# Load data

In [None]:
!mkdir train dev test

**Make sure to upload data to `train/` `dev/`, and `test/` splits before moving forward.**

## Generate metadata
Create a `metadata.csv` file in each data directory used to load the audio and text data. This contains the name of each `.wav` file and the content of the corresponding `.txt` file.

In [None]:
def generate_metadata(data_dir: Path):
    audio_files = data_dir.glob("*.wav")

    with open(data_dir / 'metadata.csv', 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["file_name", "transcription"])

        for audio_file in audio_files:
            txt_file = audio_file.with_suffix(".txt")
            if txt_file.is_file():
                with txt_file.open('r', encoding='utf8') as txt:
                    transcription = txt.read()
                writer.writerow([audio_file.name, transcription])

In [None]:
generate_metadata(Path("./train"))
generate_metadata(Path("./dev"))
generate_metadata(Path("./test"))

## Load dataset from splits

In [None]:
ds = load_dataset("audiofolder", data_dir="./")
ds

# Load pretrained models

In [None]:
# select which model and language you want to use
# tiny, base,	small, medium, large, large-v2
whisper_model = "openai/whisper-base"
language = "tagalog"

In [None]:
feature_extractor = WhisperFeatureExtractor.from_pretrained(whisper_model, task="transcribe", language=language)
tokenizer = WhisperTokenizer.from_pretrained(whisper_model, task="transcribe", language=language)
processor = WhisperProcessor.from_pretrained(whisper_model, task="transcribe", language=language)
model = WhisperForConditionalGeneration.from_pretrained(whisper_model)

model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

# Prepare data

In [None]:
def prepare_dataset(batch):
    audio = batch["audio"]
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]
    batch["labels"] = tokenizer(batch["transcription"]).input_ids
    return batch

In [None]:
ds = ds.map(prepare_dataset, remove_columns=ds.column_names["train"])

In [None]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        return batch

In [None]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

# Prepare evaluation metrics
We will use Word Error Rate (WER) and Character Error Rate (CER) to evaluate the model.

In [None]:
metric_wer = evaluate.load("wer")
metric_cer = evaluate.load("cer")

In [1]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric_wer.compute(predictions=pred_str, references=label_str)
    cer = 100 * metric_cer.compute(predictions=pred_str, references=label_str)

    return {"wer": wer, "cer": cer}

# Create model for training
Generate a `Trainer` object using the Whisper models and evaluation metrics defined above.

In [34]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./result",
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=10,
    num_train_epochs=30,  # delete if steps is chosen
    gradient_checkpointing=True,
    fp16=True,  # True only if training on GPU, it won't work on CPU
    evaluation_strategy="epoch",  # steps, epoch, no - must match `save_strategy`
    save_strategy="epoch",  # steps, epoch, no - must match `evaluation_strategy`
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    logging_steps=1,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False
)

In [35]:
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=ds["train"],
    eval_dataset=ds["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor
)

## Zero-shot learning
First, evaluate the model on the dev and test sets to get a baseline for the pretrained Whisper model.

In [None]:
trainer.evaluate(eval_dataset=ds["validation"])

In [None]:
trainer.evaluate(eval_dataset=ds["test"])

## Fine-tuning
Use the trainer from above to fine-tune the Whisper model to this specific data

In [None]:
trainer.train()

## Evaluation
Evaluate performance of the best model from fine-tuning on the dev and test sets (best dev will be the same as the best epoch from fine-tuning, as that is the model that is saved).

In [None]:
trainer.evaluate(eval_dataset=ds["validation"])

In [None]:
trainer.evaluate(eval_dataset=ds["test"])