In [1]:
#!/usr/bin/env python3
"""
=============================================================================
  Improved Telugu ASR Fine-Tuning Pipeline
  - IndicWav2Vec / WavLM / XLSR-1B base models
  - Frozen feature encoder
  - Proper CTC training with SpecAugment
  - KenLM n-gram language model + beam search decoding
  - Whisper fine-tuning comparison
  - Proper evaluation with WER and CER
=============================================================================
  INSTRUCTIONS:
  - Run this on Kaggle/Colab with a GPU (T4 minimum, A100 preferred)
  - Adjust `BASE_PATH` to your dataset location
  - Set `HF_TOKEN` to your Hugging Face write token
  - Choose your model in the CONFIG section
=============================================================================
"""




# üîß Section 1: Setup & Installation


In [2]:
# !pip install -q transformers datasets accelerate evaluate jiwer
# !pip install -q pyctcdecode kenlm
# !pip install -q librosa soundfile torchaudio
# !pip install -q sentencepiece
# !pip install -q huggingface_hub


# ‚öôÔ∏è Section 2: Configuration


In [3]:
import os

# ===================== CONFIGURATION =====================
# Choose one of the base models:
#   "ai4bharat/indicwav2vec_v1_telugu"        - Best for Telugu (Indian language pre-trained)
#   "facebook/wav2vec2-xls-r-1b"              - 1B param XLSR (strong multilingual)
#   "facebook/wav2vec2-xls-r-300m"            - 300M param XLSR (your current)
#   "microsoft/wavlm-large"                   - WavLM (often better than wav2vec2)
#   "facebook/wav2vec2-large-xlsr-53"          - XLSR-53 (your current)
BASE_MODEL = "facebook/wav2vec2-xls-r-300m"

# Dataset path - CHANGE THIS to match your environment
BASE_PATH = "/kaggle/input/datasets/rishiakkiraju/telugu-microsoft-corpus-major-project/telugu_microsoft_corpus/microsoftspeechcorpusindianlanguages"

# Hugging Face token (for pushing to hub - optional)
HF_TOKEN = "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"  # Set your token here or use huggingface-cli login

# Training config
NUM_EPOCHS = 30
LEARNING_RATE = 1e-4
BATCH_SIZE = 8
GRADIENT_ACCUMULATION = 4  # effective batch = BATCH_SIZE * GRADIENT_ACCUMULATION = 32
WARMUP_RATIO = 0.1
FREEZE_FEATURE_ENCODER = True
SAMPLING_RATE = 16000
MAX_AUDIO_LENGTH_SECONDS = 20.0

# Output directories
OUTPUT_DIR = "./results_improved"
LOGGING_DIR = "./logs_improved"
LM_DIR = "./language_model"
# =========================================================

print(f"Base model: {BASE_MODEL}")
print(f"Dataset path: {BASE_PATH}")
print(f"Epochs: {NUM_EPOCHS}, LR: {LEARNING_RATE}, Effective batch: {BATCH_SIZE * GRADIENT_ACCUMULATION}")



Base model: facebook/wav2vec2-xls-r-300m
Dataset path: /kaggle/input/datasets/rishiakkiraju/telugu-microsoft-corpus-major-project/telugu_microsoft_corpus/microsoftspeechcorpusindianlanguages
Epochs: 30, LR: 0.0001, Effective batch: 32


# üìÇ Section 3: Dataset Loading & Exploration


In [4]:
import pandas as pd
import os

# Define dataset paths
train_path = os.path.join(BASE_PATH, "te-in-Train")
test_path = os.path.join(BASE_PATH, "te-in-Test")
validation_path = os.path.join(BASE_PATH, "te-in-Measurement")

def load_data(data_path, transcription_file, audio_subfolder="Audios"):
    """Load dataset with proper audio path construction."""
    df = pd.read_csv(
        os.path.join(data_path, transcription_file),
        delimiter=",",
        encoding="utf-8-sig",  # handles BOM character
    )
    df["audio_path"] = df["audio_path_file"].apply(
        lambda x: os.path.join(data_path, audio_subfolder, x)
    )
    return df

# Load all splits
train_df = load_data(train_path, "train_transcriptions.csv", audio_subfolder="Audios")
test_df = load_data(test_path, "test_transcriptions.csv", audio_subfolder="Audios")
validation_df = load_data(
    validation_path, "validation_transcriptions.csv",
    audio_subfolder="Audios_with_Transcriptions"
)

# Remove validation files from test set (avoid leakage)
validation_files = set(validation_df["audio_path_file"].tolist())
test_df = test_df[~test_df["audio_path_file"].isin(validation_files)].reset_index(drop=True)

