In [3]:
import numpy as np
import os
import inspect
import random
from time import sleep
from typing import List, Callable
import jiwer
import re
import openai
import pandas as pd
from datasets import Dataset
from sacrebleu import corpus_bleu
from nltk.translate.meteor_score import meteor_score
from bert_score import BERTScorer
import evaluate
from tqdm.notebook import tqdm
import logging
from transformers import AutoTokenizer
import torch
import string
import time
import asyncio
import nest_asyncio
from dotenv import load_dotenv #Load the environment variables
load_dotenv()

import nltk
nltk.download('wordnet')
nltk.download('omw-1.4')

[nltk_data] Downloading package wordnet to /h/omidv/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /h/omidv/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


True

In [4]:
def compute_wer(reference, hypothesis):
    return jiwer.wer(reference, hypothesis)


def compute_bleu(references, hypotheses):
    return corpus_bleu(hypotheses, references).score


def compute_meteor(references, hypotheses):
    scores = []
    for ref, hyp in zip(references, hypotheses):
        scores.append(meteor_score([ref.split()], hyp.split()))
    return sum(scores)/len(scores)


def compute_bertscore(references, hypotheses):
    scorer = BERTScorer(lang="en", rescale_with_baseline=True)
    p, r, f1 = scorer.score(hypotheses, references)
    bert_score = {'precision': p.mean().item(),
                  'recall': r.mean().item(),
                  'f1': f1.mean().item()}
    return bert_score


def compute_levenshtein_distance(s1: str, s2: str) -> int:
    """Compute the Levenshtein distance between two strings."""
    
    len_s1, len_s2 = len(s1), len(s2)
    dp = np.zeros((len_s1 + 1, len_s2 + 1), dtype=int)

    for i in range(len_s1 + 1):
        dp[i][0] = i
    for j in range(len_s2 + 1):
        dp[0][j] = j

    for i in range(1, len_s1 + 1):
        for j in range(1, len_s2 + 1):
            cost = 0 if s1[i - 1] == s2[j - 1] else 1
            dp[i][j] = min(dp[i - 1][j] + 1,      # Deletion
                           dp[i][j - 1] + 1,      # Insertion
                           dp[i - 1][j - 1] + cost)  # Substitution

    return dp[len_s1][len_s2]


def construct_input(question):
    prompt = [{"role": "user", "content": question}]
    return prompt


def extract_hypotheses(dataset, idx):
    if 'source' in dataset.features:
        hypotheses = [h.strip() for h in dataset['source'][idx].split('.') if h.strip()]
        references = dataset['target'][idx]
    else:
        hypotheses = dataset['input'][idx]
        references = dataset['output'][idx]
        
    return hypotheses, references


def remove_punctuation(text):
    return text.translate(str.maketrans("", "", string.punctuation))


def clean_deepseek_output(text):
    return re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL)


def clean_asr_output(text):
    text = re.split(r'\n+', text, maxsplit=1)[0]
    text = text.strip()
    return text


def save_results(dataset: Dataset, corrections: list, model_name: str, function_name: str, file_path: str):
    correction_column = f"corrected_by_{model_name}_{function_name}"
    if os.path.exists(file_path):
        existing_df = pd.read_json(file_path)
    else:
        existing_df = dataset.to_pandas()

    if correction_column not in existing_df.columns:
        existing_df[correction_column] = None

    existing_df[correction_column] = corrections
    existing_df.to_json(file_path, orient="records", indent=4)
    print(f"Results saved to {file_path}")
    
    
