# ASR Benchmarking Notebook for Multitask-National-Speech-Corpus-v1


In [113]:
import os
import time
import torch
import pandas as pd
import numpy as np
from datasets import load_dataset, Audio
from transformers import pipeline, AutoProcessor, AutoModelForSpeechSeq2Seq, WhisperProcessor, WhisperForConditionalGeneration
from jiwer import wer, mer, wip, wil
from tqdm import tqdm

In [114]:
# ----------------------- Configurations -----------------------
MODEL_ID = "openai/whisper-small"  # Changeable to local or finetuned models
DATA_DIR = "ASR-PART2-Test"        # Changeable for different partitions
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_SAMPLES = 100                  # Number of samples for quick benchmarking

In [115]:
# ----------------------- Load Dataset -----------------------
def load_data(data_dir, split='train', num_samples=None):
    dataset = load_dataset('MERaLiON/Multitask-National-Speech-Corpus-v1', data_dir=data_dir)[split]
    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
    if num_samples:
        dataset = dataset.select(range(min(num_samples, len(dataset))))
    return dataset

In [116]:
# ----------------------- Load ASR Model -----------------------
def load_asr_model(model_id):
    try:
        # Try using a generic pipeline-based ASR (e.g. Whisper, wav2vec2)
        asr_pipeline = pipeline(
            "automatic-speech-recognition",
            model=model_id,
            device=0 if DEVICE == "cuda" else -1
        )

        def wrapped_pipeline(audio_array, sampling_rate=16000):
            return asr_pipeline(audio_array)

        return wrapped_pipeline

    except Exception as e:
        print(f"Falling back to custom processor loading for model {model_id} due to: {e}")
        # Custom processor and model logic
        processor = AutoProcessor.from_pretrained(model_id)
        model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id).to(DEVICE)

        def custom_asr_pipeline(audio_array, sampling_rate=16000):
            input_features = processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt").input_features.to(DEVICE)
            with torch.no_grad():
                generated_ids = model.generate(input_features=input_features)
            transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
            return {"text": transcription}

        return custom_asr_pipeline


In [117]:
# ----------------------- Benchmarking Function -----------------------
def benchmark_asr(asr_pipeline, dataset):
    results = []
    total_start_time = time.time()

    for example in tqdm(dataset, desc="Transcribing"):
        audio_array = example["context"]["array"]
        reference = example["answer"]

        start = time.time()
        prediction = asr_pipeline(audio_array, sampling_rate=16000)["text"]
        end = time.time()

        results.append({
            "reference": reference,
            "prediction": prediction,
            "time": end - start
        })

    total_time = time.time() - total_start_time
    result_df = pd.DataFrame(results)
    return result_df, total_time

In [118]:
# ----------------------- Evaluation Metrics -----------------------
def evaluate_metrics(result_df):
    result_df["wer"] = result_df.apply(lambda x: wer(x["reference"], x["prediction"]), axis=1)
    result_df["mer"] = result_df.apply(lambda x: mer(x["reference"], x["prediction"]), axis=1)
    result_df["wil"] = result_df.apply(lambda x: wil(x["reference"], x["prediction"]), axis=1)
    result_df["wip"] = result_df.apply(lambda x: wip(x["reference"], x["prediction"]), axis=1)

    average_metrics = {
        "Average WER": result_df["wer"].mean(),
        "Average MER": result_df["mer"].mean(),
        "Average WIL": result_df["wil"].mean(),
        "Average WIP": result_df["wip"].mean(),
        "Average Time per Sample (s)": result_df["time"].mean()
    }

    return result_df, average_metrics

In [119]:
# ----------------------- Identify Best/Worst Words -----------------------
def word_analysis(result_df):
    from collections import Counter
    word_errors = []
    for _, row in result_df.iterrows():
        ref_words = row["reference"].lower().split()
        pred_words = row["prediction"].lower().split()
        missed = set(ref_words) - set(pred_words)
        for word in missed:
            word_errors.append(word)

    error_counter = Counter(word_errors)
    most_common_errors = error_counter.most_common(10)
    least_common_errors = error_counter.most_common()[-10:]
    return most_common_errors, least_common_errors