print(f"Train samples:      {len(train_df):,}")
print(f"Test samples:       {len(test_df):,}")
print(f"Validation samples: {len(validation_df):,}")


Train samples:      44,882
Test samples:       2,640
Validation samples: 400


In [5]:
# Quick data exploration
print("\n--- Sample transcriptions ---")
for i in range(3):
    print(f"  [{i}] {train_df.iloc[i]['transcription']}")

print(f"\n--- Transcription length stats (characters) ---")
lengths = train_df["transcription"].str.len()
print(f"  Mean: {lengths.mean():.0f}, Median: {lengths.median():.0f}, "
      f"Min: {lengths.min()}, Max: {lengths.max()}")

# Check for any NaN or empty transcriptions
nan_count = train_df["transcription"].isna().sum()
empty_count = (train_df["transcription"].str.strip() == "").sum()
print(f"\n--- Data quality ---")
print(f"  NaN transcriptions: {nan_count}")
print(f"  Empty transcriptions: {empty_count}")




--- Sample transcriptions ---
  [0] ‡∞ï‡∞ö‡±ç‡∞ö‡∞ø‡∞§‡∞Ç‡∞ó‡∞æ ‡∞ö‡±Ç‡∞™‡∞ø‡∞∏‡±ç‡∞§‡±Å‡∞Ç‡∞¶‡∞ø ‡∞ï‡∞¶‡∞æ ‡∞Æ‡∞∞‡∞ø
  [1] ‡∞Ö ‡∞ö‡∞∞‡∞£‡±ç ‡∞ï‡∞¶‡∞æ ‡∞§‡±Ü‡∞≤‡±Å‡∞∏‡±Å
  [2] ‡∞ö‡∞™‡±ç‡∞™‡∞æ‡∞≤‡∞Ç‡∞ü‡±á ‡∞ö‡∞æ‡∞≤‡∞æ ‡∞â‡∞Ç‡∞ü‡∞æ‡∞Ø‡∞ø ‡∞ó‡∞æ‡∞®‡∞ø

--- Transcription length stats (characters) ---
  Mean: 41, Median: 30, Min: 1, Max: 260

--- Data quality ---
  NaN transcriptions: 0
  Empty transcriptions: 0


# üßπ Section 4: Text Normalization & Vocabulary


In [6]:
import re

def normalize_telugu_text(text):
    """
    Normalize Telugu transcriptions:
    - Keep Telugu Unicode characters (U+0C00 - U+0C7F)
    - Keep spaces (word boundaries)
    - Remove all other characters (English, punctuation, etc.)
    - Collapse multiple spaces
    - Strip leading/trailing whitespace
    """
    if pd.isna(text) or not isinstance(text, str):
        return ""
    # Keep only Telugu chars and whitespace
    text = re.sub(r'[^\u0C00-\u0C7F\s]', '', text)
    # Collapse multiple spaces
    text = re.sub(r'\s+', ' ', text).strip()
    return text

# Apply normalization ONCE (no double replacement)
train_df["transcription"] = train_df["transcription"].apply(normalize_telugu_text)
test_df["transcription"] = test_df["transcription"].apply(normalize_telugu_text)
validation_df["transcription"] = validation_df["transcription"].apply(normalize_telugu_text)

# Remove any rows with empty transcriptions after normalization
train_df = train_df[train_df["transcription"].str.len() > 0].reset_index(drop=True)
test_df = test_df[test_df["transcription"].str.len() > 0].reset_index(drop=True)
validation_df = validation_df[validation_df["transcription"].str.len() > 0].reset_index(drop=True)

print(f"After cleaning - Train: {len(train_df):,}, Test: {len(test_df):,}, Val: {len(validation_df):,}")


After cleaning - Train: 44,882, Test: 2,640, Val: 400


In [7]:
# Build vocabulary from all splits
all_text = " ".join(
    train_df["transcription"].tolist() +
    test_df["transcription"].tolist() +
    validation_df["transcription"].tolist()
)
vocab_chars = sorted(set(all_text))
print(f"Unique characters in dataset: {len(vocab_chars)}")
print(f"Characters: {vocab_chars}")

# Build vocab dict with special tokens
vocab_dict = {char: idx for idx, char in enumerate(vocab_chars)}
# Replace space with word delimiter token |
vocab_dict["|"] = vocab_dict.pop(" ")
# Add special tokens
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)

print(f"\nVocab size (with special tokens): {len(vocab_dict)}")