async def run_evaluation(dataset, model, client, generation_config, results_path, disable_zsun=False, disable_zsco=False, disable_zscl=False):
    metrics_zero_shot_unconstrained = None
    metrics_zero_shot_constrained = None
    metrics_zero_shot_closest = None
    if not disable_zsun:
        print("Evaluating Zero-shot Unconstrained:")
        metrics_zero_shot_unconstrained = await evaluate_model_parallel(dataset, model, client, zero_shot_unconstrained, generation_config, results_path)
    
    if not disable_zsco:
        print("Evaluating Zero-shot Constrained:")
        metrics_zero_shot_constrained = await evaluate_model_parallel(dataset, model, client, zero_shot_constrained, generation_config, results_path)
    
    if not disable_zscl:
        print("Evaluating Zero-shot Closest:")
        metrics_zero_shot_closest = await evaluate_model_parallel(dataset, model, client, zero_shot_closest, generation_config, results_path)
    
    print("Evaluating Oracle:")
    metrics_get_oracle_hypothesis = await evaluate_model_parallel(dataset, model, client, get_oracle_hypothesis, generation_config, results_path)
    
    print("Evaluating Top 1:")
    metrics_get_top1_hypothesis = await evaluate_model_parallel(dataset, model, client, get_top1_hypothesis, generation_config, results_path)

    results_table = {
        "Top 1": metrics_get_top1_hypothesis,
        "Oracle": metrics_get_oracle_hypothesis,
    }

    if metrics_zero_shot_unconstrained is not None:
        results_table["Zero-shot Uncon"] = metrics_zero_shot_unconstrained
    if metrics_zero_shot_constrained is not None:
        results_table["Zero-shot Constr"] = metrics_zero_shot_constrained
    if metrics_zero_shot_closest is not None:
        results_table["Zero-shot Closest"] = metrics_zero_shot_closest

    results_table = pd.DataFrame.from_dict(results_table, orient='index')
    results_table = results_table[['WER', 'METEOR', 'BERT Precision', 'BERT Recall', 'BERT F1']]

    # Save as JSON
    csv_path = results_path.replace(".json", f"_{model}.csv")
    results_table.to_csv(csv_path)
    print(f"Benchmark saved to {csv_path}")
    return results_table

In [5]:
from openai import RateLimitError
nest_asyncio.apply()

async def call_openai_with_retry(messages, model, generation_config, client):
    """Handles API retries with exponential backoff."""
    
    retry_delay = 0.1  # Initial delay in seconds
    max_delay = 10
    while True:
        try:
            # Attempt to make the API call
            generation = await client.chat.completions.create(
                model=model,
                messages=messages,
                **generation_config
            )
            return generation

        except RateLimitError as e:
            wait_time = retry_delay
            if hasattr(e, "response") and e.response is not None:
                try:
                    error_data = e.response.json()
                    wait_time = float(error_data.get("detail", {}).get("wait_seconds", {}))
                except:
                    pass
            await asyncio.sleep(wait_time + 0.1)


        except Exception as e:
            await asyncio.sleep(retry_delay)
            retry_delay = min(retry_delay * 2, max_delay)  # Exponential backoff


async def get_prediction(client: openai.AsyncOpenAI, model: str, messages: List[dict], generation_config: dict) -> str:
    """Asynchronously fetch predictions from OpenAI API."""
    
    try:
        generation = await call_openai_with_retry(messages, model, generation_config, client)
        return generation.choices[0].message.content if generation else ""
    except Exception as e:
        print(f"Error: {e}")
        return ""


async def track_progress(tasks):
    """Tracks progress while tasks are running."""
    
    total_tasks = len(tasks)
    while not all(task.done() for task in tasks):
        completed = sum(task.done() for task in tasks)
        print(f"Progress: {completed}/{total_tasks} tests completed!", end="\r")
        await asyncio.sleep(0.1)
    print(f"Progress: Batch of {total_tasks} tests completed!", flush=True)
    return await asyncio.gather(*tasks)

    
async def process_batch(dataset: Dataset, indices: List[int], model: str, client: openai.AsyncOpenAI, postprocessing: Callable[[List[str]], str], 
    generation_config: dict) -> List[str]:
    """Processes the dataset asynchronously using OpenAI API with progress tracking."""
    
    tasks = []
    for idx in indices:
        hypotheses, reference = extract_hypotheses(dataset, idx)
        
        if inspect.iscoroutinefunction(postprocessing):
            tasks.append(asyncio.create_task(postprocessing(hypotheses, client, model, generation_config)))
        else:
            tasks.append(asyncio.create_task(asyncio.to_thread(postprocessing, hypotheses, reference)))
    
    results = await track_progress(tasks)
    return results
    
