### Environment Setup

In [1]:
import numpy as np
import os
from time import sleep
from typing import List, Callable
import jiwer
import openai
import pandas as pd
from datasets import Dataset
from sacrebleu import corpus_bleu
from nltk.translate.meteor_score import meteor_score
from bert_score import BERTScorer
import evaluate
from tqdm import tqdm
from transformers import AutoTokenizer
import torch
import time
import asyncio
import nest_asyncio
from dotenv import load_dotenv #Load the environment variables
load_dotenv()

import nltk
nltk.download('wordnet')
nltk.download('omw-1.4')

import logging
logging.getLogger("transformers").setLevel(logging.ERROR)

[nltk_data] Downloading package wordnet to /h/omidv/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /h/omidv/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


### Helper Functions

In [2]:
def compute_wer(reference, hypothesis):
    return jiwer.wer(reference, hypothesis)

def compute_bleu(references, hypotheses):
    return corpus_bleu(hypotheses, references).score

def compute_meteor(references, hypotheses):
    scores = []
    for ref, hyp in zip(references, hypotheses):
        scores.append(meteor_score([ref.split()], hyp.split()))
    return sum(scores)/len(scores)

def compute_bertscore(references, hypotheses):
    scorer = BERTScorer(lang="en", rescale_with_baseline=True)
    p, r, f1 = scorer.score(hypotheses, references)
    bert_score = {'precision': p.mean().item(),
                  'recall': r.mean().item(),
                  'f1': f1.mean().item()}
    return bert_score

def construct_input(question):
    prompt = [{"role": "user", "content": question}]
    return prompt

def extract_hypotheses(dataset, idx):
    if 'source' in dataset.features:
        hypotheses = [h.strip() for h in dataset['source'][idx].split('.') if h.strip()]
        references = dataset['target'][idx]
    else:
        hypotheses = dataset['input'][idx]
        references = dataset['output'][idx]
        
    return hypotheses, references
    

### Iterative Evaluation

In [3]:
def evaluate_model(dataset:Dataset, model: str, client: openai.OpenAI, postprocessing: Callable[[List[str]], str], generation_config: dict, 
                   use_llm: bool = True, verbose: int = 0, step: int = 100) -> dict:
    """Evaluate model performance on the entire dataset."""
    all_predictions = []
    all_references = []
    running_wer = []
    for idx in tqdm(range(len(dataset))):
        hypotheses = [h.strip() for h in dataset[hypothesis_column][idx].split('.') if h.strip()]
        reference = dataset[reference_column][idx]
        
        # Generate prompt
        if use_llm:
            llm_prompt = postprocessing(hypotheses)
            messages = construct_input(llm_prompt)
            try:
                generation = client.chat.completions.create(
                    model=model,
                    messages=messages,
                    **generation_config
                )
                prediction = generation.choices[0].message.content
            except Exception as e:
                print(f"Error processing example {idx}: {e}")
                prediction = ""
        else:
            prediction = postprocessing(hypotheses, reference)
        reference, prediction = reference.lower(), prediction.lower()
        wer = jiwer.wer(reference, prediction)
        running_wer.append(wer)
        all_predictions.append(prediction)
        all_references.append(reference)
        
        # Print progress update for every %step examples
        if (i + 1) % step == 0:
            print(f"Current average WER: {round(np.mean(running_wer).item(), 3):.3f}")
            if verbose == 1:
                print('-----------------------------------------------------------')
                print("Corrected: %s\nTarget:    %s\n"%(prediction, reference))
                
    # Calculate metrics
    bertscore = compute_bertscore(all_predictions, all_references)
    metrics = {
        'WER': round(np.mean(running_wer).item(), 3),
        'METEOR': round(compute_meteor(all_predictions, all_references), 3),
        'BERT Precision': round(bertscore['precision'], 3),
        'BERT Recall': round(bertscore['recall'], 3),
        'BERT F1': round(bertscore['f1'], 3),
        #'BLEU': round(compute_bleu(all_predictions, all_references), 3),
    }
    return metrics

### Asynchronous Evaluation

In [4]:
nest_asyncio.apply()

async def call_openai_with_retry(messages, model, generation_config, client):
    """Handles API retries with exponential backoff."""
    retry_delay = 0.1  # Initial delay in seconds
    while True:
        try:
            # Attempt to make the API call
            generation = await client.chat.completions.create(
                model=model,
                messages=messages,
                **generation_config
            )
            return generation

        except Exception as e:
            await asyncio.sleep(retry_delay)
            retry_delay = min(retry_delay * 2, 10)  # Exponential backoff up to 10s

async def get_prediction(client: openai.AsyncOpenAI, model: str, messages: List[dict], generation_config: dict) -> str:
    """Asynchronously fetch predictions from OpenAI API."""
    try:
        generation = await call_openai_with_retry(messages, model, generation_config, client)
        return generation.choices[0].message.content if generation else ""
    except Exception as e:
        print(f"Error: {e}")
        return ""