Unique characters in dataset: 65
Characters: [' ', '‡∞Ç', '‡∞É', '‡∞Ö', '‡∞Ü', '‡∞á', '‡∞à', '‡∞â', '‡∞ä', '‡∞ã', '‡∞é', '‡∞è', '‡∞ê', '‡∞í', '‡∞ì', '‡∞î', '‡∞ï', '‡∞ñ', '‡∞ó', '‡∞ò', '‡∞ô', '‡∞ö', '‡∞õ', '‡∞ú', '‡∞ù', '‡∞û', '‡∞ü', '‡∞†', '‡∞°', '‡∞¢', '‡∞£', '‡∞§', '‡∞•', '‡∞¶', '‡∞ß', '‡∞®', '‡∞™', '‡∞´', '‡∞¨', '‡∞≠', '‡∞Æ', '‡∞Ø', '‡∞∞', '‡∞≤', '‡∞≥', '‡∞µ', '‡∞∂', '‡∞∑', '‡∞∏', '‡∞π', '‡∞æ', '‡∞ø', '‡±Ä', '‡±Å', '‡±Ç', '‡±É', '‡±Ü', '‡±á', '‡±à', '‡±ä', '‡±ã', '‡±å', '‡±ç', '‡±ñ', '‡±¶']

Vocab size (with special tokens): 67


In [8]:
import json

os.makedirs(OUTPUT_DIR, exist_ok=True)
vocab_path = os.path.join(OUTPUT_DIR, "vocab.json")
with open(vocab_path, "w", encoding="utf-8") as f:
    json.dump(vocab_dict, f, ensure_ascii=False)

print(f"Vocabulary saved to {vocab_path}")



Vocabulary saved to ./results_improved/vocab.json


# üî§ Section 5: Tokenizer, Feature Extractor & Processor


In [9]:
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor

# Create tokenizer from our vocabulary
tokenizer = Wav2Vec2CTCTokenizer(
    vocab_path,
    unk_token="[UNK]",
    pad_token="[PAD]",
    word_delimiter_token="|",
)

# Feature extractor - standard for wav2vec2 / WavLM
feature_extractor = Wav2Vec2FeatureExtractor(
    feature_size=1,
    sampling_rate=SAMPLING_RATE,
    padding_value=0.0,
    do_normalize=True,
    return_attention_mask=True,
)

# Combined processor
processor = Wav2Vec2Processor(
    feature_extractor=feature_extractor,
    tokenizer=tokenizer,
)

print(f"Tokenizer vocab size: {tokenizer.vocab_size}")



2026-02-19 00:43:33.640706: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1771461813.667012     151 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1771461813.674870     151 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1771461813.695669     151 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1771461813.695705     151 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1771461813.695708     151 computation_placer.cc:177] computation placer alr

Tokenizer vocab size: 67


# üìä Section 6: Dataset Preparation (HuggingFace Datasets)


In [10]:
import librosa
import numpy as np
from datasets import Dataset

def prepare_dataset(batch):
    """
    Load audio, resample to 16kHz, extract features, and tokenize labels.
    This runs via Dataset.map() ‚Äî processes one example at a time.
    """
    try:
        # Load audio
        audio, sr = librosa.load(batch["audio_path"], sr=SAMPLING_RATE)
    except Exception as e:
        print(f"‚ö†Ô∏è Error loading {batch['audio_path']}: {e}")
        # Return a short silence as fallback
        audio = np.zeros(SAMPLING_RATE, dtype=np.float32)

    # Truncate to max length
    max_samples = int(MAX_AUDIO_LENGTH_SECONDS * SAMPLING_RATE)
    if len(audio) > max_samples:
        audio = audio[:max_samples]

    # Process audio through feature extractor
    input_values = processor(
        audio, sampling_rate=SAMPLING_RATE, return_tensors="np"
    ).input_values[0]

    batch["input_values"] = input_values
    batch["input_length"] = len(input_values)

    # Tokenize transcription (updated API ‚Äî no more as_target_processor)
    batch["labels"] = processor.tokenizer(batch["transcription"]).input_ids

    return batch

print("Converting to HuggingFace Datasets and processing audio...")
print("(This will take a while for 44K+ training samples)")
print("NOTE: Using single process to avoid OOM subprocess crashes.")

# Convert DataFrames to Datasets
train_dataset = Dataset.from_pandas(train_df[["audio_path", "transcription"]])
test_dataset = Dataset.from_pandas(test_df[["audio_path", "transcription"]])
val_dataset = Dataset.from_pandas(validation_df[["audio_path", "transcription"]])

# Process datasets ‚Äî single process to avoid OOM crashes
# On Colab with high-RAM runtime, you can try num_proc=2
train_dataset = train_dataset.map(
    prepare_dataset,
    remove_columns=["audio_path", "transcription"],
)
test_dataset = test_dataset.map(
    prepare_dataset,
    remove_columns=["audio_path", "transcription"],
)
val_dataset = val_dataset.map(
    prepare_dataset,
    remove_columns=["audio_path", "transcription"],
)