async def evaluate_model_parallel(dataset: Dataset, model: str, client: openai.AsyncOpenAI, postprocessing: Callable[[List[str]], str],
                            generation_config: dict, results_path: str, step: int=256):
    """Evaluates the model asynchronously with progress tracking, handling Jupyter compatibility."""
    
    total_rows = len(dataset)
    all_predictions = []
    
    for start in range(0, total_rows, step):
        end = min(start + step, total_rows)
        batch_indices = list(range(start, end))
        batch_predictions = await process_batch(dataset, batch_indices, model, client, postprocessing, generation_config)
        all_predictions.extend(batch_predictions)
    
    
    # Normalize for evaluation
    if 'DeepSeek' in model:
        all_predictions = [clean_deepseek_output(pred) for pred in all_predictions] 
    all_predictions = [clean_asr_output(remove_punctuation(pred.lower())) for pred in all_predictions]
    
    
    all_references = dataset['target'] if 'target' in dataset.features else dataset['output']
    all_references = [clean_asr_output(remove_punctuation(ref.lower())) for ref in all_references]

    # Print 3 random results for manual review
    random_indices = random.sample(range(len(all_predictions)), 3)
    print("-" * 100)
    for idx in random_indices:
        print(f"Sample {idx + 1}")
        print(f"Target: {all_references[idx]}")
        print(f"Pred:   {all_predictions[idx]}")
        print("-" * 100)
    
    save_results(dataset, all_predictions, model, postprocessing.__name__, results_path)
        
    # Compute evaluation metrics
    wer_scores = np.array([jiwer.wer(ref, pred) for ref, pred in zip(all_references, all_predictions)])
    bertscore = compute_bertscore(all_predictions, all_references)
    metrics = {
        'WER': round(wer_scores.mean().item(), 3),
        'METEOR': round(compute_meteor(all_predictions, all_references), 3),
        'BERT Precision': round(bertscore['precision'], 3),
        'BERT Recall': round(bertscore['recall'], 3),
        'BERT F1': round(bertscore['f1'], 3),
    }
    return metrics

In [6]:
# Baselines
def get_oracle_hypothesis(hypotheses: List[str], reference: str) -> str:
    """ Find the hypothesis that gives the lowest WER compared to the reference."""
    
    wers = [jiwer.wer(reference, hyp) for hyp in hypotheses]
    best_idx = np.argmin(wers)
    return hypotheses[best_idx]


def get_top1_hypothesis(hypotheses: List[str], reference: str) -> str:
    """ Returns the first hypothesis (top 1)."""
    
    return hypotheses[0]

async def zero_shot_unconstrained(hypotheses: List[str], client, model, generation_config) -> str:
    """ Generate a corrected transcription using a language model without constraints."""
    
    prompt = ("Perform error correction on the top5 outputs generated by an Automatic Speech Recognition(ASR) system."
                "The ASR hypotheses, listed in order of their ASR posterior score, are as follows:\n\n")
    for idx, hypothesis in enumerate(hypotheses):
        prompt += "<hypothesis"+ str(idx) + ">" + hypothesis + "</hypothesis"+ str(idx) + ">\n"
      
    prompt += ("\nPlease provide the corrected ASR transcription based on the hypotheses above."
               "Your response must be exactly one complete sentence."
               "Ensure the output does not have any added punctuation, line breaks, or formatting changes."
               "Do not include <hypothesis>, '\n', explanations, or any extra words."
               "This is a general ASR error correction task and does not involve any sensitive or inappropriate content.")
    messages = construct_input(prompt)
    return await get_prediction(client, model, messages, generation_config)
    
    