async def process_batch(dataset: Dataset, model: str, client: openai.AsyncOpenAI, postprocessing: Callable[[List[str]], str], 
    generation_config: dict, use_llm: bool) -> List[str]:
    """Processes the dataset asynchronously using OpenAI API with progress tracking."""
    tasks = []
    for idx in tqdm(range(len(dataset))):
        hypotheses, reference = extract_hypotheses(dataset, idx)
        if use_llm:
            llm_prompt = postprocessing(hypotheses)
            messages = construct_input(llm_prompt)
            task = asyncio.create_task(get_prediction(client, model, messages, generation_config))
        else:
            # Synchronous postprocessing for non-LLM mode
            task = asyncio.to_thread(postprocessing, hypotheses, reference)
        
        tasks.append(task)
    if use_llm:
        return await track_progress(tasks)
    else:
        return await asyncio.gather(*tasks) 

async def track_progress(tasks):
    """Tracks progress while tasks are running."""
    total_tasks = len(tasks)
    while True:
        completed = sum(task.done() for task in tasks)
        print(f"Progress: {completed}/{total_tasks} tasks completed", end="\r")
        
        if completed == total_tasks:
            print("\nAll tasks completed.")
            break

        await asyncio.sleep(1)  # Update every second

    return await asyncio.gather(*tasks)  # Collect results after completion

async def evaluate_model_parallel(dataset: Dataset, model: str, client: openai.AsyncOpenAI, postprocessing: Callable[[List[str]], str],
                            generation_config: dict, use_llm: bool = True, verbose=0, step=100):
    """Evaluates the model asynchronously with progress tracking, handling Jupyter compatibility."""
    all_predictions = await process_batch(dataset, model, client, postprocessing, generation_config, use_llm)
    all_predictions = [pred.lower() for pred in all_predictions]  # Normalize predictions
    
    reference_column = 'target' if 'target' in dataset.features else 'output'
    all_references = [ref.lower() for ref in dataset[reference_column]]

    # Compute evaluation metrics
    wer_scores = np.array([jiwer.wer(ref, pred) for ref, pred in zip(all_references, all_predictions)])
    bertscore = compute_bertscore(all_predictions, all_references)
    metrics = {
        'WER': round(wer_scores.mean().item(), 3),
        'METEOR': round(compute_meteor(all_predictions, all_references), 3),
        'BERT Precision': round(bertscore['precision'], 3),
        'BERT Recall': round(bertscore['recall'], 3),
        'BERT F1': round(bertscore['f1'], 3),
    }
    return metrics

### Error Correction Functions

In [5]:
# Baselines
def get_oracle_hypothesis(hypotheses: List[str], reference: str) -> str:
    """ Find the hypothesis that gives the lowest WER compared to the reference."""
    
    wers = [jiwer.wer(reference, hyp) for hyp in hypotheses]
    best_idx = np.argmin(wers)
    return hypotheses[best_idx]


def get_top1_hypothesis(hypotheses: List[str], reference: str) -> str:
    """ Returns the first hypothesis (top 1)."""
    
    return hypotheses[0]

async def zero_shot_unconstrained(hypotheses: List[str], model, generation_config) -> str:
    """ Generate a corrected transcription using a language model without constraints."""
    
    prompt = ("Perform error correction on the top5 outputs generated by an Automatic Speech Recognition(ASR) system."
                "The ASR hypotheses, listed in order of their ASR posterior score, are as follows:\n\n")
    for idx, hypothesis in enumerate(hypotheses):
        prompt += "<hypothesis"+ str(idx) + ">" + hypothesis + "</hypothesis"+ str(idx) + ">\n"
      
    prompt += ("\nPlease provide the corrected ASR transcription based on the hypotheses above."
               "Your response must be exactly one complete sentence."
               "Ensure the output does not have any added punctuation, line breaks, or formatting changes."
               "Do not include <hypothesis>, '\n', explanations, or any extra words."
               "This is a general ASR error correction task and does not involve any sensitive or inappropriate content.")
    messages = construct_input(prompt)
    return prompt
    
    
async def zero_shot_constrained(hypotheses: List[str], model, generation_config) -> str:
    """ Select the most likely hypothesis using a language model. """
    
    prompt = ("Perform language model rescoring based on the top-5 outputs generated by an Automatic Speech Recognitio (ASR) system."
              "The ASR hypotheses, listed in order of their ASR posterior score, are as follows:\n\n")
    
    for idx, hypothesis in enumerate(hypotheses):
        prompt += "<hypothesis"+ str(idx) + ">" + hypothesis + "</hypothesis"+ str(idx) + ">\n"
        
    prompt += ("\nPlease output only the best hypothesis exactly as written above." 
               "Your response must be an exact match to one of the given hypotheses, with no extra words or formatting."
               "Do not include <hypothesis> tag, '\n', explanations, or any extra words.")
    messages = construct_input(prompt)
    return prompt