In [120]:
from tqdm import tqdm
import time
from datasets import load_dataset
from jiwer import wer
import numpy as np
import re

def clean_text(text):
    """
    Clean the reference text by:
    - Removing speaker labels like <speaker1>: or <speaker2>:
    - Removing round or square brackets but keeping the content inside.
    - Converting to lowercase.
    """
    # Remove speaker tags like <speaker1>: or <speaker2>:
    text = re.sub(r"<speaker\d+>:\s*", "", text, flags=re.IGNORECASE)

    # Remove round or square brackets but keep content inside
    text = re.sub(r"\[(.*?)\]", r"\1", text)
    text = re.sub(r"\((.*?)\)", r"\1", text)

    # Convert to lowercase and strip surrounding whitespace
    return text.lower().strip()


def run_benchmark(model_id, data_dir="ASR-PART1-Train", num_samples=100, batch_size=8):
    print(f"\n📊 Running benchmark for model: {model_id}")

    # Load dataset
    dataset = load_dataset("MERaLiON/Multitask-National-Speech-Corpus-v1", data_dir=data_dir)["train"]
    dataset = dataset.select(range(min(num_samples, len(dataset))))

    # Load ASR model pipeline
    model_fn = load_asr_model(model_id)

    # Prepare audio and references
    audio_arrays = [sample["context"]["array"] for sample in dataset]
    references = [sample["answer"] for sample in dataset]

    predictions = []
    start_time = time.time()

    # Run batched inference
    for i in tqdm(range(0, len(audio_arrays), batch_size), desc=f"Evaluating {model_id}", ncols=100):
        batch_audio = audio_arrays[i:i+batch_size]
        batch_refs = references[i:i+batch_size]

        try:
            # Support both batch and single-audio pipelines
            preds = model_fn(batch_audio)

            # Ensure output is iterable
            if isinstance(preds, dict):
                preds = [preds]

            for ref, pred_dict in zip(batch_refs, preds):
                hyp = pred_dict.get("text", "[ERROR]").strip()

                # Clean both reference and prediction
                ref_cleaned = clean_text(ref)
                hyp_cleaned = clean_text(hyp)

                predictions.append({"reference": ref_cleaned, "prediction": hyp_cleaned})

        except Exception as e:
            for ref in batch_refs:
                predictions.append({"reference": clean_text(ref), "prediction": f"[ERROR: {e}]"})

        # Dynamically print progress without breaking the progress bar
        tqdm.write(f"Processed {i + batch_size}/{len(audio_arrays)} samples.")

    end_time = time.time()
    total_time = end_time - start_time

    # Compute metrics
    references_clean = [x["reference"] for x in predictions]
    predictions_clean = [x["prediction"] for x in predictions]
    computed_wer = wer(references_clean, predictions_clean)

    # Print metrics
    print(f"\n✅ Benchmark complete for {model_id}")
    print(f"⏱️ Total time: {total_time:.2f} seconds")
    print(f"🧠 WER: {computed_wer:.4f}")

    # Sample output
    print("\n🔍 Sample predictions:")
    for sample in predictions[:5]:
        print(f"REF: {sample['reference']}")
        print(f"HYP: {sample['prediction']}")
        print("-----")

    return {
        "model_id": model_id,
        "wer": computed_wer,
        "runtime_sec": total_time,
        "results": predictions,
    }

In [121]:
# ----------------------- Multi-Model Benchmark Loop -----------------------
import gc

def benchmark_multiple_models(model_ids, data_dir="ASR-PART5-Test", num_samples=100, batch_size=8):
    comparison_results = []

    for model_id in model_ids:
        gc.collect()
        torch.cuda.empty_cache() # Clear the GPU cache

        print(f"\n===== Benchmarking Model: {model_id} =====")

        # Run benchmark and get results
        result = run_benchmark(model_id=model_id, data_dir=data_dir, num_samples=num_samples, batch_size=batch_size)

        # Collecting metrics for comparison
        metrics = {"WER": result["wer"], "Runtime (sec)": result["runtime_sec"], "Model": model_id}
        comparison_results.append(metrics)

    # Convert list of dicts to DataFrame for easy comparison
    comparison_df = pd.DataFrame(comparison_results)
    comparison_df = comparison_df.set_index("Model")

    return comparison_df