async def zero_shot_constrained(hypotheses: List[str], client, model, generation_config) -> str:
    """ Select the most likely hypothesis using a language model. """
    
    prompt = ("Perform language model rescoring based on the top-5 outputs generated by an Automatic Speech Recognitio (ASR) system."
              "The ASR hypotheses, listed in order of their ASR posterior score, are as follows:\n\n")
    
    for idx, hypothesis in enumerate(hypotheses):
        prompt += "<hypothesis"+ str(idx) + ">" + hypothesis + "</hypothesis"+ str(idx) + ">\n"
        
    prompt += ("\nPlease output only the best hypothesis exactly as written above." 
               "Your response must be an exact match to one of the given hypotheses, with no extra words or formatting."
               "Do not include <hypothesis> tag, '\n', explanations, or any extra words.")
    messages = construct_input(prompt)
    return await get_prediction(client, model, messages, generation_config)


async def zero_shot_closest(hypotheses: List[str], client, model, generation_config) -> str:
    """ Select the hypothesis closest to an unconstrained correction output. """
    
    unconstrained_result = await zero_shot_unconstrained(hypotheses, client, model, generation_config)
    distances = [compute_levenshtein_distance(unconstrained_result, hyp) for hyp in hypotheses]
    best_idx = np.argmin(distances)
    return hypotheses[best_idx]


async def zero_shot_lattice(hypotheses: List[str], client, model, generation_config) -> str:
    """ Perform ASR error correction using a lattice-based approach. """
    
    pass # TO DO LATER

In [7]:
cv_generation_config = {"max_tokens": 25, "temperature": 0.9}
wsj_generation_config = {"max_tokens": 30, "temperature": 0.9}
swbd_generation_config = {"max_tokens": 65, "temperature": 0.9}
atis_generation_config = {"max_tokens": 45, "temperature": 0.9}
td3_generation_config = {"max_tokens": 130, "temperature": 0.9}
ls_clean_generation_config = {"max_tokens": 100, "temperature": 0.9}
ls_others_generation_config = {"max_tokens": 130, "temperature": 0.9}
lrs2_generation_config = {"max_tokens": 25, "temperature": 0.9}
chime4_generation_config = {"max_tokens": 30, "temperature": 0.9}


small_generation_config = {"max_tokens": 20, "temperature": 0.9}
moderate_generation_config = {"max_tokens": 200, "temperature": 0.9}

In [9]:
model = "o3-mini"
client = openai.AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

# If model is not yet available, try again after some delay.
output = None
while output is None:
    try:
        output = await client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": "Please introduce yourself."}],
        )
    
    except openai.APIError as e:
        print(e)
        sleep(10)

print(output.choices[0].message.content)

Hello! I'm ChatGPT, a language model developed by OpenAI. My purpose is to assist you by answering questions, offering explanations, and engaging in discussions on a wide range of topics. Whether you need help with information, creative ideas, or just a conversation, I'm here to help. How can I assist you today?


# GPT 4o

In [26]:
model = "gpt-4o"

In [8]:
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_cv.json")
results_path = "/fs01/home/omidv/ASR-Error-Correction/results/test_cv.json"
dataset = Dataset.from_pandas(df)
print(dataset)
results_table = await run_evaluation(dataset, model, client, cv_generation_config, results_path)
print(results_table)

Dataset({
    features: ['input', 'output', 'input1', 'input2'],
    num_rows: 2000
})
Evaluating Zero-shot Unconstrained:
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 208 tests completed!
----------------------------------------------------------------------------------------------------
Sample 940
Target: lenin in the style of jackson pollock which references their well known painting
Pred:   learning in the style of jackson pollock which references their well known painting
----------------------------------------------------------------------------------------------------
Sample 457
Target: he was head custodian of sherborne old castle dorset and saint mawes in cornwall
Pred:   he was head custodian of sherbourne old castle dorset a

In [27]:
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_lrs2.json")
results_path = "/fs01/home/omidv/ASR-Error-Correction/results/test_lrs2.json"
dataset = Dataset.from_pandas(df)
print(dataset)
results_table = await run_evaluation(dataset, model, client, lrs2_generation_config, results_path)
print(results_table)