async def zero_shot_closest(hypotheses: List[str], client, model, generation_config) -> str:
    """ Select the hypothesis closest to an unconstrained correction output. """
    
    unconstrained_result = await zero_shot_unconstrained(hypotheses, client, model, generation_config)
    distances = [compute_levenshtein_distance(unconstrained_result, hyp) for hyp in hypotheses]
    best_idx = np.argmin(distances)
    return hypotheses[best_idx]


async def zero_shot_lattice(hypotheses: List[str], client, model, generation_config) -> str:
    """ Perform ASR error correction using a lattice-based approach. """
    
    pass # TO DO LATER

### Specify Experiment Settings

In [17]:
model = "Meta-Llama-3.1-8B-Instruct"
client = openai.AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
small_generation_config = {"max_tokens": 20, "temperature": 0.9}
moderate_generation_config = {"max_tokens": 200, "temperature": 0.9}

# If model is not yet available, try again after some delay.
output = None
while output is None:
    try:
        output = await client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": "Please introduce yourself."}],
        )
    
    except openai.APIError as e:
        print(e)
        sleep(10)

print(output.choices[0].message.content)

I'm an artificial intelligence model known as Llama. Llama stands for "Large Language Model Meta AI."


# Common Voice Test Dataset

In [18]:
# Importing Dataset
df = pd.read_csv("/fs01/home/omidv/ASR-Error-Correction/data/test_cv.csv")
dataset = Dataset.from_pandas(df)
print(dataset)
hypotheses = [h.strip() for h in dataset['source'][0].split('.') if h.strip()]
print(hypotheses)

Dataset({
    features: ['source', 'target', 'best_hypo'],
    num_rows: 1098
})
['transit road surveyed by joseph ellicott was named for an important surveying instrument', 'transit wrote surveyed by joseph ellicott was named for an important surveying instrument', 'transit road surveyed by joseph ellikot was named for an important surveying instrument', 'transit road surveyed by joseph ellicott was named for an important surveying instrument', 'transit road surveyed by joseph ellicate was named for an important surveying instrument']


In [8]:
metrics_zero_shot_unconstrained = await evaluate_model_parallel(dataset, model, client, zero_shot_unconstrained,
                                                                small_generation_config, use_llm=True)

100%|██████████| 1098/1098 [00:03<00:00, 326.74it/s]


Progress: 1098/1098 tasks completed
All tasks completed.


In [35]:
metrics_zero_shot_constrained = await evaluate_model_parallel(dataset, model, client, zero_shot_constrained,
                                                              small_generation_config, use_llm=True)

Progress: 1098/1098 tasks completed
All tasks completed.


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [37]:
metrics_get_oracle_hypothesis = await evaluate_model_parallel(dataset, model, client, get_oracle_hypothesis,
                                                              small_generation_config, use_llm=False)
metrics_get_top1_hypothesis = await evaluate_model_parallel(dataset, model, client, get_top1_hypothesis,
                                                              small_generation_config, use_llm=False)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [38]:
results_table = {
    "Top 1": metrics_get_top1_hypothesis,
    "Zero-shot Uncon": metrics_zero_shot_unconstrained,
    "Zero-shot Constr": metrics_zero_shot_constrained,
    "Oracle": metrics_get_oracle_hypothesis,
}
df = pd.DataFrame.from_dict(results_table, orient='index')
print(df[['WER', 'METEOR', 'BERT Precision', 'BERT Recall', 'BERT F1']])

                    WER  METEOR  BERT Precision  BERT Recall  BERT F1
Top 1             0.149   0.876           0.789        0.800    0.794
Zero-shot Uncon   0.185   0.853           0.774        0.795    0.784
Zero-shot Constr  0.153   0.872           0.786        0.801    0.794
Oracle            0.112   0.903           0.828        0.840    0.834


# Wall Street Journal Test Dataset


In [7]:
# Importing Dataset
df = pd.read_csv("/fs01/home/omidv/ASR-Error-Correction/data/test_wsj_score.csv")
dataset = Dataset.from_pandas(df)
print(dataset)
hypotheses = [h.strip() for h in dataset['source'][0].split('.') if h.strip()]
print(hypotheses)

Dataset({
    features: ['source', 'target', 'best_hypo', 'score'],
    num_rows: 836
})
['saatchi officials said the management restructuring might accelerate its efforts to persuade clients to use the firm as a one stop shop for business services', 'sachi officials said the management restructuring might accelerate its efforts to persuade clients to use the firm as a one stop shop for business services', 'saatchi officials said the management restructuring might accelerate its efforts to persuade clients to use the firm as a one stop shop for business services', 'sachi officials said the management restructuring might accelerate its efforts to persuade clients to use the firm as a one stop shop for business services', 'saatchi officials said the management restructuring might accelerate its efforts to persuade clients to use the firm as a one stop shop for business services']




