# LahStats - LoRA Fine-tuning for Singlish ASR

This notebook fine-tunes **MERaLiON-2-10B-ASR** using LoRA adapters for better Singlish word recognition.

## What This Does
- Loads MERaLiON-2-10B-ASR with 8-bit quantization (fits on Colab T4/A100)
- Applies LoRA adapters to attention layers (~1% trainable params)
- Trains on your team's Singlish recordings
- Evaluates with Word Error Rate (WER)
- Saves lightweight adapter (~50MB) for easy deployment

## Expected Results
- Training time: 2-3 hours on T4, 30-60 min on A100
- Expected improvement: +5-15% accuracy on Singlish words
- Output: LoRA adapter weights (~50MB)

## Data Format
Place your data in Google Drive:
```
lahstats_data/
  audio_001.wav
  audio_002.wav
  transcripts.json  # {"audio_001.wav": "walao eh why like that sia", ...}
```

In [None]:
# Cell 1: Install Dependencies
!pip install -q \
    peft>=0.7.0 \
    transformers>=4.36.0 \
    bitsandbytes>=0.41.0 \
    accelerate>=0.25.0 \
    datasets>=2.14.0 \
    librosa>=0.10.0 \
    soundfile>=0.12.0 \
    evaluate>=0.4.0 \
    jiwer>=3.0.0 \
    torch \
    torchaudio

print("Installation complete!")

In [None]:
# Cell 2: Verify GPU
import torch

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"GPU: {gpu_name}")
    print(f"Memory: {gpu_memory:.1f} GB")
else:
    print("WARNING: No GPU detected!")
    print("Go to Runtime > Change runtime type > GPU")

In [None]:
# Cell 3: Mount Google Drive & Load Data
from google.colab import drive
drive.mount('/content/drive')

import os
import json

# === CHANGE THIS PATH TO YOUR DATA FOLDER ===
DATA_ROOT = "/content/drive/MyDrive/lahstats_data"

if os.path.exists(DATA_ROOT):
    print(f"Found data folder: {DATA_ROOT}")
    files = os.listdir(DATA_ROOT)
    audio_files = [f for f in files if f.endswith(('.wav', '.mp3', '.m4a'))]
    print(f"Audio files: {len(audio_files)}")
    
    # Load transcripts
    transcript_path = os.path.join(DATA_ROOT, "transcripts.json")
    if os.path.exists(transcript_path):
        with open(transcript_path) as f:
            transcripts = json.load(f)
        print(f"Loaded {len(transcripts)} transcripts")
    else:
        print(f"ERROR: transcripts.json not found at {transcript_path}")
        print("Create a JSON file mapping audio filenames to transcriptions")
else:
    print(f"ERROR: Data folder not found: {DATA_ROOT}")
    print("Update DATA_ROOT to point to your data folder")

In [None]:
# Cell 4: Configuration
# =============================================================================
# CONFIGURATION - Based on research doc recommendations
# See: .planning/phases/lora-finetuning/RESEARCH.md
# =============================================================================

MODEL_NAME = "MERaLiON/MERaLiON-2-10B-ASR"

# LoRA Settings (from research - r=32, alpha=64 recommended for ASR)
LORA_R = 32              # Rank - optimal for ~1000 samples
LORA_ALPHA = 64          # Scaling factor - standard is 2x rank
LORA_DROPOUT = 0.05      # Light regularization
TARGET_MODULES = ["q_proj", "v_proj"]  # Minimum effective set for ASR

# Training Settings
BATCH_SIZE = 4           # Reduce to 2 if you get OOM errors
GRADIENT_ACCUMULATION = 4  # Effective batch size = 4 * 4 = 16
LEARNING_RATE = 1e-4     # Higher than full fine-tuning (which uses 1e-5)
NUM_EPOCHS = 3           # More epochs risk overfitting on small datasets
WARMUP_STEPS = 50        # ~10% of total steps

# Paths
OUTPUT_DIR = "/content/drive/MyDrive/lahstats_lora_checkpoints"
FINAL_ADAPTER_DIR = "/content/drive/MyDrive/lahstats_lora_adapter"