In [None]:
DEEPINFRA_API_KEY = "API KEY"  
DEEPINFRA_API_BASE = "https://api.deepinfra.com/v1/openai"

In [126]:
def correct_with_llm(hypotheses, llm_type="meta-llama/Llama-3-70b-chat-hf"):
    """Use DeepInfra-hosted LLM for correction"""
    from openai import OpenAI
    
    client = OpenAI(
        api_key=DEEPINFRA_API_KEY,
        base_url=DEEPINFRA_API_BASE
    )
    
    prompt = f"""Correct this speech transcription, which is in Singlish and likely to contain words related to Singapore. You may need to correct it based on the knowledge you have about Singapore, but you shall correct it only when it's necessary. Using these ASR hypotheses and their confidence scores:
    
    {chr(10).join([f'- {h["text"]} (confidence: {h["confidence"]:.2f})' for h in hypotheses])}
    
    Do not include any other text or explanation. Just provide the corrected transcript."""
    #
    
    response = client.chat.completions.create(
        model=llm_type,
        messages=[{"role": "user", "content": prompt}],
        temperature=0.2,
        max_tokens=500
    )
    
    return response.choices[0].message.content.strip()

In [None]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration

def generate_whisper_hypotheses(
    audio_array,
    model_id="openai/whisper-small.en",
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    sampling_rate=16000,
    num_hypotheses=3
):
    """
    Generate Whisper ASR hypotheses with one confidence score each.

    Returns:
        List[dict]: [{"text": str, "confidence": float}, ...]
    """

    model = WhisperForConditionalGeneration.from_pretrained(model_id)
    processor = WhisperProcessor.from_pretrained(model_id)
    model.to(device)
    model.eval()

    
    inputs = processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt")
    features = inputs.input_features.to(device)

    #Generate N-best sequences with scores
    outputs = model.generate(
        features,
        do_sample=True,            
        num_return_sequences=3,    
        max_new_tokens=256,
        top_k=50,                  
        top_p=0.95,                
        temperature=0.7,
        output_scores=True,
        return_dict_in_generate=True,
    )

    hypotheses = []
    for idx in range(num_hypotheses):
        seq_ids = outputs.sequences[idx]
        text = processor.decode(seq_ids, skip_special_tokens=True).strip()


        if hasattr(outputs, "sequences_scores") and outputs.sequences_scores is not None:
            sum_logprob = outputs.sequences_scores[idx]
        else:

            sum_logprob = 0.0

            for step, step_scores in enumerate(outputs.scores):
                logprobs = torch.log_softmax(step_scores, dim=-1)
                token_id = seq_ids[step + 1]  # skip <s> token

                sum_logprob += logprobs[idx, token_id]

        token_count = seq_ids.shape[-1] - 1
        avg_logprob = (sum_logprob / token_count).item()

        hypotheses.append({"text": text, "confidence": avg_logprob})

    return hypotheses

