# Whisper Fine-tuning with LoRA for Indic Languages

This notebook fine-tunes Whisper for Indic languages using LoRA (Low-Rank Adaptation) for memory efficiency.

**Features:**
- LoRA fine-tuning (memory efficient)
- Mixed precision (fp16)
- Gradient checkpointing
- WER evaluation

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Install dependencies
!pip install -q transformers datasets accelerate peft evaluate jiwer
!pip install -q soundfile librosa

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

DATA_DIR = '/content/drive/MyDrive/indic_speech_data'
OUTPUT_DIR = '/content/drive/MyDrive/indic_stt_models'

import os
os.makedirs(OUTPUT_DIR, exist_ok=True)

## 1. Configuration

In [None]:
import torch

# Configuration
MODEL_NAME = 'openai/whisper-medium'  # or whisper-small for faster training
LANGUAGE = 'hi'  # hi, ta, te, bn, etc.
BATCH_SIZE = 4
LEARNING_RATE = 1e-5
EPOCHS = 5
MAX_STEPS = 5000

# LoRA config
LORA_R = 32
LORA_ALPHA = 64
LORA_DROPOUT = 0.1

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## 2. Load Model and Apply LoRA

In [None]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from peft import LoraConfig, get_peft_model, TaskType

# Load processor
processor = WhisperProcessor.from_pretrained(MODEL_NAME)

# Load model
model = WhisperForConditionalGeneration.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
)

# Enable gradient checkpointing
model.config.use_cache = False
model.gradient_checkpointing_enable()

print(f"Base model loaded: {MODEL_NAME}")

In [None]:
# Apply LoRA
lora_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    target_modules=['q_proj', 'v_proj', 'k_proj', 'out_proj', 'fc1', 'fc2'],
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

model = model.to(device)

## 3. Prepare Dataset

In [None]:
from datasets import load_dataset, Audio
import json

def load_manifest_dataset(manifest_path):
    """Load dataset from JSONL manifest."""
    dataset = load_dataset('json', data_files={'data': manifest_path})['data']
    dataset = dataset.cast_column('audio_filepath', Audio(sampling_rate=16000))
    return dataset

# Load datasets
train_manifest = f"{DATA_DIR}/manifests/stt_{LANGUAGE}_train.jsonl"
val_manifest = f"{DATA_DIR}/manifests/stt_{LANGUAGE}_val.jsonl"

if os.path.exists(train_manifest):
    train_dataset = load_manifest_dataset(train_manifest)
    val_dataset = load_manifest_dataset(val_manifest) if os.path.exists(val_manifest) else None
    print(f"Train: {len(train_dataset)} samples")
    if val_dataset:
        print(f"Val: {len(val_dataset)} samples")
else:
    print(f"Manifest not found: {train_manifest}")
    print("Run notebook 01 first to prepare data.")

In [None]:
def prepare_dataset(batch):
    """Prepare batch for training."""
    audio = batch['audio_filepath']
    
    # Compute input features
    input_features = processor(
        audio['array'],
        sampling_rate=audio['sampling_rate'],
        return_tensors='pt',
    ).input_features[0]
    
    # Encode labels
    labels = processor.tokenizer(batch['text']).input_ids
    
    return {
        'input_features': input_features,
        'labels': labels,
    }

# Process datasets
train_processed = train_dataset.map(
    prepare_dataset,
    remove_columns=train_dataset.column_names,
)

if val_dataset:
    val_processed = val_dataset.map(
        prepare_dataset,
        remove_columns=val_dataset.column_names,
    )

## 4. Training

In [None]:
from dataclasses import dataclass
from typing import Any, Dict, List

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: WhisperProcessor
    
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        input_features = [{'input_features': f['input_features']} for f in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors='pt')
        
        label_features = [{'input_ids': f['labels']} for f in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors='pt')
        
        labels = labels_batch['input_ids'].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )
        
        batch['labels'] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [None]:
import evaluate
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

# Load WER metric
wer_metric = evaluate.load('wer')

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    return {'wer': wer}

# Training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir=f"{OUTPUT_DIR}/whisper_{LANGUAGE}_lora",
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    num_train_epochs=EPOCHS,
    max_steps=MAX_STEPS,
    warmup_steps=500,
    gradient_accumulation_steps=4,
    fp16=True,
    eval_strategy='steps',
    eval_steps=500,
    save_strategy='steps',
    save_steps=500,
    logging_steps=50,
    predict_with_generate=True,
    generation_max_length=225,
    load_best_model_at_end=True,
    metric_for_best_model='wer',
    greater_is_better=False,
    report_to=['tensorboard'],
)

In [None]:
# Create trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_processed,
    eval_dataset=val_processed if val_dataset else None,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

# Train
print("Starting training...")
trainer.train()

## 5. Evaluation

In [None]:
# Evaluate
if val_dataset:
    results = trainer.evaluate()
    print(f"\nEvaluation Results:")
    print(f"  WER: {results['eval_wer']:.2%}")

In [None]:
# Test transcription
import soundfile as sf
import librosa
import IPython.display as ipd

# Get a sample
sample = train_dataset[0]
audio = sample['audio_filepath']['array']
sr = sample['audio_filepath']['sampling_rate']

# Play audio
print(f"Ground truth: {sample['text']}")
ipd.Audio(audio, rate=sr)

In [None]:
# Transcribe
model.eval()

input_features = processor(
    audio,
    sampling_rate=sr,
    return_tensors='pt',
).input_features.to(device)

with torch.no_grad():
    predicted_ids = model.generate(input_features, max_length=225)

transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
print(f"Transcription: {transcription}")

## 6. Save Model

In [None]:
# Save LoRA adapters
lora_path = f"{OUTPUT_DIR}/whisper_{LANGUAGE}_lora_adapters"
model.save_pretrained(lora_path)
print(f"LoRA adapters saved to {lora_path}")

# Merge and save full model
merged_model = model.merge_and_unload()
merged_path = f"{OUTPUT_DIR}/whisper_{LANGUAGE}_merged"
merged_model.save_pretrained(merged_path)
processor.save_pretrained(merged_path)
print(f"Merged model saved to {merged_path}")