# Evaluation
EVAL_STEPS = 100
SAVE_STEPS = 100
VAL_SPLIT = 0.15  # 15% for validation

print(f"Config loaded:")
print(f"  Model: {MODEL_NAME}")
print(f"  LoRA: r={LORA_R}, alpha={LORA_ALPHA}")
print(f"  Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION}")
print(f"  Learning rate: {LEARNING_RATE}")

In [None]:
# Cell 5: Load Model with 8-bit Quantization
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

print("Loading MERaLiON with 8-bit quantization...")
print("This may take a few minutes on first run (downloading ~20GB)")

# 8-bit quantization config - reduces memory by ~50%
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
)

# Load model
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

# Load processor
processor = AutoProcessor.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
)

print(f"Model loaded: {model.num_parameters() / 1e9:.1f}B parameters")

In [None]:
# Cell 6: Apply LoRA Adapters
# Prepare model for k-bit training (handles gradient checkpointing etc)
model = prepare_model_for_kbit_training(model)
model.config.use_cache = False  # Required for gradient checkpointing

# Define LoRA configuration
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=TARGET_MODULES,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="SEQ_2_SEQ_LM"  # For encoder-decoder ASR models
)

# Apply LoRA
model = get_peft_model(model, lora_config)

# Show trainable parameters
model.print_trainable_parameters()
# Expected output: ~1% trainable parameters

In [None]:
# Cell 7: Prepare Dataset
from datasets import Dataset, Audio
import os

def load_audio_dataset(data_root, transcripts):
    """
    Load audio files and transcripts into a HuggingFace Dataset.
    
    Expected transcripts format:
    {"audio_001.wav": "walao eh the food damn shiok sia", ...}
    """
    data = []
    missing = []
    
    for audio_file, transcript in transcripts.items():
        audio_path = os.path.join(data_root, audio_file)
        if os.path.exists(audio_path):
            data.append({
                "audio": audio_path,
                "transcript": transcript.lower().strip()
            })
        else:
            missing.append(audio_file)
    
    if missing:
        print(f"Warning: {len(missing)} audio files not found")
        print(f"  First few: {missing[:5]}")
    
    # Create dataset and cast audio column
    dataset = Dataset.from_list(data)
    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
    
    return dataset

# Load dataset
full_dataset = load_audio_dataset(DATA_ROOT, transcripts)
print(f"Total samples: {len(full_dataset)}")

# Split into train/validation
dataset_split = full_dataset.train_test_split(test_size=VAL_SPLIT, seed=42)
train_dataset = dataset_split["train"]
val_dataset = dataset_split["test"]

print(f"Train: {len(train_dataset)}, Validation: {len(val_dataset)}")

In [None]:
# Cell 8: Preprocess Audio
def prepare_dataset(batch):
    """
    Preprocess audio for MERaLiON:
    - Extract log-Mel spectrogram features
    - Tokenize transcript as labels
    """
    audio = batch["audio"]
    
    # Extract features using processor's feature extractor
    batch["input_features"] = processor.feature_extractor(
        audio["array"],
        sampling_rate=audio["sampling_rate"]
    ).input_features[0]
    
    # Tokenize transcript
    batch["labels"] = processor.tokenizer(batch["transcript"]).input_ids
    
    return batch

# Apply preprocessing
print("Preprocessing training data...")
train_dataset = train_dataset.map(
    prepare_dataset,
    remove_columns=train_dataset.column_names,
    num_proc=1  # Use 1 for Colab stability
)

print("Preprocessing validation data...")
val_dataset = val_dataset.map(
    prepare_dataset,
    remove_columns=val_dataset.column_names,
    num_proc=1
)

print("Preprocessing complete!")