In [None]:
# ----------------------- Enhanced Benchmarking with Progress -----------------------
def benchmark_llm_correction(asr_model_id, llm_models, num_samples=50):
    dataset = load_data(DATA_DIR, num_samples=num_samples)

    
    results = []
    total_models = len(llm_models)
    total_samples = len(dataset)
    
    with tqdm(total=total_models * total_samples, desc="Overall Progress") as main_pbar:
        for llm_idx, llm in enumerate(llm_models):
            model_progress = {
                "processed": 0,
                "examples": []
            }
            
            with tqdm(dataset, desc=f"Model {llm_idx+1}/{total_models}: {llm[:20]}...") as model_pbar:
                corrected = []
                for sample_idx, example in enumerate(dataset):
                    # Generate ASR hypotheses
                    hypotheses = generate_whisper_hypotheses(example["context"]["array"], model_id=asr_model_id)
                    # LLM Correction
                    corrected_text = correct_with_llm(hypotheses, llm)
                    
                    # Store results
                    corrected.append({
                        "original": hypotheses[0]["text"],
                        "corrected": clean_text(corrected_text),
                        "reference": clean_text(example["answer"])
                    })
                    
                    model_progress["processed"] += 1
                    main_pbar.update(1)
                    model_pbar.update(1)
                    
                    # Show examples every 10% of samples
                    if (sample_idx + 1) % max(1, total_samples//10) == 0:
                        progress_percent = int(100 * (sample_idx + 1) / total_samples)
                        latest_example = corrected[-1]
                        
                        example_output = f"""
                        \n=== [Model: {llm} | Progress: {progress_percent}%] ===
                        [Original] {latest_example['original']}
                        [Corrected] {latest_example['corrected']}
                        [Reference] {latest_example['reference']}
                        """
                        tqdm.write(example_output)
                        model_progress["examples"].append(example_output)

                # Calculate metrics
                wer_original = wer([c["reference"] for c in corrected], [c["original"] for c in corrected])
                wer_corrected = wer([c["reference"] for c in corrected], [c["corrected"] for c in corrected])
                
                results.append({
                    "LLM": llm,
                    "Original WER": wer_original,
                    "Corrected WER": wer_corrected,
                    "Improvement": wer_original - wer_corrected,
                    "Examples": model_progress["examples"]
                })

    # Final output with examples
    print("\n=== Benchmark Complete ===")
    for result in results:
        print(f"\nModel: {result['LLM']}")
        print(f"WER Improvement: {result['Improvement']:.3f}")
        print("Sample Corrections:")
        for example in result['Examples'][::2]:  
            print(example)
    
    return pd.DataFrame(results)

In [None]:

llm_models = [
    "google/gemma-3-27b-it"

]

correction_results = benchmark_llm_correction(
    "openai/whisper-small.en",
    llm_models,
    num_samples=30
)

display(correction_results.drop(columns=['Examples']))

Overall Progress:  10%|█         | 3/30 [00:17<02:33,  5.70s/it]
Overall Progress:  10%|█         | 3/30 [00:17<02:33,  5.70s/it]                  


                        
=== [Model: google/gemma-3-27b-it | Progress: 10%] ===
                        [Original] Castle, Tender, Beef, Raderoo, and Hellmuth.
                        [Corrected] cassu, tender beef, radarou and hellmuths.
                        [Reference] katsu tendon beef vindaloo and hummus
                        


Overall Progress:  20%|██        | 6/30 [00:33<02:13,  5.56s/it]
Overall Progress:  20%|██        | 6/30 [00:33<02:13,  5.56s/it]                  


                        
=== [Model: google/gemma-3-27b-it | Progress: 20%] ===
                        [Original] The old MRT house, home sweet sand drive and trabeca by the waterfront
                        [Corrected] the old mrt house, hong huat drive and tribeca by the waterfront.
                        [Reference] the old admiralty house hon sui sen drive and tribeca by the waterfronts
                        


Overall Progress:  30%|███       | 9/30 [00:50<01:59,  5.71s/it]
Overall Progress:  30%|███       | 9/30 [00:50<01:59,  5.71s/it]                  


                        
=== [Model: google/gemma-3-27b-it | Progress: 30%] ===
                        [Original] Addu Rakhim Isha, Anthony Chen and Owen Bun-Tek
                        [Corrected] addu, rahim, isha, anthony chen and ong boon tek.
                        [Reference] abdul rahim ishak anthony chen and ong boon tat
                        


Overall Progress:  40%|████      | 12/30 [01:08<01:43,  5.75s/it]
Overall Progress:  40%|████      | 12/30 [01:08<01:43,  5.75s/it]                  


                        
=== [Model: google/gemma-3-27b-it | Progress: 40%] ===
                        [Original] How much is a cathera?
                        [Corrected] how much is a kaya?
                        [Reference] how much is air karthira
                        


Overall Progress:  50%|█████     | 15/30 [01:25<01:23,  5.60s/it]
Overall Progress:  50%|█████     | 15/30 [01:25<01:23,  5.60s/it]                  


                        
=== [Model: google/gemma-3-27b-it | Progress: 50%] ===
                        [Original] Lift the BIMI off at 3 Guru Singh Saba
                        [Corrected] lift me up at 3 gurdwara singh sabha.
                        [Reference] arleth is dropping me off at sri guru singh sabha
                        


Overall Progress:  60%|██████    | 18/30 [01:41<01:05,  5.47s/it]
Overall Progress:  60%|██████    | 18/30 [01:41<01:05,  5.47s/it]                  


                        
=== [Model: google/gemma-3-27b-it | Progress: 60%] ===
                        [Original] 856 41745
                        [Corrected] 856-41745
                        [Reference] eight five six four one seven four five
                        


Overall Progress:  70%|███████   | 21/30 [01:59<00:51,  5.72s/it]
Overall Progress:  70%|███████   | 21/30 [01:59<00:51,  5.72s/it]                  


                        
=== [Model: google/gemma-3-27b-it | Progress: 70%] ===
                        [Original] John Funs, Nicola, Keok, Ben, Kim, Francis and Elek Keok.
                        [Corrected] john fungs, nicole pang, ken francis and ella keough.
                        [Reference] john fearns nicoll kwok peng kin francis and alec kuok
                        


Overall Progress:  80%|████████  | 24/30 [02:17<00:36,  6.06s/it]
Overall Progress:  80%|████████  | 24/30 [02:17<00:36,  6.06s/it]                  


                        
=== [Model: google/gemma-3-27b-it | Progress: 80%] ===
                        [Original] Hiyashi chuka taaki komi gohan en fugo.
                        [Corrected] hiyashi chuka, takikomi gohan, enough or not?
                        [Reference] hiyashi chuka takikomi gohan and fugu
                        


Overall Progress:  90%|█████████ | 27/30 [02:34<00:17,  5.88s/it]
Overall Progress:  90%|█████████ | 27/30 [02:34<00:17,  5.88s/it]                  


                        
=== [Model: google/gemma-3-27b-it | Progress: 90%] ===
                        [Original] Campbell loves chut-papri.
                        [Corrected] campbell loves chutney papri.
                        [Reference] campbell loves chaat papri
                        


Overall Progress: 100%|██████████| 30/30 [02:58<00:00,  7.18s/it]
Model 1/1: google/gemma-3-27b-i...: 100%|██████████| 30/30 [02:58<00:00,  5.95s/it]
Overall Progress: 100%|██████████| 30/30 [02:58<00:00,  5.95s/it]


                        
=== [Model: google/gemma-3-27b-it | Progress: 100%] ===
                        [Original] Tell me the price of Guizhou.
                        [Corrected] tell me the price of guizhou.
                        [Reference] tell me the price of kuay chap
                        

=== Benchmark Complete ===

Model: google/gemma-3-27b-it
WER Improvement: 0.147
Sample Corrections:

                        
=== [Model: google/gemma-3-27b-it | Progress: 10%] ===
                        [Original] Castle, Tender, Beef, Raderoo, and Hellmuth.
                        [Corrected] cassu, tender beef, radarou and hellmuths.
                        [Reference] katsu tendon beef vindaloo and hummus
                        

                        
=== [Model: google/gemma-3-27b-it | Progress: 30%] ===
                        [Original] Addu Rakhim Isha, Anthony Chen and Owen Bun-Tek
                        [Corrected] addu, rahim, isha, anthony chen and ong boon tek.
     




Unnamed: 0,LLM,Original WER,Corrected WER,Improvement
0,google/gemma-3-27b-it,0.659341,0.512821,0.14652


In [131]:

# llm_models = [
#     "google/gemma-3-27b-it"
# ]

# correction_results = benchmark_llm_correction(
#     "mjwong/whisper-small-singlish",
#     llm_models,
#     num_samples=30
# )

# display(correction_results.drop(columns=['Examples']))