# Fine-tuning Whisper Small for Minangkabau Language

This notebook fine-tunes the OpenAI Whisper Small model on the Minangkabau subset of the `indonesian-nlp/librivox-indonesia` dataset. 

**Improvements in this version:**
- **Efficient Data Loading:** Uses Hugging Face `datasets.Audio` feature to handle audio files directly without manual conversion to WAV.
- **Evaluation Metrics:** Tracks Word Error Rate (WER) during training to select the best model.
- **Optimized Training:** Adjusted hyperparameters for better convergence on low-resource data.
- **RunPod Ready:** Configured to use `/workspace` for persistent storage.
- **WandB Integration:** Tracks experiments using Weights & Biases.

In [None]:
# Install dependencies
!apt-get update -y && apt-get install -y ffmpeg
!pip install -q datasets transformers torchaudio evaluate jiwer accelerate tensorboard scikit-learn wandb

In [None]:
import os
import torch
import pandas as pd
import tarfile
import wandb
from pathlib import Path
from datasets import Dataset, Audio, load_metric
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    Seq2SeqTrainer
)
import evaluate
from dataclasses import dataclass
from typing import Any, Dict, List, Union

# --- Configuration ---
MODEL_NAME = "openai/whisper-small"
LANGUAGE = "minangkabau"
TASK = "transcribe"
SAMPLING_RATE = 16000

# RunPod typically uses /workspace for persistent storage
DATA_DIR = Path("/workspace/data")
DATA_DIR.mkdir(parents=True, exist_ok=True)

OUTPUT_DIR = Path("/workspace/whisper-minang-finetuned")
LOG_DIR = Path("/workspace/logs")

In [None]:
# Login to WandB
# You will be prompted to enter your API key
wandb.login()

## 1. Data Preparation
We download the dataset and extract it. We will use the `datasets` library to handle audio loading on-the-fly.

In [None]:
# Download files
urls = [
    "https://huggingface.co/datasets/indonesian-nlp/librivox-indonesia/resolve/main/data/audio_train.tgz",
    "https://huggingface.co/datasets/indonesian-nlp/librivox-indonesia/resolve/main/data/audio_test.tgz",
    "https://huggingface.co/datasets/indonesian-nlp/librivox-indonesia/resolve/main/data/metadata_train.csv.gz",
    "https://huggingface.co/datasets/indonesian-nlp/librivox-indonesia/resolve/main/data/metadata_test.csv.gz"
]

for url in urls:
    filename = url.split("/")[-1]
    dest = DATA_DIR / filename
    if not dest.exists():
        print(f"Downloading {filename}...")
        !wget -nc {url} -P {DATA_DIR}
    else:
        print(f"{filename} already exists.")

# Extract Audio
for archive in ["audio_train.tgz", "audio_test.tgz"]:
    archive_path = DATA_DIR / archive
    extract_path = DATA_DIR / archive.replace(".tgz", "")
    # Check if extraction folder exists and is not empty as a rough check
    if not extract_path.exists():
        print(f"Extracting {archive}...")
        with tarfile.open(archive_path, "r:gz") as tar:
            tar.extractall(path=DATA_DIR)
    else:
        print(f"{archive} seems to be already extracted.")

In [None]:
# Load Metadata
train_meta_path = DATA_DIR / "metadata_train.csv.gz"
test_meta_path = DATA_DIR / "metadata_test.csv.gz"

df_train = pd.read_csv(train_meta_path)
df_test = pd.read_csv(test_meta_path)

# Filter for Minangkabau
df_train = df_train[df_train["language"] == "min"].copy()
df_test = df_test[df_test["language"] == "min"].copy()

# Construct full paths
def get_full_path(row, split_dir):
    return str(DATA_DIR / split_dir / "librivox-indonesia" / row["path"])

df_train["audio"] = df_train.apply(lambda x: str(DATA_DIR / "audio_train" / "librivox-indonesia" / x["path"]), axis=1)
df_test["audio"] = df_test.apply(lambda x: str(DATA_DIR / "audio_train" / "librivox-indonesia" / x["path"]), axis=1)

# Verify a file exists
if len(df_train) > 0:
    print(f"Checking file: {df_train.iloc[0]['audio']}")
    print(f"Exists: {os.path.exists(df_train.iloc[0]['audio'])}")

print(f"Train samples: {len(df_train)}")
print(f"Test samples: {len(df_test)}")

In [None]:
# Create Hugging Face Datasets
train_dataset = Dataset.from_pandas(df_train)
test_dataset = Dataset.from_pandas(df_test)

# Cast audio column to Audio feature
# This automatically handles loading and resampling to 16kHz
train_dataset = train_dataset.cast_column("audio", Audio(sampling_rate=SAMPLING_RATE))
test_dataset = test_dataset.cast_column("audio", Audio(sampling_rate=SAMPLING_RATE))

# Remove unnecessary columns
train_dataset = train_dataset.remove_columns(["path", "language", "reader", "__index_level_0__"])
test_dataset = test_dataset.remove_columns(["path", "language", "reader", "__index_level_0__"])

# Split train into train/val
train_test_split = train_dataset.train_test_split(test_size=0.1)
train_dataset = train_test_split["train"]
eval_dataset = train_test_split["test"]

print(train_dataset)

## 2. Processing Pipeline
We prepare the processor and the data processing function.

In [None]:
processor = WhisperProcessor.from_pretrained(MODEL_NAME, language=LANGUAGE, task=TASK)

def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    batch["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids 
    batch["labels"] = processor.tokenizer(batch["sentence"]).input_ids
    return batch

# Apply preprocessing
# num_proc > 1 for faster processing
train_dataset = train_dataset.map(prepare_dataset, remove_columns=train_dataset.column_names, num_proc=os.cpu_count())
eval_dataset = eval_dataset.map(prepare_dataset, remove_columns=eval_dataset.column_names, num_proc=os.cpu_count())
test_dataset_processed = test_dataset.map(prepare_dataset, remove_columns=test_dataset.column_names, num_proc=os.cpu_count())

## 3. Training Setup
We define the data collator, metrics, and training arguments.

In [None]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [None]:
metric = evaluate.load("jiwer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [None]:
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME)

# Disable cache during training
model.config.use_cache = False

# Set language and task for generation
model.generation_config.language = LANGUAGE
model.generation_config.task = TASK

# Training Arguments
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=2,
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=4000,  # Train for longer
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=500,
    eval_steps=500,
    logging_steps=25,
    report_to=["tensorboard", "wandb"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
    run_name="whisper-small-minangkabau-finetune" # Name for WandB run
)

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

## 4. Training
Start the training process.

In [None]:
trainer.train()

In [None]:
# Finish WandB run
wandb.finish()

In [None]:
# Save the final model
trainer.save_model(OUTPUT_DIR / "final_model")
processor.save_pretrained(OUTPUT_DIR / "final_model")

## 5. Evaluation on Test Set
Evaluate the model on the held-out test set.

In [None]:
print("Evaluating on Test Set...")
test_metrics = trainer.evaluate(test_dataset_processed)
print(f"Test WER: {test_metrics['eval_wer']:.2%}")

In [None]:
# Visual Check
from transformers import pipeline

pipe = pipeline("automatic-speech-recognition", model=OUTPUT_DIR / "final_model", device=0)

# Take a few samples from the original test dataset (before processing)
samples = test_dataset.select(range(5))

for i, sample in enumerate(samples):
    prediction = pipe(sample["audio"]["array"], generate_kwargs={"language": LANGUAGE})["text"]
    print(f"\nSample {i+1}:")
    print(f"Reference: {sample['sentence']}")
    print(f"Prediction: {prediction}")