In [None]:
# Cell 9: Data Collator
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import torch

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    """
    Data collator that handles padding for speech-to-text models.
    Pads input features and labels to batch max length.
    """
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Pad input features (audio spectrograms)
        input_features = [{"input_features": f["input_features"]} for f in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # Pad labels (text tokens)
        label_features = [{"input_ids": f["labels"]} for f in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # Replace padding token id with -100 so it's ignored in loss calculation
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        batch["labels"] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
print("Data collator ready!")

In [None]:
# Cell 10: Evaluation Metrics
import evaluate

# Load Word Error Rate metric
wer_metric = evaluate.load("wer")

def compute_metrics(pred):
    """
    Compute Word Error Rate (WER) for evaluation.
    Lower is better - 0% means perfect transcription.
    """
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    
    # Replace -100 with pad token id for decoding
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    
    # Decode predictions and labels
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    
    # Compute WER (multiply by 100 for percentage)
    wer = 100 * wer_metric.compute(predictions=pred_str, references=label_str)
    
    return {"wer": wer}

print("Metrics ready!")

In [None]:
# Cell 11: Training Setup
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
import os

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    
    # Batch size
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION,
    
    # Learning rate
    learning_rate=LEARNING_RATE,
    warmup_steps=WARMUP_STEPS,
    
    # Training duration
    num_train_epochs=NUM_EPOCHS,
    
    # Memory optimization
    fp16=True,  # Mixed precision training
    gradient_checkpointing=True,  # Trade compute for memory
    
    # Evaluation
    eval_strategy="steps",
    eval_steps=EVAL_STEPS,
    
    # Checkpointing
    save_steps=SAVE_STEPS,
    save_total_limit=3,  # Keep only last 3 checkpoints
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,  # Lower WER is better
    
    # Generation settings for evaluation
    predict_with_generate=True,
    generation_max_length=225,
    
    # Logging
    logging_steps=25,
    report_to="tensorboard",
    
    # Other
    remove_unused_columns=False,
    label_names=["labels"],
    seed=42,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

print("Trainer ready!")
print(f"Checkpoints will be saved to: {OUTPUT_DIR}")

In [None]:
# Cell 12: TRAIN!
print("=" * 60)
print("STARTING TRAINING")
print("=" * 60)
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Epochs: {NUM_EPOCHS}")
print(f"Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION}")
print("=" * 60)
print("")

# Train!
trainer.train()

print("")
print("=" * 60)
print("TRAINING COMPLETE!")
print("=" * 60)

In [None]:
# Cell 13: Save Final Adapter
import os

os.makedirs(FINAL_ADAPTER_DIR, exist_ok=True)

# Save LoRA adapter weights (~50MB)
model.save_pretrained(FINAL_ADAPTER_DIR)

# Save processor for inference
processor.save_pretrained(FINAL_ADAPTER_DIR)

print(f"")
print(f"LoRA adapter saved to: {FINAL_ADAPTER_DIR}")
print(f"")
print(f"Contents:")
for f in os.listdir(FINAL_ADAPTER_DIR):
    size = os.path.getsize(os.path.join(FINAL_ADAPTER_DIR, f)) / 1024 / 1024
    print(f"  {f}: {size:.1f} MB")

In [None]:
# Cell 14: Test the Fine-tuned Model
import librosa
import numpy as np

def transcribe_test(audio_path):
    """Test transcription with the fine-tuned model."""
    # Load audio
    audio, sr = librosa.load(audio_path, sr=16000)
    
    # Prepare input
    inputs = processor(
        audio,
        sampling_rate=16000,
        return_tensors="pt"
    )
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Generate
    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_new_tokens=256)
    
    # Decode
    transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return transcription

# Test with first audio file
test_file = list(transcripts.keys())[0]
test_path = os.path.join(DATA_ROOT, test_file)

if os.path.exists(test_path):
    print(f"Testing with: {test_file}")
    print(f"Expected: {transcripts[test_file]}")
    print(f"Got: {transcribe_test(test_path)}")
else:
    print(f"Test file not found: {test_path}")

## How to Use Your Trained Adapter

```python
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from peft import PeftModel

# Load base model
base_model = AutoModelForSpeechSeq2Seq.from_pretrained(
    "MERaLiON/MERaLiON-2-10B-ASR",
    device_map="auto",
    trust_remote_code=True,
)

# Load your LoRA adapter
model = PeftModel.from_pretrained(base_model, "./lahstats_lora_adapter")
processor = AutoProcessor.from_pretrained("./lahstats_lora_adapter")

# Optional: Merge adapter into base model for faster inference
# model = model.merge_and_unload()

# Use for transcription!
```