# Fine-tuning Whisper Small (Robust Version)

This notebook fine-tunes `openai/whisper-small` on the Minangkabau language.

### üõ†Ô∏è Fixes in this version:
1.  **WAV Conversion:** Converts all MP3s to WAV before processing to prevent kernel crashes.
2.  **Safety Filtering:** Removes audio > 30 seconds to prevent Out-Of-Memory errors.
3.  **Evaluation:** Includes WER metrics and WandB logging.


In [None]:
# 1. Install Dependencies
!apt-get update -y && apt-get install -y ffmpeg
!pip install -q datasets transformers torchaudio evaluate jiwer accelerate tensorboard scikit-learn wandb

In [None]:
import os
import torch
import pandas as pd
import tarfile
import wandb
import subprocess
import glob
from pathlib import Path
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
from datasets import Dataset, Audio
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
import evaluate
from dataclasses import dataclass
from typing import Any, Dict, List, Union

# --- Configuration ---
MODEL_NAME = "openai/whisper-small"
LANGUAGE = "minangkabau"
TASK = "transcribe"
SAMPLING_RATE = 16000

# Paths
DATA_DIR = Path("/workspace/data")
DATA_DIR.mkdir(parents=True, exist_ok=True)

OUTPUT_DIR = Path("/workspace/whisper-minang-finetuned")
LOG_DIR = Path("/workspace/logs")

# Login to WandB
wandb.login()

In [None]:
# 2. Download & Extract Data
urls = [
    "https://huggingface.co/datasets/indonesian-nlp/librivox-indonesia/resolve/main/data/audio_train.tgz",
    "https://huggingface.co/datasets/indonesian-nlp/librivox-indonesia/resolve/main/data/audio_test.tgz",
    "https://huggingface.co/datasets/indonesian-nlp/librivox-indonesia/resolve/main/data/metadata_train.csv.gz",
    "https://huggingface.co/datasets/indonesian-nlp/librivox-indonesia/resolve/main/data/metadata_test.csv.gz"
]

for url in urls:
    filename = url.split("/")[-1]
    dest = DATA_DIR / filename
    if not dest.exists():
        print(f"Downloading {filename}...")
        !wget -nc {url} -P {DATA_DIR}

print("Extracting archives...")
for archive in ["audio_train.tgz", "audio_test.tgz"]:
    archive_path = DATA_DIR / archive
    # Simple check to see if we extracted already
    if not (DATA_DIR / "audio_train" / "librivox-indonesia").exists():
        with tarfile.open(archive_path, "r:gz") as tar:
            tar.extractall(path=DATA_DIR)
print("Extraction complete.")

In [None]:
# 3. CRITICAL FIX: Convert Audio to WAV (Prevents crashes)
AUDIO_ROOT = DATA_DIR / "audio_train" / "librivox-indonesia"
WAV_DIR = DATA_DIR / "converted_wav"
WAV_DIR.mkdir(parents=True, exist_ok=True)

def convert_to_wav(mp3_path):
    try:
        mp3_path = Path(mp3_path)
        relative_path = mp3_path.relative_to(AUDIO_ROOT)
        output_path = WAV_DIR / relative_path.with_suffix(".wav")
        output_path.parent.mkdir(parents=True, exist_ok=True)

        if output_path.exists():
            return str(output_path)

        # Convert to 16kHz Mono WAV
        cmd = [
            "ffmpeg", "-y", "-v", "error",
            "-i", str(mp3_path),
            "-ac", "1",
            "-ar", "16000",
            str(output_path)
        ]
        subprocess.run(cmd, check=True)
        return str(output_path)
    except Exception as e:
        return None

print("üîç Scanning for MP3 files...")
mp3_files = list(AUDIO_ROOT.rglob("*.mp3"))
print(f"Found {len(mp3_files)} files. Converting to WAV (this may take a moment)...")

# Run conversion in parallel
with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
    list(tqdm(executor.map(convert_to_wav, mp3_files), total=len(mp3_files)))