Small Progress: 100%|██████████| 20/20 [00:12<00:00, 95.28it/s][A[A

In [8]:
metrics_zero_shot_unconstrained = await evaluate_model_parallel(dataset, model, client, zero_shot_unconstrained,
                                                                small_generation_config, use_llm=True)
metrics_zero_shot_constrained = await evaluate_model_parallel(dataset, model, client, zero_shot_constrained,
                                                              small_generation_config, use_llm=True)
metrics_get_oracle_hypothesis = await evaluate_model_parallel(dataset, model, client, get_oracle_hypothesis,
                                                              small_generation_config, use_llm=False)
metrics_get_top1_hypothesis = await evaluate_model_parallel(dataset, model, client, get_top1_hypothesis,
                                                              small_generation_config, use_llm=False)

100%|██████████| 836/836 [00:02<00:00, 395.30it/s]


Progress: 813/836 tasks completed

CancelledError: 

In [12]:
from tqdm import tqdm
import time
overall_progress = tqdm(total=10, desc="Total Progress")

for start_idx in range(0, 10):
    time.sleep(1)
    overall_progress.update(1)
    small_progress = tqdm(total=20, desc="Small Progress")
    for start_idx in range(0, 20):
        time.sleep(0.01)
        small_progress.update(1)
overall_progress.close()

Total Progress:  10%|█         | 1/10 [00:01<00:09,  1.00s/it]
Small Progress:   0%|          | 0/20 [00:00<?, ?it/s][A
Small Progress:  50%|█████     | 10/20 [00:00<00:00, 98.69it/s][A
Total Progress:  20%|██        | 2/10 [00:02<00:09,  1.13s/it]][A

Small Progress: 100%|██████████| 20/20 [00:01<00:00, 16.38it/s]


Small Progress:  50%|█████     | 10/20 [00:00<00:00, 96.70it/s][A[A

Total Progress:  30%|███       | 3/10 [00:03<00:08,  1.17s/it]][A[A
Small Progress: 100%|██████████| 20/20 [00:01<00:00, 16.36it/s]

Small Progress:  50%|█████     | 10/20 [00:00<00:00, 97.40it/s][A
Total Progress:  40%|████      | 4/10 [00:04<00:07,  1.19s/it]][A

Small Progress: 100%|██████████| 20/20 [00:01<00:00, 16.37it/s]


Small Progress:  50%|█████     | 10/20 [00:00<00:00, 96.76it/s][A[A

Total Progress:  50%|█████     | 5/10 [00:05<00:06,  1.20s/it]][A[A
Small Progress: 100%|██████████| 20/20 [00:01<00:00, 16.36it/s]

Small Progress:  50%|█████     | 10/20 [00:00<00:00, 96.22it/s][

In [42]:
results_table = {
    "Top 1": metrics_get_top1_hypothesis,
    "Zero-shot Uncon": metrics_zero_shot_unconstrained,
    "Zero-shot Constr": metrics_zero_shot_constrained,
    "Oracle": metrics_get_oracle_hypothesis,
}
df = pd.DataFrame.from_dict(results_table, orient='index')
df = df[['WER', 'METEOR', 'BERT Precision', 'BERT Recall', 'BERT F1']]
print(df)

                    WER  METEOR  BERT Precision  BERT Recall  BERT F1
Top 1             0.056   0.959           0.942        0.941    0.941
Zero-shot Uncon   0.201   0.892           0.773        0.845    0.808
Zero-shot Constr  0.138   0.941           0.828        0.891    0.859
Oracle            0.041   0.970           0.959        0.956    0.957


# SwitchBoard Test Dataset

In [24]:
# Importing Dataset
df = pd.read_csv("/fs01/home/omidv/ASR-Error-Correction/data/test_swbd.csv")
dataset = Dataset.from_pandas(df)
print(dataset)
hypotheses = [h.strip() for h in dataset['source'][0].split('.') if h.strip()]
print(hypotheses)

Dataset({
    features: ['source', 'target', 'best_hypo'],
    num_rows: 1234
})
['you know that did not in the home by choice anymore', 'you know that did not in the home by choice anymore', 'that did not in the home by choice anymore', 'you know that that did not in the home by choice anymore', 'that they are not in the home by choice anymore']


In [46]:
metrics_zero_shot_unconstrained = await evaluate_model_parallel(dataset, model, client, zero_shot_unconstrained,
                                                                small_generation_config, use_llm=True)
metrics_zero_shot_constrained = await evaluate_model_parallel(dataset, model, client, zero_shot_constrained,
                                                              small_generation_config, use_llm=True)
metrics_get_oracle_hypothesis = await evaluate_model_parallel(dataset, model, client, get_oracle_hypothesis,
                                                              small_generation_config, use_llm=False)
metrics_get_top1_hypothesis = await evaluate_model_parallel(dataset, model, client, get_top1_hypothesis,
                                                              small_generation_config, use_llm=False)

Progress: 1234/1234 tasks completed
All tasks completed.
Progress: 1234/1234 tasks completed
All tasks completed.


In [47]:
results_table = {
    "Top 1": metrics_get_top1_hypothesis,
    "Zero-shot Uncon": metrics_zero_shot_unconstrained,
    "Zero-shot Constr": metrics_zero_shot_constrained,
    "Oracle": metrics_get_oracle_hypothesis,
}
df = pd.DataFrame.from_dict(results_table, orient='index')
df = df[['WER', 'METEOR', 'BERT Precision', 'BERT Recall', 'BERT F1']]
print(df)

                    WER  METEOR  BERT Precision  BERT Recall  BERT F1
Top 1             0.159   0.915           0.781        0.823    0.802
Zero-shot Uncon   0.355   0.815           0.586        0.726    0.655
Zero-shot Constr  0.264   0.888           0.650        0.772    0.710
Oracle            0.124   0.936           0.822        0.859    0.841


# ATIS Test Dataset

In [33]:
# Importing Dataset
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_atis.json")
dataset = Dataset.from_pandas(df)
print(dataset)
hypotheses = dataset['input'][0]
print(hypotheses)

Dataset({
    features: ['id', 'input', 'output', 'am_score'],
    num_rows: 809
})
['list all us air flights from miami to cleveland leaving on sunday afternoon', 'list all us air flights from miami to cleveland leaving on sunday afternoon', 'list all us air flights from miami to cleveland leaving on sunday afternoon', 'list all us airflights from miami to cleveland leaving on sunday afternoon', 'list all us airflights from miami to cleveland leaving on sunday afternoon']


In [12]:
metrics_zero_shot_unconstrained = await evaluate_model_parallel(dataset, model, client, zero_shot_unconstrained,
                                                                small_generation_config, use_llm=True)
metrics_zero_shot_constrained = await evaluate_model_parallel(dataset, model, client, zero_shot_constrained,
                                                              small_generation_config, use_llm=True)
metrics_get_oracle_hypothesis = await evaluate_model_parallel(dataset, model, client, get_oracle_hypothesis,
                                                              small_generation_config, use_llm=False)
metrics_get_top1_hypothesis = await evaluate_model_parallel(dataset, model, client, get_top1_hypothesis,
                                                              small_generation_config, use_llm=False)

Progress: 1098/1098 tasks completed
All tasks completed.
Progress: 1098/1098 tasks completed
All tasks completed.


In [13]:
results_table = {
    "Top 1": metrics_get_top1_hypothesis,
    "Zero-shot Uncon": metrics_zero_shot_unconstrained,
    "Zero-shot Constr": metrics_zero_shot_constrained,
    "Oracle": metrics_get_oracle_hypothesis,
}
df = pd.DataFrame.from_dict(results_table, orient='index')
df = df[['WER', 'METEOR', 'BERT Precision', 'BERT Recall', 'BERT F1']]
print(df)

                    WER  METEOR  BERT Precision  BERT Recall  BERT F1
Top 1             0.149   0.876           0.789        0.800    0.794
Zero-shot Uncon   0.186   0.853           0.775        0.795    0.785
Zero-shot Constr  0.149   0.876           0.789        0.804    0.797
Oracle            0.112   0.903           0.828        0.840    0.834


# Tedlium-3 Test Dataset

In [14]:
# Importing Dataset
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_td3.json")
dataset = Dataset.from_pandas(df)
print(dataset)
hypotheses = dataset['input'][0]
print(hypotheses)

Dataset({
    features: ['input', 'output', 'input1', 'input2'],
    num_rows: 1155
})
['i would like to share with you a discovery that i made a few months ago while writing an article for italian wired i always keep my thesaurus handy whenever i am writing anything', 'i would like to share with you a discovery that i made a few months ago while writing an article for italian wired i always keep my thesaurus handy whenever i am writing anything but .', 'i would like to share with you a discovery that i made a few months ago while writing an article for italian wired i always keep my thesaurus handy whenever i am writing anything but', 'i would like to share with you a discovery that i made a few months ago while writing an article for italianwired i always keep my thesaurus handy whenever i am writing anything', 'i would like to share with you a discovery that i made a few months ago while writing an article for italian wired i always keep my thesaurus handy whenever i am writing anyt

In [15]:
metrics_zero_shot_unconstrained = await evaluate_model_parallel(dataset, model, client, zero_shot_unconstrained,
                                                                small_generation_config, use_llm=True)
metrics_zero_shot_constrained = await evaluate_model_parallel(dataset, model, client, zero_shot_constrained,
                                                              small_generation_config, use_llm=True)
metrics_get_oracle_hypothesis = await evaluate_model_parallel(dataset, model, client, get_oracle_hypothesis,
                                                              small_generation_config, use_llm=False)
metrics_get_top1_hypothesis = await evaluate_model_parallel(dataset, model, client, get_top1_hypothesis,
                                                              small_generation_config, use_llm=False)

Progress: 1155/1155 tasks completed
All tasks completed.
Progress: 1155/1155 tasks completed
All tasks completed.


In [16]:
results_table = {
    "Top 1": metrics_get_top1_hypothesis,
    "Zero-shot Uncon": metrics_zero_shot_unconstrained,
    "Zero-shot Constr": metrics_zero_shot_constrained,
    "Oracle": metrics_get_oracle_hypothesis,
}
df = pd.DataFrame.from_dict(results_table, orient='index')
df = df[['WER', 'METEOR', 'BERT Precision', 'BERT Recall', 'BERT F1']]
print(df)

                    WER  METEOR  BERT Precision  BERT Recall  BERT F1
Top 1             0.048   0.972           0.943        0.948    0.945
Zero-shot Uncon   0.362   0.827           0.601        0.775    0.686
Zero-shot Constr  0.267   0.906           0.678        0.847    0.760
Oracle            0.030   0.981           0.962        0.966    0.964


# Librispeech Clean Test Dataset

In [18]:
# Importing Dataset
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_ls_clean.json").iloc[:1000]
dataset = Dataset.from_pandas(df)
print(dataset)
hypotheses = dataset['input'][0]
print(hypotheses)

Dataset({
    features: ['instruction', 'input', 'output'],
    num_rows: 1000
})
['he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick peppered flour fattened sauce', 'he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick peppered flour flattened sauce', 'he hoped there would be stew for dinner turnips and carrots and bruise potatoes and fat mutton pieces to be ladled out in thick peppered flour fattened sauce', 'he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick peppered flour fattening sauce', 'he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out with thick peppered flour fattened sauce']


In [20]:
metrics_zero_shot_unconstrained = await evaluate_model_parallel(dataset, model, client, zero_shot_unconstrained,
                                                                small_generation_config, use_llm=True)
metrics_zero_shot_constrained = await evaluate_model_parallel(dataset, model, client, zero_shot_constrained,
                                                              small_generation_config, use_llm=True)
metrics_get_oracle_hypothesis = await evaluate_model_parallel(dataset, model, client, get_oracle_hypothesis,
                                                              small_generation_config, use_llm=False)
metrics_get_top1_hypothesis = await evaluate_model_parallel(dataset, model, client, get_top1_hypothesis,
                                                              small_generation_config, use_llm=False)

100%|██████████| 1000/1000 [00:11<00:00, 89.68it/s]


Progress: 1000/1000 tasks completed
All tasks completed.


100%|██████████| 1000/1000 [00:11<00:00, 88.61it/s]


Progress: 1000/1000 tasks completed
All tasks completed.


100%|██████████| 1000/1000 [00:11<00:00, 88.32it/s]
100%|██████████| 1000/1000 [00:11<00:00, 83.43it/s]


In [21]:
results_table = {
    "Top 1": metrics_get_top1_hypothesis,
    "Zero-shot Uncon": metrics_zero_shot_unconstrained,
    "Zero-shot Constr": metrics_zero_shot_constrained,
    "Oracle": metrics_get_oracle_hypothesis,
}
df = pd.DataFrame.from_dict(results_table, orient='index')
df = df[['WER', 'METEOR', 'BERT Precision', 'BERT Recall', 'BERT F1']]
print(df)

                    WER  METEOR  BERT Precision  BERT Recall  BERT F1
Top 1             0.021   0.980           0.972        0.972    0.972
Zero-shot Uncon   0.241   0.893           0.704        0.848    0.774
Zero-shot Constr  0.214   0.913           0.725        0.869    0.795
Oracle            0.009   0.989           0.984        0.985    0.985


# Librispeech Others Test Dataset

In [15]:
# Importing Dataset
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_ls_other.json").iloc[:1000]
dataset = Dataset.from_pandas(df)
print(dataset)
hypotheses = dataset['input'][0]
print(hypotheses)

Dataset({
    features: ['instruction', 'input', 'output'],
    num_rows: 1000
})
["there's iron they say in all our blood and a grain or two perhaps is good but his he makes me harshly feel has got a little too much of steel anon", "there's iron they say in all our blood and a grain or two perhaps is good but this he makes me harshly feel has got a little too much of steel anon", "there's iron they say in all our blood and a grain or two perhaps is good but he makes me harshly feel has got a little too much of steel anon", "there's iron they say in all our blood and a grain or two perhaps is good but as he makes me harshly feel has got a little too much of steel anon", "there's iron they say in all our blood an a grain or two perhaps is good but his he makes me harshly feel has got a little too much of steel anon"]


In [16]:
metrics_zero_shot_unconstrained = await evaluate_model_parallel(dataset, model, client, zero_shot_unconstrained,
                                                                small_generation_config, use_llm=True)
metrics_zero_shot_constrained = await evaluate_model_parallel(dataset, model, client, zero_shot_constrained,
                                                              small_generation_config, use_llm=True)
metrics_get_oracle_hypothesis = await evaluate_model_parallel(dataset, model, client, get_oracle_hypothesis,
                                                              small_generation_config, use_llm=False)
metrics_get_top1_hypothesis = await evaluate_model_parallel(dataset, model, client, get_top1_hypothesis,
                                                              small_generation_config, use_llm=False)

100%|██████████| 1000/1000 [00:10<00:00, 91.69it/s]


Progress: 1000/1000 tasks completed
All tasks completed.


100%|██████████| 1000/1000 [00:10<00:00, 91.52it/s]


Progress: 1000/1000 tasks completed
All tasks completed.


100%|██████████| 1000/1000 [00:11<00:00, 88.31it/s]
100%|██████████| 1000/1000 [00:11<00:00, 88.80it/s]


In [17]:
results_table = {
    "Top 1": metrics_get_top1_hypothesis,
    "Zero-shot Uncon": metrics_zero_shot_unconstrained,
    "Zero-shot Constr": metrics_zero_shot_constrained,
    "Oracle": metrics_get_oracle_hypothesis,
}
df = pd.DataFrame.from_dict(results_table, orient='index')
df = df[['WER', 'METEOR', 'BERT Precision', 'BERT Recall', 'BERT F1']]
print(df)

                    WER  METEOR  BERT Precision  BERT Recall  BERT F1
Top 1             0.057   0.948           0.914        0.916    0.915
Zero-shot Uncon   0.240   0.865           0.697        0.810    0.752
Zero-shot Constr  0.204   0.898           0.734        0.842    0.786
Oracle            0.036   0.964           0.936        0.940    0.938


# LRS2 Clean Test Dataset

In [22]:
# Importing Dataset
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_lrs2.json").iloc[:1000]
dataset = Dataset.from_pandas(df)
print(dataset)
hypotheses = dataset['input'][0]
print(hypotheses)

Dataset({
    features: ['input', 'output', 'input1', 'input2'],
    num_rows: 1000
})
['but it really is a rolls royce version', 'but it really is a rolls royce version', 'but it really is a rolls royce version it is', 'but it really is a rolls royce version it is', 'it really is a rolls royce version']


In [13]:
metrics_zero_shot_unconstrained = await evaluate_model_parallel(dataset, model, client, zero_shot_unconstrained,
                                                                small_generation_config, use_llm=True)
metrics_zero_shot_constrained = await evaluate_model_parallel(dataset, model, client, zero_shot_constrained,
                                                              small_generation_config, use_llm=True)
metrics_get_oracle_hypothesis = await evaluate_model_parallel(dataset, model, client, get_oracle_hypothesis,
                                                              small_generation_config, use_llm=False)
metrics_get_top1_hypothesis = await evaluate_model_parallel(dataset, model, client, get_top1_hypothesis,
                                                              small_generation_config, use_llm=False)

100%|██████████| 1000/1000 [00:10<00:00, 95.10it/s]


Progress: 1000/1000 tasks completed
All tasks completed.


100%|██████████| 1000/1000 [00:10<00:00, 92.50it/s]


Progress: 1000/1000 tasks completed
All tasks completed.


100%|██████████| 1000/1000 [00:11<00:00, 84.45it/s]
100%|██████████| 1000/1000 [00:10<00:00, 93.30it/s]


In [14]:
results_table = {
    "Top 1": metrics_get_top1_hypothesis,
    "Zero-shot Uncon": metrics_zero_shot_unconstrained,
    "Zero-shot Constr": metrics_zero_shot_constrained,
    "Oracle": metrics_get_oracle_hypothesis,
}
df = pd.DataFrame.from_dict(results_table, orient='index')
df = df[['WER', 'METEOR', 'BERT Precision', 'BERT Recall', 'BERT F1']]
print(df)

                    WER  METEOR  BERT Precision  BERT Recall  BERT F1
Top 1             0.132   0.893           0.871        0.845    0.858
Zero-shot Uncon   0.212   0.840           0.821        0.806    0.813
Zero-shot Constr  0.134   0.895           0.870        0.851    0.860
Oracle            0.069   0.939           0.926        0.913    0.919


In [6]:
model="gemma-2-9b-it"
client = openai.AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

output = None
while output is None:
    try:
        output = await client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": "Please introduce yourself."}],
        )
    
    except openai.APIError as e:
        print(e)
        sleep(10)

print(output.choices[0].message.content)

Hello! I am Gemma, an open-weights AI assistant. I was created by the Gemma team at Google DeepMind. My purpose is to help people by understanding and responding to their requests in a helpful, informative, and comprehensive way.

Since I am open-weights, my weights are publicly available. This means anyone can access, study, and modify me, which promotes transparency and collaboration in the AI community.

I am still under development, but I'm learning new things every day. I can generate text, translate languages, write different kinds of creative content, and answer your questions in an informative way.

What can I do for you today?



In [None]:
messages = [
    {
        "role": "user",
        "content": "Do you know speech recognition?"
    },
    {
        "role": "assistant",
        "content": "As a large language model, I am trained on a massive dataset of text and code. This allows me to understand and generate human-like text, but I don't directly process audio input.\n\n**Speech recognition**, also known as automatic speech recognition (ASR), is a separate field of AI that deals with converting spoken language into text. It involves several steps:\n\n1. **Audio Acquisition:** Capturing the spoken audio.\n2. **Feature Extraction:** Converting the audio signal into numerical features that represent the speech sounds.\n3. **Acoustic Modeling:** Using statistical models to predict the probability of different speech units (phonemes or words) based on the extracted features.\n4. **Language Modeling:** Using statistical models to predict the probability of different word sequences, considering the context of the spoken words.\n5. **Decoding:** Combining the acoustic and language model outputs to find the most likely sequence of words.\n"
    },
    {
        "role": "user",
        "content": "Do you know language model for speech recognition?"
    },
    {
        "role": "assistant",
        "content": "You're getting into some exciting territory! \n\nWhile I'm a text-only model, I can tell you about language models used in speech recognition.  \n\nTraditionally, speech recognition heavily relied on acoustic models (like Hidden Markov Models) and pronunciation dictionaries.  \n\nHowever, the rise of Transformer-based language models has revolutionized the field. These models, trained on massive text and sometimes audio datasets, can learn complex relationships between sounds and words, leading to significant improvements in accuracy and natural language understanding.\n\nHere are some examples of language models used in speech recognition:\n\n* **Whisper (OpenAI):** A powerful, open-source model known for its impressive performance in various languages and noise conditions.\n* ** wav2vec 2.0 (Facebook AI):** Another open-source model that excels in unsupervised learning from raw audio data.\n* **FastSpeech 2 (Microsoft):** Focuses on fast and natural-sounding speech synthesis. \n"
    },
    {
        "role": "user",
        "content": "Could you give a possible example of language model rescoring with hypotheses?"
    },
    {
        "role": "assistant",
        "content": ""
    },
    {
        "role": "user",
        "content": "Nice job, I will provide some examples as a demonstration from []. The 5-best hypothesis is:[], and I would expect your output is: []. Following this example, could you report the true transcription from the following 5-best hypotheses? []"
    }
]

output = await client.chat.completions.create(
            model=model,
            messages=messages,
            **custom_generation_config,
)
print(output)

In [36]:
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_cv.json")
print(df.info())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2000 entries, 0 to 1999
Data columns (total 4 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   input   2000 non-null   object
 1   output  2000 non-null   object
 2   input1  2000 non-null   object
 3   input2  2000 non-null   object
dtypes: object(4)
memory usage: 62.6+ KB
None


In [39]:
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_lrs2.json")
print(df.info())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2259 entries, 0 to 2258
Data columns (total 4 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   input   2259 non-null   object
 1   output  2259 non-null   object
 2   input1  2259 non-null   object
 3   input2  2259 non-null   object
dtypes: object(4)
memory usage: 70.7+ KB
None


In [43]:
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_chime4.json")
print(df.info())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1320 entries, 0 to 1319
Data columns (total 4 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   input   1320 non-null   object
 1   output  1320 non-null   object
 2   input1  1320 non-null   object
 3   input2  1320 non-null   object
dtypes: object(4)
memory usage: 41.4+ KB
None


In [45]:
import tiktoken

tokenizer = tiktoken.encoding_for_model("gpt-4o")  # Change based on your API model

# Function to count tokens
def count_tokens(text):
    if isinstance(text, str):  # Ensure the value is a string
        return len(tokenizer.encode(zero_shot_constrained(text)))
    return 0  # Return 0 for missing values

df["token_count"] = df['output'].apply(count_tokens)
total_tokens = df["token_count"].sum()

# Print the result
print(f"Total token count: {total_tokens}")

# Calculate cost using OpenAI pricing (update as needed)
cost_per_1k_tokens = 0.01  # Example price for gpt-4-turbo (input tokens)
total_cost = (total_tokens / 1000) * cost_per_1k_tokens

print(f"Estimated OpenAI API cost: ${total_cost:.4f}")


Total token count: 1021435
Estimated OpenAI API cost: $10.2143
