# ASR Fellowship Challenge:  Adapter-Based Fine-Tuning for Low-Resource Languages
## Automatic Speech Recognition (ASR) Health Adapter


**Objective:** Implement and train adapter modules to reduce Word Error Rate (WER)  on Afrivoice Kinyarwanda Health dataset while keeping base model frozen

**Evaluation Metric:** Evaluation will consider the adapter’s accuracy and
improvement over the base model.

##### Key Constraints:
- ✅ Base model weights MUST remain frozen
- ✅ Only train adapter parameters

## 1. Environment Setup

In [1]:
# mount google drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
# GPU-avaliability
!nvidia-smi

Fri Nov 21 20:54:46 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA L4                      Off |   00000000:00:03.0 Off |                    0 |
| N/A   36C    P8             11W /   72W |       0MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [4]:
!pip install evaluate jiwer

Installing collected packages: rapidfuzz, jiwer, evaluate
Successfully installed evaluate-0.4.6 jiwer-4.0.0 rapidfuzz-3.14.3


In [6]:
# import useful libraries
import os
import tarfile
import json
import io
import random
from pathlib import Path
from typing import Dict, List, Union
from dataclasses import dataclass
import tarfile
import shutil


import logging

import numpy as np
import matplotlib.pyplot as plt
import IPython.display as ipd
from tqdm.auto import tqdm
from pathlib import Path



import torch
import torchaudio
from torch.utils.data import IterableDataset, get_worker_info
from huggingface_hub import snapshot_download
from transformers import Wav2Vec2BertProcessor, Wav2Vec2BertForCTC, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, PeftModel,PeftConfig, TaskType


import evaluate
import re
import unicodedata


import warnings
import logging

warnings.filterwarnings("ignore")
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("datasets").setLevel(logging.WARNING)
logging.getLogger("torchaudio").setLevel(logging.ERROR)

print("Warning and logging filters applied (warnings suppressed).")



## 2. Download dataset

1.  download dataset
2.  Load base model
3.  Sanity check for Audio/text





In [None]:
# # Download dataset into google drive
# dataset_path = '/content/drive/MyDrive/Colab/Notebooks/Adapt_Data'

# # create path if it does not exit
# if not os.path.exists('/content/drive/MyDrive/Colab/Notebooks/Adapt_Data'):
#   os.makedirs('/content/drive/MyDrive/Colab/Notebooks/Adapt_Data')

# # download dataset
# print("Downloading dataset............")
# snapshot_download(repo_id="DigitalUmuganda/ASR_Fellowship_Challenge_Dataset",
#                   repo_type='dataset',
#                   local_dir=dataset_path
#                   )
# print(f"✅ Dataset downloaded to {dataset_path}")
# print(f"\n Dataset structure:")
# !ls -l {dataset_path}