print("‚úÖ Audio conversion complete.")

In [None]:
# 4. Prepare Metadata & Dataset
train_meta_path = DATA_DIR / "metadata_train.csv.gz"
test_meta_path = DATA_DIR / "metadata_test.csv.gz"

df_train = pd.read_csv(train_meta_path)
df_test = pd.read_csv(test_meta_path)

# Filter for Minangkabau
df_train = df_train[df_train["language"] == "min"].copy()
df_test = df_test[df_test["language"] == "min"].copy()

# Update paths to point to the NEW WAV files
def get_wav_path(row):
    original_path = Path(row["path"])
    # Map original structure to our new WAV_DIR
    wav_path = WAV_DIR / original_path.with_suffix(".wav")
    return str(wav_path)

df_train["audio"] = df_train.apply(get_wav_path, axis=1)
df_test["audio"] = df_test.apply(get_wav_path, axis=1)

# Verify files exist
df_train = df_train[df_train["audio"].apply(os.path.exists)]
df_test = df_test[df_test["audio"].apply(os.path.exists)]

print(f"Training samples: {len(df_train)}")
print(f"Test samples: {len(df_test)}")

# Create Datasets
train_dataset = Dataset.from_pandas(df_train)
test_dataset = Dataset.from_pandas(df_test)

# Cast to Audio (now using safe WAVs)
train_dataset = train_dataset.cast_column("audio", Audio(sampling_rate=SAMPLING_RATE))
test_dataset = test_dataset.cast_column("audio", Audio(sampling_rate=SAMPLING_RATE))

# Train/Test Split
train_test_split = train_dataset.train_test_split(test_size=0.1)
train_dataset = train_test_split["train"]
eval_dataset = train_test_split["test"]

In [None]:
# 5. Filter Long Audio Files (Safety Step)
def is_audio_in_length_range(batch):
    audio = batch["audio"]
    return audio["array"].shape[0] < 30 * SAMPLING_RATE

print(f"Original training size: {len(train_dataset)}")
train_dataset = train_dataset.filter(is_audio_in_length_range, num_proc=1)
eval_dataset = eval_dataset.filter(is_audio_in_length_range, num_proc=1)
print(f"Filtered training size: {len(train_dataset)}")

In [None]:
# 6. Preprocessing
processor = WhisperProcessor.from_pretrained(MODEL_NAME, language=LANGUAGE, task=TASK)

def prepare_dataset(batch):
    audio = batch["audio"]
    batch["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    batch["labels"] = processor.tokenizer(batch["sentence"]).input_ids
    return batch

# We use num_proc=1 to be absolutely safe, but with WAVs you could try os.cpu_count()
train_dataset = train_dataset.map(prepare_dataset, remove_columns=train_dataset.column_names, num_proc=1)
eval_dataset = eval_dataset.map(prepare_dataset, remove_columns=eval_dataset.column_names, num_proc=1)
test_dataset_processed = test_dataset.map(prepare_dataset, remove_columns=test_dataset.column_names, num_proc=1)

In [None]:
# 7. Data Collator & Metrics
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]
        batch["labels"] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
metric = evaluate.load("jiwer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    wer = metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}

In [None]:
# 8. Training Setup
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME)
model.config.use_cache = False
model.generation_config.language = LANGUAGE
model.generation_config.task = TASK

training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=2,
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=4000,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=500,
    eval_steps=500,
    logging_steps=25,
    report_to=["wandb"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    run_name="whisper-minangkabau-fixed"
)

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

In [None]:
# 9. Train
trainer.train()

In [None]:
# 10. Save & Finish
wandb.finish()
trainer.save_model(OUTPUT_DIR / "final_model")
processor.save_pretrained(OUTPUT_DIR / "final_model")
print("Training Complete and Model Saved!")

In [None]:
# 11. Final Evaluation
print("Evaluating on Test Set...")
test_metrics = trainer.evaluate(test_dataset_processed)
print(f"Test WER: {test_metrics['eval_wer']:.2%}")