Dataset({
    features: ['input', 'output', 'input1', 'input2'],
    num_rows: 2259
})
Evaluating Zero-shot Unconstrained:
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 211 tests completed!
----------------------------------------------------------------------------------------------------
Sample 1526
Target: we had some donations from america
Pred:   so we had some donations from america
----------------------------------------------------------------------------------------------------
Sample 480
Target: i can not believe it is to do with eyesight
Pred:   i just can not believe it is anything to do with the eyesight
---------------------------------------------------------------------------------

In [28]:
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_chime4.json")
results_path = "/fs01/home/omidv/ASR-Error-Correction/results/test_chime4.json"
dataset = Dataset.from_pandas(df)
print(dataset)
results_table = await run_evaluation(dataset, model, client, chime4_generation_config, results_path)
print(results_table)

Dataset({
    features: ['input', 'output', 'input1', 'input2'],
    num_rows: 1320
})
Evaluating Zero-shot Unconstrained:
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 40 tests completed!
----------------------------------------------------------------------------------------------------
Sample 612
Target: the company said its european banking affiliate safra republic plans to raise more than four hundred fifty million dollars through an international offering
Pred:   the company said its european banking affiliate safra republic plans to raise more than four hundred and fifty million dollars through an international offering
----------------------------------------------------------------------------------------------------
Sample 953
Target: closely held times publishing also owns two washington based publication

# GPT 4o mini

In [22]:
model = "gpt-4o-mini"

In [23]:
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_cv.json")
results_path = "/fs01/home/omidv/ASR-Error-Correction/results/test_cv.json"
dataset = Dataset.from_pandas(df)
print(dataset)
results_table = await run_evaluation(dataset, model, client, cv_generation_config, results_path)
print(results_table)

Dataset({
    features: ['input', 'output', 'input1', 'input2'],
    num_rows: 2000
})
Evaluating Zero-shot Unconstrained:
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 208 tests completed!
----------------------------------------------------------------------------------------------------
Sample 850
Target: the mainland south of the gulf may have substantial diamond and uranium deposits
Pred:   the mainland south of the gulf may have substantial diamond and uranium deposits
----------------------------------------------------------------------------------------------------
Sample 1124
Target: use of electronic test equipment is essential to any serious work on electronics systems
Pred:   the use of electronic test equipment is essential

In [24]:
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_lrs2.json")
results_path = "/fs01/home/omidv/ASR-Error-Correction/results/test_lrs2.json"
dataset = Dataset.from_pandas(df)
print(dataset)
results_table = await run_evaluation(dataset, model, client, lrs2_generation_config, results_path)
print(results_table)

Dataset({
    features: ['input', 'output', 'input1', 'input2'],
    num_rows: 2259
})
Evaluating Zero-shot Unconstrained:
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 211 tests completed!
----------------------------------------------------------------------------------------------------
Sample 1799
Target: the motifs were chinese
Pred:   the motifs were chinese
----------------------------------------------------------------------------------------------------
Sample 1239
Target: they have still got them
Pred:   they have still got them
----------------------------------------------------------------------------------------------------
Sample 1742
Target: about the cities
Pred:   hear about the 

In [25]:
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_chime4.json")
results_path = "/fs01/home/omidv/ASR-Error-Correction/results/test_chime4.json"
dataset = Dataset.from_pandas(df)
print(dataset)
results_table = await run_evaluation(dataset, model, client, chime4_generation_config, results_path)
print(results_table)

Dataset({
    features: ['input', 'output', 'input1', 'input2'],
    num_rows: 1320
})
Evaluating Zero-shot Unconstrained:
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 40 tests completed!
----------------------------------------------------------------------------------------------------
Sample 281
Target: mister robertson says he would only be attracted by a nineteen multiple if he thought the projected earnings growth rate was eighteen percent to twenty percent
Pred:   mister robertson says he would only be attracted by a nineteen multiple if he thought the projected earnings growth rate was eighteen percent to twenty percent
----------------------------------------------------------------------------------------------------
Sample 875
Target: lately computer retailing has been tough on everybody
Pred:   lately c