### Load Base model
I will be using a base model trained on digital Umuganda data for ASR
https://huggingface.co/badrex/w2v-bert-2.0-kinyarwanda-asr
@misc{w2v_bert_kinyarwanda_asr,
  author = {Badr M. Abdullah},
  title = {Adapting Wav2Vec2-BERT 2.0 for Kinyarwanda ASR},
  year = {2025},
  publisher = {Hugging Face},
  url = {https://huggingface.co/badrex/w2v-bert-2.0-kinyarwanda-asr-1000h}
}

In [6]:
# Load base model  and processor
print("Loading Base Model...")
processor = Wav2Vec2BertProcessor.from_pretrained("badrex/w2v-bert-2.0-kinyarwanda-asr-1000h")
model = Wav2Vec2BertForCTC.from_pretrained(
    "badrex/w2v-bert-2.0-kinyarwanda-asr-1000h",
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
)

Loading Base Model...


preprocessor_config.json:   0%|          | 0.00/275 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json:   0%|          | 0.00/331 [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/30.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/548 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/2.32G [00:00<?, ?B/s]

In [7]:
# print model processor
print(processor)

Wav2Vec2BertProcessor:
- feature_extractor: SeamlessM4TFeatureExtractor {
  "feature_extractor_type": "SeamlessM4TFeatureExtractor",
  "feature_size": 80,
  "num_mel_bins": 80,
  "padding_side": "right",
  "padding_value": 1,
  "processor_class": "Wav2Vec2BertProcessor",
  "return_attention_mask": true,
  "sampling_rate": 16000,
  "stride": 2
}

- tokenizer: Wav2Vec2CTCTokenizer(name_or_path='badrex/w2v-bert-2.0-kinyarwanda-asr-1000h', vocab_size=30, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '[UNK]', 'pad_token': '[PAD]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	28: AddedToken("[UNK]", rstrip=True, lstrip=True, single_word=False, normalized=False, special=False),
	29: AddedToken("[PAD]", rstrip=True, lstrip=True, single_word=False, normalized=False, special=False),
	30: AddedToken("<s>", rstrip=False, lstrip=False, single_word=F

In [8]:
#print model architecture
model

Wav2Vec2BertForCTC(
  (wav2vec2_bert): Wav2Vec2BertModel(
    (feature_projection): Wav2Vec2BertFeatureProjection(
      (layer_norm): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=160, out_features=1024, bias=True)
      (dropout): Dropout(p=0.05, inplace=False)
    )
    (encoder): Wav2Vec2BertEncoder(
      (dropout): Dropout(p=0.05, inplace=False)
      (layers): ModuleList(
        (0-23): 24 x Wav2Vec2BertEncoderLayer(
          (ffn1_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (ffn1): Wav2Vec2BertFeedForward(
            (intermediate_dropout): Dropout(p=0.0, inplace=False)
            (intermediate_dense): Linear(in_features=1024, out_features=4096, bias=True)
            (intermediate_act_fn): SiLU()
            (output_dense): Linear(in_features=4096, out_features=1024, bias=True)
            (output_dropout): Dropout(p=0.05, inplace=False)
          )
          (self_attn_layer_norm): LayerNorm(

In [45]:
def clean_text_winner(text):
    if not text: return ""
    text = text.lower()
    text = unicodedata.normalize("NFD", text)
    allowed_chars = set("abcdefghijklmnopqrstuvwxyz '")
    text = ''.join(c for c in text if c in allowed_chars)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

class StreamingAudioDataset(IterableDataset):
    def __init__(self, data_dir, split='train', shuffle=True):
        self.data_dir = Path(data_dir)
        self.split = split
        self.shuffle = shuffle
        self.tarred_dir = self.data_dir / f"{split}_tarred" / "sharded_manifests_with_image"
        self.audio_shards_dir = self.tarred_dir / "audio_shards"
        self.debug_missing_count = 0

        # load metadata
        print(f"  Loading {split} metadata...")
        self.meta_map = {}
        manifest_files = sorted(list(self.tarred_dir.glob("manifest_*.json")))

        if not manifest_files:
            print(f" Warning: No manifests found at {self.tarred_dir}")

        for m_file in tqdm(manifest_files, desc=f"Loading {split} manifests"):
            with open(m_file, 'r') as f:
                for line in f:
                    if not line.strip(): continue
                    try:
                        data = json.loads(line)
                        self.meta_map[data['audio_filepath']] = data
                    except: continue

        self.tar_paths = sorted(list(self.audio_shards_dir.glob("audio_*.tar.xz")))
        print(f" Loaded metadata for {len(self.meta_map)} files across {len(self.tar_paths)} shards.")

    def parse_tar_file(self, tar_path):
        try:
            with tarfile.open(tar_path, mode='r|*') as tar:
                for member in tar:
                    if not member.isfile(): continue

                    # Robust key matching
                    key = None
                    possible_keys = [
                        member.name,
                        member.name.split('/')[-1],
                        "./" + member.name,
                        member.name.lstrip("./")
                    ]

                    for k in possible_keys:
                        if k in self.meta_map:
                            key = k
                            break

                    if not key:
                        if self.debug_missing_count < 5:
                            print(f"DEBUG: Key not found for {member.name}. Tried: {possible_keys}")
                        self.debug_missing_count += 1
                        continue

                    f = tar.extractfile(member)
                    if f is None: continue

                    audio_bytes = f.read()
                    audio_buffer = io.BytesIO(audio_bytes)

                    try:
                        # Determine format from key (filename)
                        file_ext = None
                        if key:
                            file_ext = key.split('.')[-1].lower()

                        if file_ext in ['webm', 'mp3', 'wav', 'flac']:
                            waveform, sr = torchaudio.load(audio_buffer, format=file_ext)
                        else:
                            waveform, sr = torchaudio.load(audio_buffer)

                        if self.split == 'train' and waveform.abs().max() < 1e-5:
                            # print(f"  Skipping silent/corrupt audio: {key}")
                            continue

                        raw_text = self.meta_map[key].get('text', '')
                        cleaned_text = clean_text_winner(raw_text)

                        if not cleaned_text and self.split == 'train':
                            continue

                        if sr != 16000:
                            waveform = torchaudio.functional.resample(waveform, sr, 16000)

                        if waveform.shape[0] > 1:
                            waveform = waveform.mean(dim=0, keepdim=True)

                        yield {
                            "audio": waveform.squeeze().numpy(),
                            "text": cleaned_text,
                            "path": key
                        }
                    except Exception as e:
                        print(f"Error processing file {member.name}: {e}")
                        continue

        except Exception as e:
            print(f"Error opening shard {tar_path}: {e}")

    def __iter__(self):
        worker_info = get_worker_info()
        if worker_info is None:
            my_shards = self.tar_paths
        else:
            per_worker = int(len(self.tar_paths) / worker_info.num_workers)
            worker_id = worker_info.id
            start = worker_id * per_worker
            end = start + per_worker if worker_id < worker_info.num_workers - 1 else len(self.tar_paths)
            my_shards = self.tar_paths[start:end]

        if self.shuffle:
            random.shuffle(my_shards)

        for tar_path in my_shards:
            yield from self.parse_tar_file(tar_path)


In [46]:
data_dir = '/content/drive/MyDrive/Colab/Adapt_Data'

### Sanity Check for audio/text pairing

In [47]:
splits = {
    "Train": StreamingAudioDataset(data_dir, split='train', shuffle=True),
    "Validation": StreamingAudioDataset(data_dir, split='val', shuffle=True),
    "Test": StreamingAudioDataset(data_dir, split='test', shuffle=True)
}

print("AUDIO & DATA SANITY CHECK\n")

for split_name, dataset in splits.items():
    print(f" Fetching 1 sample from {split_name} split...")

    iterator = iter(dataset)

    try:
        sample = next(iterator)

        print(f" Filename: {sample['path']}")

        features = processor(
            audio=sample["audio"],
            sampling_rate=16000,
            return_tensors="pt"
        ).input_features

        # Check Shape [1, Time, 160] as specified in the model's archicture
        print(f" Feature Shape: {features.shape} (Depth: {features.shape[-1]})")

        print(" Playing Audio:")
        audio_for_display = sample["audio"]
        if isinstance(audio_for_display, torch.Tensor):
            audio_for_display = audio_for_display.numpy()

        ipd.display(ipd.Audio(data=audio_for_display, rate=16000))

        print("-" * 60 + "\n")

    except StopIteration:
        print(f" Error: Could not fetch sample from {split_name} (Dataset empty?)\n")
    except Exception as e:
        print(f" Error processing {split_name}: {e}\n")

print("✅ Check complete. Verify Train/Val text matches audio")

  Loading train metadata...


Loading train manifests: 100%|██████████| 27/27 [00:01<00:00, 17.81it/s]


 Loaded metadata for 176361 files across 27 shards.
  Loading val metadata...


Loading val manifests: 100%|██████████| 1/1 [00:00<00:00, 61.86it/s]


 Loaded metadata for 1617 files across 1 shards.
  Loading test metadata...


Loading test manifests: 100%|██████████| 1/1 [00:00<00:00, 83.11it/s]


 Loaded metadata for 1569 files across 1 shards.
AUDIO & DATA SANITY CHECK

 Fetching 1 sample from Train split...
 Filename: audio_1742857016-70FFZsxmhVX2lSZ6ioXAUmNnOtl1.webm
 Feature Shape: torch.Size([1, 485, 160]) (Depth: 160)
 Playing Audio:


------------------------------------------------------------

 Fetching 1 sample from Validation split...
 Filename: audio_1742357277-lJcmtCeFHXWoYh4am0A0uiTH14M2.webm
 Feature Shape: torch.Size([1, 734, 160]) (Depth: 160)
 Playing Audio:


------------------------------------------------------------

 Fetching 1 sample from Test split...
 Filename: audio_1743193078-RHVdK8vg94QZjU5VMSNq48Xx42u1.webm
 Feature Shape: torch.Size([1, 641, 160]) (Depth: 160)
 Playing Audio:


------------------------------------------------------------

✅ Check complete. Verify Train/Val text matches audio


### Check if validation data has numbers/unknown symbols

In [12]:
# This is preprocessing step to check if val split has numbers or not
# this will give an idea if to post processing on the transcriptions of the fineturned model before evaluation

print("DETECTING NUMBERS IN VALIDATION SET...")

val_dataset = StreamingAudioDataset(data_dir, split='val', shuffle=False)
has_numbers = False
sample_count = 0

# Check first 500 samples
iterator = iter(val_dataset)
for _ in range(500):
    try:
        sample = next(iterator)
        text = sample['text']

        # Check if any character is a digit 0-9
        if any(char.isdigit() for char in text):
            print(f"⚠️ FOUND A NUMBER! Sample: {text}")
            has_numbers = True
            break
        sample_count += 1
    except StopIteration:
        break

if not has_numbers:
    print(f"✅ Checked {sample_count} validation samples. NO numbers found.")
    print("   CONCLUSION: The dataset uses spelled-out words. SAFE to remove numbers.")
else:
    print("❌ CONCLUSION: The dataset expects digits. DO NOT remove numbers.")

DETECTING NUMBERS IN VALIDATION SET...
  Loading val metadata...
 Loaded metadata for 1617 files across 1 shards.
✅ Checked 0 validation samples. NO numbers found.
   CONCLUSION: The dataset uses spelled-out words. SAFE to remove numbers.


In [13]:

train_split = StreamingAudioDataset(data_dir, split='train', shuffle=True)
test_split = StreamingAudioDataset(data_dir, split='test', shuffle=False)
val_split = StreamingAudioDataset(data_dir, split='val', shuffle=True)

  Loading train metadata...
 Loaded metadata for 176361 files across 27 shards.
  Loading test metadata...
 Loaded metadata for 1569 files across 1 shards.
  Loading val metadata...
 Loaded metadata for 1617 files across 1 shards.


### Extractig th audio files to prep dataset

In [None]:
import shutil
import os
from pathlib import Path

# 1. Define Paths
drive_data_dir = Path('/content/drive/MyDrive/Colab/Adapt_Data')
local_data_dir = Path('/content/local_data')

# 2. Create Local Directory
if not local_data_dir.exists():
    print(f"🚀 Copying data from Drive to Local Disk ({local_data_dir})...")
    print("   This might take 2-5 minutes, but will save hours of training time.")
    shutil.copytree(drive_data_dir, local_data_dir)
    print("✅ Copy complete! Reading from local disk now.")
else:
    print("✅ Data already exists locally.")

🚀 Copying data from Drive to Local Disk (/content/local_data)...
   This might take 2-5 minutes, but will save hours of training time.
✅ Copy complete! Reading from local disk now.


In [None]:
# !rm -rf /content/local_data/.cache

### Extract all audio

In [None]:
import os
import time
from pathlib import Path
from tqdm.auto import tqdm

local_data_root = Path('/content/local_data')
extract_path = Path('/content/extracted_audio')

splits_to_process = ['train', 'test', 'val']

if not extract_path.exists():
    extract_path.mkdir(parents=True, exist_ok=True)
    print(f"🚀 Starting extraction to: {extract_path}")

    start_global = time.time()

    for split in splits_to_process:
        # Dynamic path construction based on split name
        current_tar_dir = local_data_root / f"{split}_tarred" / "sharded_manifests_with_image" / "audio_shards"

        # Safety check: Does this split exist?
        if not current_tar_dir.exists():
            print(f"⚠️ Skipping '{split}': Directory not found at {current_tar_dir}")
            continue

        # Find tars
        tar_files = list(current_tar_dir.glob("*.tar*"))

        if not tar_files:
            print(f"⚠️ Skipping '{split}': No .tar files found in folder.")
            continue

        # --- THE TQDM LOOP ---
        print(f"📦 Found {len(tar_files)} shards for '{split}'. Extracting...")

        # tqdm wraps the list to show a progress bar
        for tar in tqdm(tar_files, desc=f"Extracting {split}"):
            # -xf: extract file
            # -C: output dir
            # --skip-old-files: Don't waste time if it's already there
            os.system(f"tar -xf {str(tar)} -C {str(extract_path)} --skip-old-files")

    end_global = time.time()
    print(f"\n✅ All Extraction Finished in {(end_global-start_global)/60:.2f} minutes.")
    print(f"   Files are ready at: {extract_path}")

else:
    print("✅ Extraction folder already exists. Skipping.")

🚀 Starting extraction to: /content/extracted_audio
📦 Found 27 shards for 'train'. Extracting...


Extracting train:   0%|          | 0/27 [00:00<?, ?it/s]

📦 Found 1 shards for 'test'. Extracting...


Extracting test:   0%|          | 0/1 [00:00<?, ?it/s]

📦 Found 1 shards for 'val'. Extracting...


Extracting val:   0%|          | 0/1 [00:00<?, ?it/s]


✅ All Extraction Finished in 26.43 minutes.
   Files are ready at: /content/extracted_audio


In [None]:
# remove the .cache files
!rm -rf /content/local_data/.cache

### prepare training data

In [None]:
import json
import torchaudio
import re
import unicodedata
from pathlib import Path
from tqdm.auto import tqdm
from torch.utils.data import Dataset


def clean_text_winner(text):
    if not text: return ""
    text = text.lower()
    text = unicodedata.normalize("NFD", text)
    allowed_chars = set("abcdefghijklmnopqrstuvwxyz '")
    text = ''.join(c for c in text if c in allowed_chars)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

class MapStyleAudioDataset(Dataset):
    def __init__(self, data_list):
        self.data = data_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        audio_path, text = self.data[idx]

        # Load Audio
        waveform, sr = torchaudio.load(audio_path)

        # Resample if needed
        if sr != 16000:
            waveform = torchaudio.functional.resample(waveform, sr, 16000)

        return {
            "audio": waveform.squeeze().numpy(),
            "text": text, # This will be "" for test split
            "input_length": waveform.shape[1]
        }


print("🔍 Indexing ALL extracted audio files...")
audio_root = Path('/content/extracted_audio')

# Map filename -> full_path
audio_file_map = {}
for f in tqdm(audio_root.glob("**/*"), desc="Indexing Global Audio"):
    if f.is_file() and f.suffix in ['.wav', '.mp3', '.flac', '.webm']:
        audio_file_map[f.name] = f

print(f"✅ Indexed {len(audio_file_map)} files globally.")



def create_dataset_for_split(split_name, audio_map):
    print(f"\n🔗 Processing split: '{split_name}'...")

    manifest_dir = Path(f'/content/local_data/{split_name}_tarred/sharded_manifests_with_image')

    if not manifest_dir.exists():
        print(f"⚠️ Warning: Directory not found for {split_name}: {manifest_dir}")
        return None

    manifest_files = sorted(list(manifest_dir.glob("manifest_*.json")))

    if not manifest_files:
        print(f"⚠️ Warning: No JSON manifests found in {manifest_dir}")
        return None

    valid_samples = []
    skipped_long = 0
    skipped_missing = 0

    for m_file in tqdm(manifest_files, desc=f"Parsing {split_name}"):
        with open(m_file, 'r') as f:
            for line in f:
                if not line.strip(): continue
                try:
                    entry = json.loads(line)
                    filename = Path(entry['audio_filepath']).name

                    if filename in audio_map:
                        full_path = audio_map[filename]


                        if full_path.stat().st_size > 1500000:
                            skipped_long += 1
                            continue

                        raw_text = entry.get('text', '')
                        clean = clean_text_winner(raw_text)

                        if split_name == 'test':
                            # For TEST: We accept the audio even if text is empty
                            valid_samples.append((str(full_path), clean))
                        else:
                            # For TRAIN/VAL: We require valid text for loss calculation
                            if clean:
                                valid_samples.append((str(full_path), clean))

                    else:
                        skipped_missing += 1
                except: continue

    print(f"✅ {split_name.upper()} Dataset Ready: {len(valid_samples)} samples")
    if skipped_long > 0: print(f"   Skipped {skipped_long} long files")
    if skipped_missing > 0: print(f"   Skipped {skipped_missing} missing audio files")

    return MapStyleAudioDataset(valid_samples)

# --- EXECUTION LOOP ---
splits = ['train', 'val', 'test']
all_datasets = {}

for split in splits:
    ds = create_dataset_for_split(split, audio_file_map)
    if ds is not None:
        all_datasets[split] = ds

# --- EXPOSE VARIABLES ---
train_dataset = all_datasets.get('train')
val_dataset   = all_datasets.get('val')
test_dataset  = all_datasets.get('test')

print("\n🎉 All datasets initialized!")
if train_dataset: print(f"Train size: {len(train_dataset)}")
if val_dataset:   print(f"Val size:   {len(val_dataset)}")
if test_dataset:  print(f"Test size:  {len(test_dataset)}")

🔍 Indexing ALL extracted audio files...


Indexing Global Audio: 0it [00:00, ?it/s]

✅ Indexed 179314 files globally.

🔗 Processing split: 'train'...


Parsing train:   0%|          | 0/27 [00:00<?, ?it/s]

✅ TRAIN Dataset Ready: 176129 samples

🔗 Processing split: 'val'...


Parsing val:   0%|          | 0/1 [00:00<?, ?it/s]

✅ VAL Dataset Ready: 1617 samples

🔗 Processing split: 'test'...


Parsing test:   0%|          | 0/1 [00:00<?, ?it/s]

✅ TEST Dataset Ready: 1569 samples

🎉 All datasets initialized!
Train size: 176129
Val size:   1617
Test size:  1569


In [None]:
train_dataset[0]

{'audio': array([ 3.0865804e-15,  2.2852925e-13,  6.8836234e-13, ...,
        -3.7584133e-02, -2.8015159e-02, -1.7984023e-02], dtype=float32),
 'text': 'abagabo babiri aho bari gusindagiza umuntu uri hagati aho yagize ikibazo atari kubasha kugenda neza kugira ngo amugeze ahabugenewe bityo amenye',
 'input_length': 213120}

### Verify the loaded data

In [None]:
# play some audio samples to see the fi the text corresponse
import random
from IPython.display import Audio, display

def verify_dataset_samples(dataset, split_name, num_samples=3):
    """
    Picks random samples from the dataset, prints the text, and displays the audio player.
    """
    if dataset is None or len(dataset) == 0:
        print(f"⚠️ Skipping {split_name}: Dataset is empty or None.")
        return

    print(f"\n{'='*40}")
    print(f"🎧 CHECKING: {split_name.upper()} SPLIT (Total: {len(dataset)})")
    print(f"{'='*40}")

    # Pick unique random indices
    indices = random.sample(range(len(dataset)), min(num_samples, len(dataset)))

    for idx in indices:
        sample = dataset[idx]

        # Extract data
        audio_data = sample['audio']
        text = sample['text']

        print(f"\n📌 Index: {idx}")

        # Handle Test split (Empty text)
        if not text:
            print(f"📝 Text: [No Text Available - Test Split]")
        else:
            print(f"📝 Text: \"{text}\"")

        # Display Audio (Rate is 16000 as per your dataset logic)
        display(Audio(audio_data, rate=16000))

# --- RUN FOR ALL SPLITS ---
# We assume train_dataset, val_dataset, and test_dataset are already defined
datasets_to_check = {
    "Train": train_dataset,
    "Validation": val_dataset,
    "Test": test_dataset
}

for name, ds in datasets_to_check.items():
    verify_dataset_samples(ds, name, num_samples=1)


### Check if data is compatible with model

In [None]:
# load 3 sample and try to do a forward and backward pass with the data
print("Loading Base Model")
model_base = Wav2Vec2BertForCTC.from_pretrained(
    "badrex/w2v-bert-2.0-kinyarwanda-asr-1000h",
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
model_base.to(device)
model_base.train()

print("\n Fetching 2 samples...")
iterator = iter(train_split)
batch_samples = []
for _ in range(2):
    try:
        batch_samples.append(next(iterator))
    except StopIteration:
        break

data_collator = DataCollatorCTCWithPadding(processor=processor)
batch = data_collator(batch_samples)
batch = {k: v.to(device) for k, v in batch.items()}

print(f"   Input Shape: {batch['input_features'].shape} (Expect [2, Time, 160])")

print("\nForward Pass...")
try:
    outputs = model_base(**batch)
    loss = outputs.loss

    print(f"   ✅ Loss calculated: {loss.item():.4f}")
except Exception as e:
    print(f"   ❌ Forward Pass Failed: {e}")
    raise e

print("\n Backward Pass...")
try:
    loss.backward()
    print("   ✅ Gradients calculated successfully.")
    print("   BASE MODEL IS COMPATIBLE.")
except Exception as e:
    print(f"   ❌ Backward Pass Failed: {e}")

# Cleanup
del model_base
del batch
torch.cuda.empty_cache()

## 3. Modeling (Base model + Adapter Module)


### Model setup

In [14]:
import torch
import types
from transformers import Wav2Vec2BertForCTC, Wav2Vec2BertProcessor, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

print(" Loading Model & Processor...")
model_id = "badrex/w2v-bert-2.0-kinyarwanda-asr-1000h"
processor = Wav2Vec2BertProcessor.from_pretrained(model_id)

model = Wav2Vec2BertForCTC.from_pretrained(
    model_id,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
)


def make_inputs_require_grad(module, input, output):
    if isinstance(output, tuple):
        output[0].requires_grad_(True)
    else:
        output.requires_grad_(True)

def _enable_input_require_grads(self):
    self.wav2vec2_bert.feature_projection.register_forward_hook(make_inputs_require_grad)

model.enable_input_require_grads = types.MethodType(_enable_input_require_grads, model)
model.gradient_checkpointing_enable()


peft_config = LoraConfig(
    inference_mode=False,
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=[
        "linear_q", "linear_k", "linear_v", "linear_out",
        "intermediate_dense", "output_dense"
    ]
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

 Loading Model & Processor...
trainable params: 11,010,048 || all params: 591,535,968 || trainable%: 1.8613


In [15]:
model

PeftModel(
  (base_model): LoraModel(
    (model): Wav2Vec2BertForCTC(
      (wav2vec2_bert): Wav2Vec2BertModel(
        (feature_projection): Wav2Vec2BertFeatureProjection(
          (layer_norm): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
          (projection): Linear(in_features=160, out_features=1024, bias=True)
          (dropout): Dropout(p=0.05, inplace=False)
        )
        (encoder): Wav2Vec2BertEncoder(
          (dropout): Dropout(p=0.05, inplace=False)
          (layers): ModuleList(
            (0-23): 24 x Wav2Vec2BertEncoderLayer(
              (ffn1_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (ffn1): Wav2Vec2BertFeedForward(
                (intermediate_dropout): Dropout(p=0.0, inplace=False)
                (intermediate_dense): lora.Linear(
                  (base_layer): Linear(in_features=1024, out_features=4096, bias=True)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.1

In [16]:

frozen_count = 0
trainable_count = 0
lora_detected = False
base_frozen = True

for name, param in model.named_parameters():

    if "lora_" in name:
        lora_detected = True
        if not param.requires_grad:
            print(f"❌ ALARM: LoRA layer {name} is FROZEN!")

    elif "lm_head" in name:
        if param.requires_grad:
            print(f"⚠️ NOTE: LM Head {name} is TRAINABLE.")
        else:
            frozen_count += 1

    else:
        if param.requires_grad:
            print(f"❌ ALARM: Base layer {name} is TRAINABLE!")
            base_frozen = False
        else:
            frozen_count += 1

    if param.requires_grad:
        trainable_count += 1

    if "layers.0.self_attn.linear_q" in name:
        status = "🟢 TRAINABLE" if param.requires_grad else "❄️ FROZEN"
        print(f"   Layer 0 (Query): {name[-60:]} -> {status}")

    if "layers.10.ffn1" in name and "output_dense" in name:
        status = "🟢 TRAINABLE" if param.requires_grad else "❄️ FROZEN"
        print(f"   Layer 10 (FFN):  {name[-60:]} -> {status}")

print("-" * 40)
print(f"   Total Frozen Parameters:    {frozen_count:,}")
print(f"   Total Trainable Parameters: {trainable_count:,}")

if base_frozen and lora_detected:
    print("\n✅ SUCCESS: Base model is strictly frozen. LoRA adapters are trainable.")
else:
    print("\n❌ FAILURE: Critical configuration error. Do not train.")

model.print_trainable_parameters()

   Layer 0 (Query): 2_bert.encoder.layers.0.self_attn.linear_q.base_layer.weight -> ❄️ FROZEN
   Layer 0 (Query): ec2_bert.encoder.layers.0.self_attn.linear_q.base_layer.bias -> ❄️ FROZEN
   Layer 0 (Query): rt.encoder.layers.0.self_attn.linear_q.lora_A.default.weight -> 🟢 TRAINABLE
   Layer 0 (Query): rt.encoder.layers.0.self_attn.linear_q.lora_B.default.weight -> 🟢 TRAINABLE
   Layer 10 (FFN):  2_bert.encoder.layers.10.ffn1.output_dense.base_layer.weight -> ❄️ FROZEN
   Layer 10 (FFN):  ec2_bert.encoder.layers.10.ffn1.output_dense.base_layer.bias -> ❄️ FROZEN
   Layer 10 (FFN):  rt.encoder.layers.10.ffn1.output_dense.lora_A.default.weight -> 🟢 TRAINABLE
   Layer 10 (FFN):  rt.encoder.layers.10.ffn1.output_dense.lora_B.default.weight -> 🟢 TRAINABLE
----------------------------------------
   Total Frozen Parameters:    775
   Total Trainable Parameters: 384

✅ SUCCESS: Base model is strictly frozen. LoRA adapters are trainable.
trainable params: 11,010,048 || all params: 591,535,968 |

In [None]:
@dataclass
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2BertProcessor
    def __call__(self, features):
        audio = [f["audio"] for f in features]
        text = [f["text"] for f in features]

        # Pad Audio
        batch = self.processor(audio=audio, sampling_rate=16000, return_tensors="pt", padding=True)

        # Pad Labels
        labels_batch = self.processor.tokenizer(text, return_tensors="pt", padding=True)

        # Mask padding with -100
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        batch["labels"] = labels
        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor)

In [None]:
import evaluate
import numpy as np

# Load both metrics
wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    # Replace -100 with pad_token_id
    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    # Decode
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    # Compute both
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer, "cer": cer}

### Training loop

In [None]:

def make_inputs_require_grad(module, input, output):
    if isinstance(output, tuple):
        output[0].requires_grad_(True)
    else:
        output.requires_grad_(True)

model.wav2vec2_bert.feature_projection.register_forward_hook(make_inputs_require_grad)

model.gradient_checkpointing_enable()

BATCH_SIZE = 4
GRAD_ACC = 4
MAX_STEPS = 4000

training_args = TrainingArguments(
    output_dir="/content/drive/MyDrive/w2v-bert-kinyarwanda-adapter_3",
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACC,

    # checkpointing
    gradient_checkpointing=False,
    fp16=True,

    # lr
    learning_rate=5e-5,

    # scheduler
    warmup_steps=100,
    max_steps=MAX_STEPS,

    # logs
    eval_strategy="steps",
    eval_steps=500,
    save_steps=500,
    save_total_limit=2,
    logging_steps=50,

    dataloader_num_workers=4,
    report_to="none",
    remove_unused_columns=False,

    group_by_length=False
)

trainer = Trainer(
    model=model,
    data_collator=DataCollatorCTCWithPadding(processor=processor),
    args=training_args,
    train_dataset=train_split,
    eval_dataset=val_split,
)

trainer.train()

  Skipping silent/corrupt audio: audio_1744998237-QH9XmV1mBxY3FgxjsAvWcJVWQzV2.webm


Step,Training Loss,Validation Loss
500,0.0548,0.061358
1000,0.0554,0.061418
1500,0.0779,0.060509
2000,0.0549,0.060306
2500,0.0597,0.060431


  Skipping silent/corrupt audio: audio_1743784255-5dAr2souD5gNWyQ0mzPFf13Kvd92.webm


KeyboardInterrupt: 

## 4. Generate Transcripts with Greedy decoding


1.   Evaluate the base model on Dev_test(Val)- report WER/CER
2.   Generate transcripts.txt using Test split

### Evaluate Base and Fineturned model

In [18]:
test_dataset = StreamingAudioDataset(data_dir, split='test', shuffle=False)

adapter_path = "/content/drive/MyDrive/w2v-bert-kinyarwanda-adapter_3/checkpoint-2500"

base_model_id = "badrex/w2v-bert-2.0-kinyarwanda-asr-1000h"

def clean_text_winner(text):
    if not text: return ""

    text = text.lower()

    text = unicodedata.normalize("NFD", text)

    allowed_chars = set("abcdefghijklmnopqrstuvwxyz '")
    text = ''.join(c for c in text if c in allowed_chars)

    text = re.sub(r'\s+', ' ', text).strip()
    return text

def run_inference_and_save(model, processor, dataset, output_filename, desc="Evaluating"):
    print(f"\n {desc}...")
    model.eval()
    model.to("cuda" if torch.cuda.is_available() else "cpu")

    predictions = []

    with open(output_filename, "w", encoding="utf-8") as f:
        for i, sample in tqdm(enumerate(dataset), desc=desc):
            try:
                input_features = processor(
                    audio=sample["audio"],
                    sampling_rate=16000,
                    return_tensors="pt"
                ).input_features.to(model.device)

                with torch.no_grad():
                    logits = model(input_features).logits

                pred_ids = torch.argmax(logits, dim=-1)
                pred_text = processor.batch_decode(pred_ids)[0]

                cleaned_pred = clean_text_winner(pred_text)
                predictions.append(cleaned_pred)

                # Extract Filename Only (Format Requirement)
                # /content/drive/.../audio_123.webm -> audio_123.webm
                full_path = sample['path']
                filename_only = os.path.basename(full_path)

                # Write to File: "ID <space> Transcript"
                f.write(f"{filename_only} {cleaned_pred}\n")

            except Exception as e:
                print(f"Error on sample {i}: {e}")
                filename_only = os.path.basename(sample['path']) if 'path' in sample else f"unknown_{i}"
                f.write(f"{filename_only} \n")
                predictions.append("")
                continue

    print(f"✅ Saved transcriptions to {output_filename}")
    return predictions

wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")


  Loading test metadata...
 Loaded metadata for 1569 files across 1 shards.


### Compare Model performance. Generate all files

In [None]:
print("\n PREPARING VALIDATION DATA FOR REPORT...")
val_dataset = StreamingAudioDataset(data_dir, split='val', shuffle=False)
ground_truths = []

# reference text
for sample in tqdm(val_dataset, desc="Extracting Labels"):
    ground_truths.append(clean_text_winner(sample['text']))

# evaluate base model
print("\n Loading Base Model...")
processor = Wav2Vec2BertProcessor.from_pretrained(base_model_id)
model_base = Wav2Vec2BertForCTC.from_pretrained(base_model_id).to("cuda")

base_preds = run_inference_and_save(model_base, processor, val_dataset, "base_val_transcriptions.txt", "Base Inference (Val)")

base_wer = wer_metric.compute(predictions=base_preds, references=ground_truths)
base_cer = cer_metric.compute(predictions=base_preds, references=ground_truths)

# evaluate adaper
print("\n Loading Adapter Model...")
model_ft = Wav2Vec2BertForCTC.from_pretrained(base_model_id)
model_ft = PeftModel.from_pretrained(model_ft, adapter_path).to("cuda")

ft_preds = run_inference_and_save(model_ft, processor, val_dataset, "finetuned_val_transcriptions.txt", "Adapter Inference (Val)")

ft_wer = wer_metric.compute(predictions=ft_preds, references=ground_truths)
ft_cer = cer_metric.compute(predictions=ft_preds, references=ground_truths)

# print report
print("\n" + "="*65)
print(f" FINAL REPORT RESULTS (Validation Set)")
print("="*65)
print(f"{'METRIC':<10} | {'BASE MODEL':<15} | {'FINE-TUNED':<15} | {'IMPROVEMENT':<15}")
print("-" * 65)
print(f"{'WER':<10} | {base_wer:.6%}          | {ft_wer:.6%}          | {base_wer - ft_wer:.6%} (Abs)")
print(f"{'CER':<10} | {base_cer:.6%}          | {ft_cer:.6%}          | {base_cer - ft_cer:.6%} (Abs)")
print("="*65)


print("\n Generating FINAL Test files for submission...")

del model_ft
model_base = Wav2Vec2BertForCTC.from_pretrained(base_model_id).to("cuda")
run_inference_and_save(model_base, processor, test_dataset, "base_transcriptions.txt", "Testing Base Model")

model_ft = Wav2Vec2BertForCTC.from_pretrained(base_model_id)
model_ft = PeftModel.from_pretrained(model_ft, adapter_path).to("cuda")
run_inference_and_save(model_ft, processor, test_dataset, "finetuned_transcriptions.txt", "Testing Adapter Model")

print("\n✅ All files generated successfully! Check your folder.")


 PREPARING VALIDATION DATA FOR REPORT...
  Loading val metadata...
 Loaded metadata for 1617 files across 1 shards.


Extracting Labels: 0it [00:00, ?it/s]


 Loading Base Model...

 Base Inference (Val)...


Base Inference (Val): 0it [00:00, ?it/s]

✅ Saved transcriptions to base_val_transcriptions.txt

 Loading Adapter Model...

 Adapter Inference (Val)...


Adapter Inference (Val): 0it [00:00, ?it/s]

✅ Saved transcriptions to finetuned_val_transcriptions.txt

 FINAL REPORT RESULTS (Validation Set)
METRIC     | BASE MODEL      | FINE-TUNED      | IMPROVEMENT    
-----------------------------------------------------------------
WER        | 6.509828%          | 6.451538%          | 0.058290% (Abs)
CER        | 1.408015%          | 1.402160%          | 0.005856% (Abs)

 Generating FINAL Test files for submission...

 Testing Base Model...


Testing Base Model: 0it [00:00, ?it/s]

✅ Saved transcriptions to base_transcriptions.txt

 Testing Adapter Model...


Testing Adapter Model: 0it [00:00, ?it/s]

✅ Saved transcriptions to finetuned_transcriptions.txt

✅ All files generated successfully! Check your folder.


## 5. Generate Transcripts with KenLM decoding


1.  Install KenLM
2.  Train a 2-gram Model
3.  Evaluate Base Model/ Finetuned model + KenLM
4.  Generate submition files





### Install KenLM

In [None]:
!pip install pyctcdecode kenlm

In [1]:

print(" Installing build tools...")
!apt-get install -y cmake build-essential libboost-all-dev zlib1g-dev libbz2-dev liblzma-dev > /dev/null

print(" Downloading KenLM...")
!wget -O - https://kheafield.com/code/kenlm.tar.gz | tar xz

print(" Compiling KenLM (This takes ~1-2 mins)...")
%cd kenlm
!mkdir -p build
%cd build
!cmake .. > /dev/null
!make -j 4 > /dev/null
print("✅ KenLM compiled successfully!")

# Switch back to root
%cd /content

 Installing build tools...
Extracting templates from packages: 100%
 Downloading KenLM...
--2025-11-21 21:24:11--  https://kheafield.com/code/kenlm.tar.gz
Resolving kheafield.com (kheafield.com)... 129.80.89.152, 2603:c020:4009:8710:ca:11:17:0
Connecting to kheafield.com (kheafield.com)|129.80.89.152|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 491888 (480K) [application/octet-stream]
Saving to: ‘STDOUT’


2025-11-21 21:24:13 (680 KB/s) - written to stdout [491888/491888]

 Compiling KenLM (This takes ~1-2 mins)...
/content/kenlm
/content/kenlm/build
  Compatibility with CMake < 3.10 will be removed from a future version of
  CMake.

  Update the VERSION argument <min> value.  Or, use the <min>...<max> syntax
  to tell CMake that the project requires at least <min> but has been updated
  to work with policies introduced by <max> or earlier.

[0m
  Policy CMP0167 is not set: The FindBoost module is removed.  Run "cmake
  --help-policy CMP0167" for policy de

### Train a 2-gram Model

In [7]:
from pathlib import Path

# re-run the install library cell at top of the notebook
data_dir = Path('/content/drive/MyDrive/Colab/Adapt_Data')

manifest_dir = data_dir / "train_tarred" / "sharded_manifests_with_image"
output_corpus = "/content/drive/MyDrive/w2v-bert-kinyarwanda-adapter_2/kinyarwanda_corpus.txt"

def clean_text_winner(text):
    if not text: return ""
    text = text.lower()
    text = unicodedata.normalize("NFD", text)
    allowed_chars = set("abcdefghijklmnopqrstuvwxyz '")
    text = ''.join(c for c in text if c in allowed_chars)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

print(f"  Reading manifests from: {manifest_dir}")

# Ensure directory exists
if not manifest_dir.exists():
    print(f"❌ ERROR: Directory not found: {manifest_dir}")
    print("   Please check your Google Drive paths.")
else:
    manifest_files = sorted(list(manifest_dir.glob("manifest_*.json")))

    if not manifest_files:
        print("❌ ERROR: No 'manifest_*.json' files found in the directory.")
    else:
        line_count = 0
        with open(output_corpus, "w", encoding="utf-8") as out_f:
            for m_file in tqdm(manifest_files, desc="Extracting text"):
                with open(m_file, "r") as in_f:
                    for line in in_f:
                        if not line.strip(): continue
                        try:
                            data = json.loads(line)
                            if "text" in data:
                                cleaned = clean_text_winner(data["text"])
                                if cleaned:
                                    out_f.write(cleaned + "\n")
                                    line_count += 1
                        except json.JSONDecodeError:
                            continue

        print(f"✅ Success! Extracted {line_count} sentences to {output_corpus}")

        # -o 2  : Order 2 (2-gram)

        print("\n Training 2-gram Language Model...")

        !/content/kenlm/build/bin/lmplz -o 2 < {output_corpus} > 2gram_correct.arpa

        print("\n✅ Training Complete!")


        import shutil

        dest_folder = data_dir
        dest_folder.mkdir(exist_ok=True)
        dest_path = dest_folder / "2gram_correct.arpa"

        print(f"  Saving to: {dest_path}")
        shutil.copy("2gram_correct.arpa", dest_path)

  Reading manifests from: /content/drive/MyDrive/Colab/Adapt_Data/train_tarred/sharded_manifests_with_image


Extracting text:   0%|          | 0/27 [00:00<?, ?it/s]

✅ Success! Extracted 176129 sentences to /content/drive/MyDrive/w2v-bert-kinyarwanda-adapter_2/kinyarwanda_corpus.txt

 Training 2-gram Language Model...
=== 1/5 Counting and sorting n-grams ===
Reading /content/drive/MyDrive/w2v-bert-kinyarwanda-adapter_2/kinyarwanda_corpus.txt
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************
Unigram tokens 4997407 types 126024
=== 2/5 Calculating and sorting adjusted counts ===
Chain sizes: 1:1512288 2:45489160192
Statistics:
1 126024 D1=0.688896 D2=1.04489 D3+=1.33102
2 1063653 D1=0.760927 D2=1.04304 D3+=1.29685
Memory estimate for binary LM:
type       kB
probing 21896 assuming -p 1.5
probing 22389 assuming -r models -p 1.5
trie     9186 without quantization
trie     6200 assuming -q 8 -b 8 quantization 
trie     9186 assuming -a 22 array pointer compression
trie     6200 assuming -a 22 -q 8 -b 8 a

### Evaluate Base Model/ Finetuned model + KenLM

In [30]:
!pip install torchcodec

# !pip install pyctcdecode kenlm https://github.com/kpu/kenlm/archive/master.zip > /dev/null
from pyctcdecode import build_ctcdecoder

Collecting torchcodec
  Downloading torchcodec-0.8.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (9.7 kB)
Downloading torchcodec-0.8.1-cp312-cp312-manylinux_2_28_x86_64.whl (2.0 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.0 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m81.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torchcodec
Successfully installed torchcodec-0.8.1


In [32]:
import os
import io
import json
import tarfile
import unicodedata
import re
import random
import torch
import torchaudio
import evaluate
from pathlib import Path
from torch.utils.data import IterableDataset, get_worker_info
from tqdm import tqdm
from transformers import Wav2Vec2BertProcessor, Wav2Vec2BertForCTC
from peft import PeftModel
from pyctcdecode import build_ctcdecoder

# Configuration
DATA_DIR = Path('/content/drive/MyDrive/Colab/Adapt_Data')
BASE_MODEL_ID = "badrex/w2v-bert-2.0-kinyarwanda-asr-1000h"
ADAPTER_PATH = "/content/drive/MyDrive/w2v-bert-kinyarwanda-adapter_3/checkpoint-2500"
KENLM_PATH = "/content/drive/MyDrive/Colab/Adapt_Data/2gram_correct.arpa"

def evaluate_with_kenlm(dataset, base_model_id, adapter_path, kenlm_path):
    print(f"  Starting Evaluation with KenLM Decoding")

    if not os.path.exists(kenlm_path):
        # Check if it's a colab path that might not exist locally, but warn user
        print(f"WARNING: KenLM file not found at {kenlm_path}. Please check path.")

    print("   Loading Processor and Base Model...")
    try:
        processor = Wav2Vec2BertProcessor.from_pretrained(base_model_id)
        model = Wav2Vec2BertForCTC.from_pretrained(base_model_id)
    except Exception as e:
        print(f"Error loading base model: {e}")
        return None

    # Load Adapter & Merge
    print(f"   Loading Adapter from {adapter_path}...")
    try:
        model = PeftModel.from_pretrained(model, adapter_path)
        model = model.merge_and_unload()
    except Exception as e:
        print(f"Error loading adapter: {e}")
        return None

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    model.eval()

    print("   Building KenLM Decoder...")
    vocab_dict = processor.tokenizer.get_vocab()
    sorted_vocab_list = [k for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])]

    try:
        decoder = build_ctcdecoder(
            labels=sorted_vocab_list,
            kenlm_model_path=kenlm_path,
            alpha=0.5,
            beta=1.0
        )
    except Exception as e:
        print(f"Error building decoder: {e}")
        return None

    wer_metric = evaluate.load("wer")
    predictions = []
    references = []

    print("   Processing audio stream and decoding...")
    count = 0
    total_samples = len(dataset.meta_map) if hasattr(dataset, 'meta_map') else None
    for i, sample in tqdm(enumerate(dataset), desc="Decoding", total=total_samples):
        count += 1
        try:
            audio_input = sample["audio"]
            input_features = processor(
                audio=audio_input,
                sampling_rate=16000,
                return_tensors="pt"
            ).input_features.to(device)

            with torch.no_grad():
                logits = model(input_features).logits

            logits_np = logits.cpu().numpy()[0]
            transcription = decoder.decode(logits_np, beam_width=100)

            clean_pred = clean_text_winner(transcription)
            clean_ref = clean_text_winner(sample["text"])

            if clean_ref:
                predictions.append(clean_pred)
                references.append(clean_ref)
            else:
                print(f"Debug: Empty reference for sample {i} path {sample.get('path')}")

            if i < 2:
                print(f"\n--- Sample {i} ---")
                print(f"Ref:  {clean_ref}")
                print(f"Pred: {clean_pred}")

        except Exception as e:
            print(f"Skipping sample {i} due to error: {e}")
            continue

    if count == 0:
        print("❌ Dataset yielded 0 samples. Check data paths and metadata matching.")

    if len(references) > 0:
        final_wer = wer_metric.compute(predictions=predictions, references=references)
        print("\n" + "="*40)
        print(f" RESULTS WITH KENLM")
        print("="*40)
        print(f"WER: {final_wer:.6%}")
        print("="*40)
        return final_wer
    else:
        print("❌ No valid references found to compute WER.")
        return None

def evaluate_base_only(dataset, base_model_id, kenlm_path):
    print(f"  Starting BASE MODEL Evaluation (No Adapter) with KenLM")

    if not os.path.exists(kenlm_path):
        print(f"WARNING: KenLM file not found at {kenlm_path}. Please check path.")

    print("   Loading Processor and Base Model...")
    try:
        processor = Wav2Vec2BertProcessor.from_pretrained(base_model_id)
        model = Wav2Vec2BertForCTC.from_pretrained(base_model_id)
    except Exception as e:
        print(f"Error loading base model: {e}")
        return None

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    model.eval()

    print("   Building KenLM Decoder...")
    vocab_dict = processor.tokenizer.get_vocab()
    sorted_vocab_list = [k for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])]

    try:
        decoder = build_ctcdecoder(
            labels=sorted_vocab_list,
            kenlm_model_path=kenlm_path,
            alpha=0.5,
            beta=1.0
        )
    except Exception as e:
        print(f"Error building decoder: {e}")
        return None

    wer_metric = evaluate.load("wer")
    predictions = []
    references = []

    print("   Processing audio stream (Base Model)...")
    total_samples = len(dataset.meta_map) if hasattr(dataset, 'meta_map') else None
    for i, sample in tqdm(enumerate(dataset), desc="Decoding Base", total=total_samples):
        try:
            audio_input = sample["audio"]
            input_features = processor(
                audio=audio_input,
                sampling_rate=16000,
                return_tensors="pt"
            ).input_features.to(device)

            with torch.no_grad():
                logits = model(input_features).logits

            logits_np = logits.cpu().numpy()[0]
            transcription = decoder.decode(logits_np, beam_width=100)

            clean_pred = clean_text_winner(transcription)
            clean_ref = clean_text_winner(sample["text"])

            if clean_ref:
                predictions.append(clean_pred)
                references.append(clean_ref)

        except Exception as e:
            print(f"Skipping sample {i}: {e}")
            continue

    if len(references) > 0:
        final_wer = wer_metric.compute(predictions=predictions, references=references)
        print("\n" + "="*40)
        print(f" RESULTS: BASE MODEL ONLY")
        print("="*40)
        print(f"WER: {final_wer:.6%}")
        print("="*40)
        return final_wer
    else:
        return None


print("--- Loading Validation Data ---")
val_dataset_base = StreamingAudioDataset(DATA_DIR, split='val', shuffle=False)

base_wer = evaluate_base_only(val_dataset_base, BASE_MODEL_ID, KENLM_PATH)
val_dataset_adapter = StreamingAudioDataset(DATA_DIR, split='val', shuffle=False)
adapter_wer = evaluate_with_kenlm(val_dataset_adapter, BASE_MODEL_ID, ADAPTER_PATH, KENLM_PATH)

print("\n" + "#"*30)
print("FINAL COMPARISON")
print("#"*30)

if base_wer is not None:
    print(f"Base Model WER:    {base_wer:.6%}")
else:
    print("Base Model WER:    N/A (Evaluation failed or no valid references)")

if adapter_wer is not None:
    print(f"Adapter Model WER: {adapter_wer:.6%}")
else:
    print("Adapter Model WER: N/A (Evaluation failed or no valid references)")

if base_wer is not None and adapter_wer is not None:
    improvement = base_wer - adapter_wer
    print(f"⚠️ Improvement:    {improvement:.6%} (Higher is better)")
else:
    print("⚠️ Improvement:    N/A (Cannot calculate improvement due to failed evaluations)")

--- Loading Validation Data ---
  Loading val metadata...


Loading val manifests: 100%|██████████| 1/1 [00:00<00:00, 66.07it/s]

 Loaded metadata for 1617 files across 1 shards.
  Starting BASE MODEL Evaluation (No Adapter) with KenLM
   Loading Processor and Base Model...





   Building KenLM Decoder...
   Processing audio stream (Base Model)...


Decoding Base: 100%|██████████| 1617/1617 [15:42<00:00,  1.71it/s]



 RESULTS: BASE MODEL ONLY
WER: 6.029518%
  Loading val metadata...


Loading val manifests: 100%|██████████| 1/1 [00:00<00:00, 68.75it/s]


 Loaded metadata for 1617 files across 1 shards.
  Starting Evaluation with KenLM Decoding
   Loading Processor and Base Model...
   Loading Adapter from /content/drive/MyDrive/w2v-bert-kinyarwanda-adapter_3/checkpoint-2500...




   Building KenLM Decoder...
   Processing audio stream and decoding...


Decoding:   0%|          | 1/1617 [00:00<12:08,  2.22it/s]


--- Sample 0 ---
Ref:  umwana ari gukaraba mu ntoki gukaraba mu ntoki amazi meza n'isabune nibwo uba ukarabye mu ntoki neza kugira isuku ni isoko y'ubuzima
Pred: umwana ari gukaraba mu ntoki gukaraba mu ntoki amazi meza n'isabune nibwo uba ukarabye mu ntoki neza kugira isuku ni isoko y'ubuzima


Decoding:   0%|          | 2/1617 [00:00<12:10,  2.21it/s]


--- Sample 1 ---
Ref:  umuganga yambaye ga ari gutera urushinge ku rutugu rw'umugabo wicaye imbere ye arimukingira ibiza
Pred: umuganga yambaye ga ari gutera urushinge ku rutugu rw'umugabo wicaye imbere ye ari kumukingira ibiza


Decoding: 100%|██████████| 1617/1617 [15:43<00:00,  1.71it/s]


 RESULTS WITH KENLM
WER: 5.985218%

##############################
FINAL COMPARISON
##############################
Base Model WER:    6.029518%
Adapter Model WER: 5.985218%
⚠️ Improvement:    0.044300% (Lower is better)





### Generate submission files

In [36]:
# !pip install pyctcdecode https://github.com/kpu/kenlm/archive/master.zip > /dev/null

# Output Filenames
base_output = "/content/drive/MyDrive/w2v-bert-kinyarwanda-adapter_2/base_transcriptions_kenlm.txt"
adapted_output = "/content/drive/MyDrive/w2v-bert-kinyarwanda-adapter_2/finetuned_transcriptions_kenlm.txt"

def get_kenlm_decoder(processor, kenlm_path):
    print(f"  Building KenLM Decoder...")

    # Extract vocab and sort strictly by ID
    vocab_dict = processor.tokenizer.get_vocab()
    sorted_vocab_list = [k for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])]

    decoder = build_ctcdecoder(
        labels=sorted_vocab_list,
        kenlm_model_path=kenlm_path,
        alpha=0.5,
        beta=1.0   )
    return decoder

def run_inference_with_kenlm(model, processor, decoder, dataset, output_filename):
    print(f"\n  Generating transcriptions to: {output_filename}...")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.eval()
    model.to(device)

    # Get total samples for progress bar
    total_samples = len(dataset.meta_map) if hasattr(dataset, 'meta_map') else None

    with open(output_filename, "w", encoding="utf-8") as f:
        for i, sample in tqdm(enumerate(dataset), desc="Decoding", total=total_samples):
            try:
                input_features = processor(
                    audio=sample["audio"],
                    sampling_rate=16000,
                    return_tensors="pt"
                ).input_features.to(device)

                with torch.no_grad():
                    logits = model(input_features).logits

                logits_np = logits.cpu().numpy()[0]
                pred_text = decoder.decode(logits_np, beam_width=100)

                cleaned_pred = clean_text_winner(pred_text)

                filename_only = os.path.basename(sample['path'])
                f.write(f"{filename_only} {cleaned_pred}\n")

            except Exception as e:
                print(f"Error on sample {i}: {e}")
                f.write(f"error_{i} \n")
                continue

    print(f"✅ Saved: {output_filename}")

print("--- Loading Test Data ---")
test_dataset = StreamingAudioDataset(data_dir, split='test', shuffle=False)

print("--- Setting up Processor & KenLM ---")
processor = Wav2Vec2BertProcessor.from_pretrained(base_model_id)
kenlm_decoder = get_kenlm_decoder(processor, kenlm_path)

print("\n🔵 Evaluating BASE MODEL + KenLM...")
model_base = Wav2Vec2BertForCTC.from_pretrained(base_model_id)
run_inference_with_kenlm(model_base, processor, kenlm_decoder, test_dataset, base_output)
print(f"\n Done! Transcriptions saved to {base_output}")

del model_base
torch.cuda.empty_cache()

print("\n🟢 Evaluating ADAPTED MODEL + KenLM...")
model_adapted = Wav2Vec2BertForCTC.from_pretrained(base_model_id)
model_adapted = PeftModel.from_pretrained(model_adapted, adapter_path)

# Re-init dataset to restart iterator
test_dataset_2 = StreamingAudioDataset(data_dir, split='test', shuffle=False)
run_inference_with_kenlm(model_adapted, processor, kenlm_decoder, test_dataset_2, adapted_output)
print(f"\n Done! Transcriptions saved to {adapted_output}")

--- Loading Test Data ---
  Loading test metadata...


Loading test manifests: 100%|██████████| 1/1 [00:00<00:00, 78.60it/s]

 Loaded metadata for 1569 files across 1 shards.
--- Setting up Processor & KenLM ---





  Building KenLM Decoder...

🔵 Evaluating BASE MODEL + KenLM...

  Generating transcriptions to: /content/drive/MyDrive/w2v-bert-kinyarwanda-adapter_2/base_transcriptions_kenlm.txt...


Decoding: 100%|██████████| 1569/1569 [13:24<00:00,  1.95it/s]


✅ Saved: /content/drive/MyDrive/w2v-bert-kinyarwanda-adapter_2/base_transcriptions_kenlm.txt

 Done! Transcriptions saved to /content/drive/MyDrive/w2v-bert-kinyarwanda-adapter_2/base_transcriptions_kenlm.txt

🟢 Evaluating ADAPTED MODEL + KenLM...
  Loading test metadata...


Loading test manifests: 100%|██████████| 1/1 [00:00<00:00, 76.16it/s]

 Loaded metadata for 1569 files across 1 shards.

  Generating transcriptions to: /content/drive/MyDrive/w2v-bert-kinyarwanda-adapter_2/finetuned_transcriptions_kenlm.txt...



Decoding: 100%|██████████| 1569/1569 [14:00<00:00,  1.87it/s]

✅ Saved: /content/drive/MyDrive/w2v-bert-kinyarwanda-adapter_2/finetuned_transcriptions_kenlm.txt

 Done! Transcriptions saved to /content/drive/MyDrive/w2v-bert-kinyarwanda-adapter_2/finetuned_transcriptions_kenlm.txt





### Verify if the Lora was applied

In [42]:
import torch
from transformers import Wav2Vec2BertForCTC
from peft import PeftModel

base_model_id = "badrex/w2v-bert-2.0-kinyarwanda-asr-1000h"
adapter_path = "/content/drive/MyDrive/w2v-bert-kinyarwanda-adapter_3/checkpoint-2500"

print("1. Loading Base Model...")
base_model = Wav2Vec2BertForCTC.from_pretrained(base_model_id)

layer_to_check = base_model.wav2vec2_bert.encoder.layers[0].self_attn.linear_q
original_weights = layer_to_check.weight.clone()

print(f"   Original Weight Mean: {original_weights.mean().item():.6f}")

print("\n2. Loading and Merging Adapter...")
model_with_lora = PeftModel.from_pretrained(base_model, adapter_path)
model_merged = model_with_lora.merge_and_unload()

new_weights = model_merged.wav2vec2_bert.encoder.layers[0].self_attn.linear_q.weight

print(f"   Merged Weight Mean:   {new_weights.mean().item():.6f}")

diff = torch.abs(original_weights - new_weights).sum().item()
if diff > 1e-6:
    print(f"\n✅ SUCCESS: Weights have changed! (Difference: {diff:.6f})")
    print("   The LoRA adapter was successfully merged.")
else:
    print("\n❌ WARNING: Weights are identical. The adapter might not have been applied.")

1. Loading Base Model...
   Original Weight Mean: 0.000014

2. Loading and Merging Adapter...
   Merged Weight Mean:   0.000014

✅ SUCCESS: Weights have changed! (Difference: 225.685837)
   The LoRA adapter was successfully merged.


## 6. Push Model to hub


1.   Merge Model and upload to huggingface
2.   Create a gradio demo app



In [43]:
from huggingface_hub import login, HfApi
login()

repo_id = "ElvisTata2024/Kinyarwanda-Health-ASR"
model = Wav2Vec2BertForCTC.from_pretrained(base_model_id)

model = PeftModel.from_pretrained(model, adapter_path)
model = model.merge_and_unload()
processor = Wav2Vec2BertProcessor.from_pretrained(base_model_id)

print(f"🚀 Pushing model to {repo_id}...")

model.push_to_hub(repo_id)
processor.push_to_hub(repo_id)

print(" Uploading KenLM ARPA file...")
api = HfApi()
api.upload_file(
    path_or_fileobj=kenlm_path,
    path_in_repo="2gram_correct.arpa",
    repo_id=repo_id,
    repo_type="model"
)

print("✅ Done! Your model is live at: https://huggingface.co/" + repo_id)

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

🚀 Pushing model to ElvisTata2024/Kinyarwanda-Health-ASR...


README.md: 0.00B [00:00, ?B/s]

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...redyfix/model.safetensors:   0%|          | 14.7kB / 2.32GB            

No files have been modified since last commit. Skipping to prevent empty commit.


 Uploading KenLM ARPA file...


Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...t_Data/2gram_correct.arpa: 100%|##########| 33.1MB / 33.1MB            

No files have been modified since last commit. Skipping to prevent empty commit.


✅ Done! Your model is live at: https://huggingface.co/ElvisTata2024/Kinyarwanda-Health-ASR


### Gradio App
https://huggingface.co/spaces/ElvisTata2024/Kinyarwanda-Health-ASR

In [44]:
import gradio as gr
import torch
import librosa
import numpy as np
import re
import unicodedata
from transformers import Wav2Vec2BertProcessor, Wav2Vec2BertForCTC
from pyctcdecode import build_ctcdecoder
from huggingface_hub import hf_hub_download

MODEL_ID = "ElvisTata2024/Kinyarwanda-Health-ASR"
KENLM_FILENAME = "2gram_correct.arpa"

print("⏳ Loading Model and Tools...")

# 1. Load Model & Processor
processor = Wav2Vec2BertProcessor.from_pretrained(MODEL_ID)
model = Wav2Vec2BertForCTC.from_pretrained(MODEL_ID)
model.eval()

# 2. Download and Build KenLM Decoder
print("   Downloading ARPA file...")
kenlm_path = hf_hub_download(repo_id=MODEL_ID, filename=KENLM_FILENAME)

print("   Building Decoder...")
vocab_dict = processor.tokenizer.get_vocab()
sorted_vocab_list = [k for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])]

decoder = build_ctcdecoder(
    labels=sorted_vocab_list,
    kenlm_model_path=kenlm_path,
    alpha=0.5,
    beta=1.0
)

def clean_text_winner(text):
    if not text: return ""
    text = text.lower()
    text = unicodedata.normalize("NFD", text)
    allowed_chars = set("abcdefghijklmnopqrstuvwxyz '")
    text = ''.join(c for c in text if c in allowed_chars)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

# 4. Inference Function
def transcribe(audio_filepath):
    if audio_filepath is None:
        return "Please upload an audio file."

    try:
        # Load audio and resample to 16k
        audio_input, _ = librosa.load(audio_filepath, sr=16000)

        # Tokenize
        inputs = processor(audio_input, sampling_rate=16000, return_tensors="pt")

        # Forward Pass
        with torch.no_grad():
            logits = model(**inputs).logits

        # Decode with KenLM
        logits_np = logits.cpu().numpy()[0]
        transcription = decoder.decode(logits_np, beam_width=50)

        # Clean
        final_text = clean_text_winner(transcription)
        return final_text

    except Exception as e:
        return f"Error during transcription: {str(e)}"

# Build Gradio Interface
desc = """
### Kinyarwanda ASR (Wav2Vec2-BERT + Adapter + KenLM)
Upload a .wav or .mp3 file, or record your voice to generate a transcription.
"""

iface = gr.Interface(
    fn=transcribe,
    inputs=gr.Audio(type="filepath", label="Upload or Record Audio"),
    outputs="text",
    title="Kinyarwanda Speech Recognition",
    description=desc,
    examples=[], )

if __name__ == "__main__":
    iface.launch()

⏳ Loading Model and Tools...


preprocessor_config.json:   0%|          | 0.00/275 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json:   0%|          | 0.00/331 [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/30.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/548 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/2.32G [00:00<?, ?B/s]

   Downloading ARPA file...


2gram_correct.arpa:   0%|          | 0.00/33.1M [00:00<?, ?B/s]



   Building Decoder...
It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://f658de5fc0906ccb45.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