print(f"Train dataset: {len(train_dataset)} samples")
print(f"Test dataset:  {len(test_dataset)} samples")
print(f"Val dataset:   {len(val_dataset)} samples")



Converting to HuggingFace Datasets and processing audio...
(This will take a while for 44K+ training samples)
NOTE: Using single process to avoid OOM subprocess crashes.


Map:   0%|          | 0/44882 [00:00<?, ? examples/s]

Map:   0%|          | 0/2640 [00:00<?, ? examples/s]

Map:   0%|          | 0/400 [00:00<?, ? examples/s]

Train dataset: 44882 samples
Test dataset:  2640 samples
Val dataset:   400 samples


# üèóÔ∏è Section 7: Model Setup


In [11]:
import torch
from transformers import (
    Wav2Vec2ForCTC,
    AutoConfig,
)

# Load model with CTC head
# We configure the CTC head to match our vocabulary
config = AutoConfig.from_pretrained(BASE_MODEL)
config.update({
    "vocab_size": len(processor.tokenizer),
    "ctc_loss_reduction": "mean",
    "pad_token_id": processor.tokenizer.pad_token_id,
    "ctc_zero_infinity": True,  # Prevents NaN loss
    # SpecAugment configuration (data augmentation during training)
    "mask_time_prob": 0.05,
    "mask_time_length": 10,
    "mask_feature_prob": 0.004,
    "mask_feature_length": 10,
    # Regularization
    "attention_dropout": 0.1,
    "hidden_dropout": 0.1,
    "feat_proj_dropout": 0.0,
    "layerdrop": 0.1,
})

model = Wav2Vec2ForCTC.from_pretrained(
    BASE_MODEL,
    config=config,
    ignore_mismatched_sizes=True,  # CTC head size differs
)

# *** CRITICAL: Freeze the feature encoder ***
if FREEZE_FEATURE_ENCODER:
    model.freeze_feature_encoder()
    print("‚úÖ Feature encoder FROZEN (only transformer + CTC head will be fine-tuned)")

# Print trainable parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters:     {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Frozen parameters:    {total_params - trainable_params:,}")



config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-xls-r-300m and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


‚úÖ Feature encoder FROZEN (only transformer + CTC head will be fine-tuned)
Total parameters:     315,509,445
Trainable parameters: 311,299,269
Frozen parameters:    4,210,176


# üì¶ Section 8: Data Collator


In [12]:
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import torch

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that dynamically pads the inputs and labels.
    - Pads input_values to the longest in the batch
    - Pads labels to the longest in the batch (with -100 for CTC ignore)
    """
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Separate input_values and labels
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        # Pad inputs
        batch = self.processor.feature_extractor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            return_tensors="pt",
        )

        # Pad labels
        labels_batch = self.processor.tokenizer.pad(
            label_features,
            padding=self.padding,
            return_tensors="pt",
        )

        # Replace padding token id with -100 so CTC loss ignores them
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        batch["labels"] = labels
        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
print("‚úÖ Data collator ready")



‚úÖ Data collator ready


# üìà Section 9: Metrics (WER + CER)


In [13]:
!pip install evaluate 
!pip install jiwer
from IPython.display import clear_output
clear_output()

In [14]:
import evaluate

wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")

def compute_metrics(pred):
    """Compute WER and CER during evaluation."""
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    # Replace -100 with pad token id for decoding
    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    # Decode predictions and references
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    # Filter out empty strings
    filtered = [(p, l) for p, l in zip(pred_str, label_str) if len(l.strip()) > 0]
    if not filtered:
        return {"wer": 1.0, "cer": 1.0}
    pred_str, label_str = zip(*filtered)

    wer = wer_metric.compute(predictions=list(pred_str), references=list(label_str))
    cer = cer_metric.compute(predictions=list(pred_str), references=list(label_str))

    return {"wer": wer, "cer": cer}

print("‚úÖ Metrics ready (WER + CER)")



Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

‚úÖ Metrics ready (WER + CER)


# üöÄ Section 10: Training


In [15]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION,
    learning_rate=LEARNING_RATE,
    warmup_ratio=WARMUP_RATIO,
    weight_decay=0.005,
    fp16=True,  # Mixed precision training

    # *** CRITICAL: group_by_length reduces padding waste ***
    group_by_length=True,
    length_column_name="input_length",

    # Evaluation & saving strategy
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="steps",
    logging_steps=50,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,

    # Performance
    dataloader_num_workers=4,
    remove_unused_columns=False,
    report_to="none",
    logging_dir=LOGGING_DIR,
    disable_tqdm=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,  # Use VALIDATION set for eval, NOT test set
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

print("‚úÖ Trainer configured")
print(f"   Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION}")
print(f"   Total training steps: ~{len(train_dataset) * NUM_EPOCHS // (BATCH_SIZE * GRADIENT_ACCUMULATION):,}")


  trainer = Trainer(


‚úÖ Trainer configured
   Effective batch size: 32
   Total training steps: ~42,076


In [16]:
# Start training!
print("üèãÔ∏è Starting training...")
trainer.train()


üèãÔ∏è Starting training...


OutOfMemoryError: Caught OutOfMemoryError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/parallel/parallel_apply.py", line 99, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/models/wav2vec2/modeling_wav2vec2.py", line 1862, in forward
    outputs = self.wav2vec2(
              ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/models/wav2vec2/modeling_wav2vec2.py", line 1462, in forward
    encoder_outputs = self.encoder(
                      ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/models/wav2vec2/modeling_wav2vec2.py", line 826, in forward
    layer_outputs = layer(
                    ^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/modeling_layers.py", line 94, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/models/wav2vec2/modeling_wav2vec2.py", line 667, in forward
    hidden_states = self.dropout(hidden_states)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/dropout.py", line 70, in forward
    return F.dropout(input, self.p, self.training, self.inplace)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/functional.py", line 1422, in dropout
    _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training)
                                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 MiB. GPU 1 has a total capacity of 14.56 GiB of which 5.81 MiB is free. Including non-PyTorch memory, this process has 14.55 GiB memory in use. Of the allocated memory 14.04 GiB is allocated by PyTorch, and 329.80 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)


In [None]:
# Save the best model and processor
trainer.save_model(os.path.join(OUTPUT_DIR, "best_model"))
processor.save_pretrained(os.path.join(OUTPUT_DIR, "best_model"))
print(f"‚úÖ Best model saved to {OUTPUT_DIR}/best_model")



# üî§ Section 11: KenLM Language Model (THE KEY TO FIXING WER)


In [None]:
# This is the MOST IMPORTANT section for improving WER.
# A language model guides beam search to produce valid Telugu words.

os.makedirs(LM_DIR, exist_ok=True)

# Step 1: Prepare Telugu text corpus for LM training
# Combine all transcriptions + any external Telugu text you have
lm_texts = (
    train_df["transcription"].tolist() +
    test_df["transcription"].tolist() +
    validation_df["transcription"].tolist()
)

# Write text corpus (one sentence per line)
lm_corpus_path = os.path.join(LM_DIR, "telugu_corpus.txt")
with open(lm_corpus_path, "w", encoding="utf-8") as f:
    for text in lm_texts:
        text = text.strip()
        if len(text) > 0:
            f.write(text + "\n")

print(f"LM corpus: {len(lm_texts):,} sentences written to {lm_corpus_path}")


In [None]:
# Step 2: Train KenLM n-gram model
# Install kenlm if not already: pip install https://github.com/kpu/kenlm/archive/master.zip
# Or use pre-built: pip install kenlm

# NOTE: You need the kenlm binary 'lmplz' installed.
# On Kaggle/Colab, install with:
# !apt-get install -y build-essential cmake libboost-all-dev
# !pip install https://github.com/kpu/kenlm/archive/master.zip

lm_arpa_path = os.path.join(LM_DIR, "telugu_5gram.arpa")

# Train 5-gram LM using KenLM
# If kenlm binary is available:
import subprocess
try:
    subprocess.run(
        f"lmplz -o 5 --prune 0 1 1 1 1 < {lm_corpus_path} > {lm_arpa_path}",
        shell=True, check=True
    )
    print(f"‚úÖ 5-gram LM trained: {lm_arpa_path}")
except FileNotFoundError:
    print("‚ö†Ô∏è  'lmplz' not found. Install KenLM or use the alternative below.")
    print("   Alternative: pip install pyctcdecode[kenlm]")
    print("   Then use: from pyctcdecode import build_ctcdecoder")
    print("   The decoder can work without an LM, just with worse WER.")
    lm_arpa_path = None


In [None]:
# Step 3: Build CTC decoder with LM
from pyctcdecode import build_ctcdecoder

# Get vocabulary labels in the correct order
vocab_dict_sorted = dict(sorted(processor.tokenizer.get_vocab().items(), key=lambda x: x[1]))
labels = list(vocab_dict_sorted.keys())

# Build decoder
if lm_arpa_path and os.path.exists(lm_arpa_path):
    decoder = build_ctcdecoder(
        labels=labels,
        kenlm_model_path=lm_arpa_path,
        alpha=0.5,   # LM weight (tune this: 0.1 - 1.0)
        beta=1.5,    # Word insertion bonus (tune this: 0.5 - 3.0)
    )
    print("‚úÖ CTC decoder with KenLM ready")
else:
    decoder = build_ctcdecoder(labels=labels)
    print("‚ö†Ô∏è CTC decoder WITHOUT LM (greedy beam search)")



# üß™ Section 12: Evaluation with Beam Search Decoding


In [None]:
import torch
import torchaudio
import jiwer
from tqdm import tqdm

def evaluate_with_beam_search(model, dataset_df, processor, decoder, beam_width=100):
    """
    Evaluate the model using beam search + LM decoding.
    This is where the WER magic happens.
    """
    model.eval()
    device = next(model.parameters()).device

    all_predictions = []
    all_references = []

    wer_transform = jiwer.Compose([
        jiwer.Strip(),
        jiwer.RemoveMultipleSpaces(),
    ])

    for idx in tqdm(range(len(dataset_df)), desc="Evaluating"):
        row = dataset_df.iloc[idx]
        audio_path = row["audio_path"]
        reference = row["transcription"]

        # Load audio
        waveform, sample_rate = torchaudio.load(audio_path)
        if sample_rate != SAMPLING_RATE:
            waveform = torchaudio.transforms.Resample(sample_rate, SAMPLING_RATE)(waveform)

        # Truncate
        max_samples = int(MAX_AUDIO_LENGTH_SECONDS * SAMPLING_RATE)
        if waveform.shape[1] > max_samples:
            waveform = waveform[:, :max_samples]

        # Process
        input_values = processor(
            waveform.squeeze().numpy(),
            sampling_rate=SAMPLING_RATE,
            return_tensors="pt",
        ).input_values.to(device)

        # Get logits
        with torch.no_grad():
            logits = model(input_values).logits

        logits_np = logits.cpu().numpy()[0]

        # *** BEAM SEARCH WITH LM ***
        prediction = decoder.decode(logits_np, beam_width=beam_width)

        all_predictions.append(wer_transform(prediction))
        all_references.append(wer_transform(reference))

    # Compute metrics
    wer = jiwer.wer(all_references, all_predictions)
    cer = jiwer.cer(all_references, all_predictions)

    return wer, cer, all_predictions, all_references



In [None]:
# Evaluate on VALIDATION set with beam search + LM
print("=" * 60)
print("EVALUATION: Beam Search + Language Model Decoding")
print("=" * 60)

wer_beam, cer_beam, preds_beam, refs_beam = evaluate_with_beam_search(
    model=model,
    dataset_df=validation_df,
    processor=processor,
    decoder=decoder,
    beam_width=100,
)

print(f"\n{'='*40}")
print(f"  WER (Beam Search + LM): {wer_beam:.4f} ({wer_beam*100:.2f}%)")
print(f"  CER (Beam Search + LM): {cer_beam:.4f} ({cer_beam*100:.2f}%)")
print(f"{'='*40}")


In [None]:
# Compare with greedy decoding (what you had before)
print("\n" + "=" * 60)
print("COMPARISON: Greedy Decoding (your old method)")
print("=" * 60)

model.eval()
device = next(model.parameters()).device

greedy_preds = []
greedy_refs = []

for idx in tqdm(range(len(validation_df)), desc="Greedy eval"):
    row = validation_df.iloc[idx]
    waveform, sr = torchaudio.load(row["audio_path"])
    if sr != SAMPLING_RATE:
        waveform = torchaudio.transforms.Resample(sr, SAMPLING_RATE)(waveform)
    max_samples = int(MAX_AUDIO_LENGTH_SECONDS * SAMPLING_RATE)
    if waveform.shape[1] > max_samples:
        waveform = waveform[:, :max_samples]

    input_values = processor(
        waveform.squeeze().numpy(), sampling_rate=SAMPLING_RATE, return_tensors="pt"
    ).input_values.to(device)

    with torch.no_grad():
        logits = model(input_values).logits
    pred_ids = torch.argmax(logits, dim=-1)
    prediction = processor.decode(pred_ids[0])

    greedy_preds.append(prediction.strip())
    greedy_refs.append(row["transcription"].strip())

wer_greedy = jiwer.wer(greedy_refs, greedy_preds)
cer_greedy = jiwer.cer(greedy_refs, greedy_preds)

print(f"\n{'='*50}")
print(f"  GREEDY:      WER={wer_greedy:.4f} ({wer_greedy*100:.2f}%)  CER={cer_greedy:.4f} ({cer_greedy*100:.2f}%)")
print(f"  BEAM+LM:     WER={wer_beam:.4f} ({wer_beam*100:.2f}%)  CER={cer_beam:.4f} ({cer_beam*100:.2f}%)")
print(f"  IMPROVEMENT:  WER={wer_greedy - wer_beam:.4f} ({(wer_greedy - wer_beam)*100:.2f}% absolute)")
print(f"{'='*50}")


In [None]:
# Show sample predictions
print("\n--- Sample Predictions (Beam+LM vs Greedy) ---")
for i in range(min(10, len(validation_df))):
    print(f"\n[{i}] Reference:  {greedy_refs[i]}")
    print(f"    Greedy:     {greedy_preds[i]}")
    print(f"    Beam+LM:    {preds_beam[i]}")



# üîç Section 13: LM Hyperparameter Tuning (alpha & beta)


In [None]:
# Tune alpha (LM weight) and beta (word insertion bonus)
# This can improve WER by another 2-5%

print("Tuning LM hyperparameters on validation set...")
print("(This takes a while - skip if you're in a hurry)")

best_wer = float("inf")
best_alpha = 0.5
best_beta = 1.5

# Quick tuning grid
alphas = [0.1, 0.3, 0.5, 0.7, 1.0]
betas = [0.5, 1.0, 1.5, 2.0, 3.0]

if lm_arpa_path and os.path.exists(lm_arpa_path):
    for alpha in alphas:
        for beta in betas:
            # Rebuild decoder with new params
            test_decoder = build_ctcdecoder(
                labels=labels,
                kenlm_model_path=lm_arpa_path,
                alpha=alpha,
                beta=beta,
            )

            # Quick eval on first 50 samples for speed
            quick_preds = []
            quick_refs = []
            for idx in range(min(50, len(validation_df))):
                row = validation_df.iloc[idx]
                waveform, sr = torchaudio.load(row["audio_path"])
                if sr != SAMPLING_RATE:
                    waveform = torchaudio.transforms.Resample(sr, SAMPLING_RATE)(waveform)
                input_values = processor(
                    waveform.squeeze().numpy(), sampling_rate=SAMPLING_RATE, return_tensors="pt"
                ).input_values.to(device)
                with torch.no_grad():
                    logits = model(input_values).logits
                pred = test_decoder.decode(logits.cpu().numpy()[0], beam_width=50)
                quick_preds.append(pred.strip())
                quick_refs.append(row["transcription"].strip())

            wer = jiwer.wer(quick_refs, quick_preds)
            if wer < best_wer:
                best_wer = wer
                best_alpha = alpha
                best_beta = beta
                print(f"  New best: alpha={alpha}, beta={beta}, WER={wer:.4f}")

    print(f"\n‚úÖ Best LM params: alpha={best_alpha}, beta={best_beta}, WER={best_wer:.4f}")

    # Rebuild decoder with best params
    decoder = build_ctcdecoder(
        labels=labels,
        kenlm_model_path=lm_arpa_path,
        alpha=best_alpha,
        beta=best_beta,
    )
else:
    print("‚ö†Ô∏è Skipping - no LM available")



# üêã Section 14: Whisper Fine-Tuning (Comparison Baseline)

Whisper uses attention-based seq2seq decoding, NOT CTC.
This avoids many CTC blank/spike issues and handles word boundaries natively.
This section is OPTIONAL but strongly recommended for a research paper.


In [None]:
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    WhisperTokenizer,
    WhisperFeatureExtractor,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)

WHISPER_MODEL = "openai/whisper-small"  # or "openai/whisper-medium" for better results
WHISPER_OUTPUT_DIR = "./results_whisper"

# Load Whisper components
whisper_feature_extractor = WhisperFeatureExtractor.from_pretrained(WHISPER_MODEL)
whisper_tokenizer = WhisperTokenizer.from_pretrained(WHISPER_MODEL, language="te", task="transcribe")
whisper_processor = WhisperProcessor.from_pretrained(WHISPER_MODEL, language="te", task="transcribe")
whisper_model = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL)

# Set language and task
whisper_model.generation_config.language = "te"
whisper_model.generation_config.task = "transcribe"
whisper_model.generation_config.forced_decoder_ids = None

print(f"‚úÖ Whisper model loaded: {WHISPER_MODEL}")


In [None]:
# Prepare Whisper dataset
def prepare_whisper_dataset(batch):
    """Process audio for Whisper."""
    audio, sr = librosa.load(batch["audio_path"], sr=SAMPLING_RATE)

    # Truncate
    max_samples = int(MAX_AUDIO_LENGTH_SECONDS * SAMPLING_RATE)
    if len(audio) > max_samples:
        audio = audio[:max_samples]

    batch["input_features"] = whisper_feature_extractor(
        audio, sampling_rate=SAMPLING_RATE
    ).input_features[0]

    batch["labels"] = whisper_tokenizer(batch["transcription"]).input_ids
    return batch

whisper_train = Dataset.from_pandas(train_df[["audio_path", "transcription"]])
whisper_val = Dataset.from_pandas(validation_df[["audio_path", "transcription"]])

whisper_train = whisper_train.map(
    prepare_whisper_dataset,
    remove_columns=["audio_path", "transcription"],
    num_proc=4,
)
whisper_val = whisper_val.map(
    prepare_whisper_dataset,
    remove_columns=["audio_path", "transcription"],
    num_proc=2,
)

print(f"Whisper train: {len(whisper_train)}, val: {len(whisper_val)}")


In [None]:
# Whisper data collator
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: WhisperProcessor
    decoder_start_token_id: int

    def __call__(self, features):
        input_features = [{"input_features": f["input_features"]} for f in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        label_features = [{"input_ids": f["labels"]} for f in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )
        if (labels[:, 0] == self.decoder_start_token_id).all():
            labels = labels[:, 1:]

        batch["labels"] = labels
        return batch

whisper_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=whisper_processor,
    decoder_start_token_id=whisper_model.config.decoder_start_token_id,
)


In [None]:
# Whisper compute metrics
def compute_whisper_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    label_ids[label_ids == -100] = whisper_tokenizer.pad_token_id

    pred_str = whisper_tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = whisper_tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = jiwer.wer(label_str, pred_str)
    cer = jiwer.cer(label_str, pred_str)
    return {"wer": wer, "cer": cer}


In [None]:
# Whisper training
whisper_training_args = Seq2SeqTrainingArguments(
    output_dir=WHISPER_OUTPUT_DIR,
    num_train_epochs=10,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=8,
    learning_rate=1e-5,
    warmup_ratio=0.1,
    weight_decay=0.01,
    fp16=True,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    predict_with_generate=True,
    generation_max_length=225,
    logging_steps=50,
    report_to="none",
    remove_unused_columns=False,
)

whisper_trainer = Seq2SeqTrainer(
    model=whisper_model,
    args=whisper_training_args,
    train_dataset=whisper_train,
    eval_dataset=whisper_val,
    data_collator=whisper_collator,
    compute_metrics=compute_whisper_metrics,
    tokenizer=whisper_processor.feature_extractor,
)

print("üèãÔ∏è Starting Whisper training...")
whisper_trainer.train()


In [None]:
# Evaluate Whisper
whisper_results = whisper_trainer.evaluate()
print(f"\n{'='*50}")
print(f"  WHISPER Results:")
print(f"  WER: {whisper_results['eval_wer']:.4f} ({whisper_results['eval_wer']*100:.2f}%)")
print(f"  CER: {whisper_results['eval_cer']:.4f} ({whisper_results['eval_cer']*100:.2f}%)")
print(f"{'='*50}")



# üìä Section 15: Final Comparison Table


In [None]:
print("\n" + "=" * 70)
print("  FINAL RESULTS COMPARISON")
print("=" * 70)
print(f"{'Model':<35} {'WER':>8} {'CER':>8}")
print("-" * 53)
print(f"{'Wav2Vec2 Greedy':<35} {wer_greedy*100:>7.2f}% {cer_greedy*100:>7.2f}%")
print(f"{'Wav2Vec2 Beam+LM':<35} {wer_beam*100:>7.2f}% {cer_beam*100:>7.2f}%")
try:
    print(f"{'Whisper (seq2seq)':<35} {whisper_results['eval_wer']*100:>7.2f}% {whisper_results['eval_cer']*100:>7.2f}%")
except:
    print(f"{'Whisper (seq2seq)':<35} {'N/A':>8} {'N/A':>8}")
print("=" * 70)
print("\nFor your research paper, report all three rows above.")
print("The Beam+LM result should show significant WER improvement over Greedy.")



# üí° Section 16: Tips for Further Improvement

1. **External Telugu LM data**: Download Telugu Wikipedia dump or IndicNLP corpus
   and add to your KenLM training. More text = better LM = lower WER.

2. **BPE tokenization**: Try SentencePiece BPE with vocab_size=300-500
   instead of character-level CTC. This helps group frequent character
   sequences into subword tokens.

3. **Self-training**: Use your best model to pseudo-label unlabeled Telugu
   audio, filter by confidence, and retrain.

4. **Data augmentation**: Add speed perturbation (0.9x, 1.1x), noise injection,
   and room impulse response simulation.

5. **Try IndicConformer**: AI4Bharat's conformer models if available for Telugu.

6. **Ensemble**: Average logits from multiple models before decoding.

print("\nüéâ Pipeline complete! Check the results above.")
