### Environment Setup

In [35]:
import numpy as np
import os
import inspect
import random
from time import sleep
from typing import List, Callable
import jiwer
import openai
import pandas as pd
from datasets import Dataset
from sacrebleu import corpus_bleu
from nltk.translate.meteor_score import meteor_score
from bert_score import BERTScorer
import evaluate
from tqdm.notebook import tqdm
import logging
from transformers import AutoTokenizer
import torch
import string
import time
import asyncio
import nest_asyncio
from dotenv import load_dotenv #Load the environment variables
load_dotenv()

import nltk
nltk.download('wordnet')
nltk.download('omw-1.4')

import logging
logging.getLogger("transformers").setLevel(logging.ERROR)

[nltk_data] Downloading package wordnet to /h/omidv/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /h/omidv/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


### Helper Functions

In [36]:
def compute_wer(reference, hypothesis):
    return jiwer.wer(reference, hypothesis)


def compute_bleu(references, hypotheses):
    return corpus_bleu(hypotheses, references).score


def compute_meteor(references, hypotheses):
    scores = []
    for ref, hyp in zip(references, hypotheses):
        scores.append(meteor_score([ref.split()], hyp.split()))
    return sum(scores)/len(scores)


def compute_bertscore(references, hypotheses):
    scorer = BERTScorer(lang="en", rescale_with_baseline=True)
    p, r, f1 = scorer.score(hypotheses, references)
    bert_score = {'precision': p.mean().item(),
                  'recall': r.mean().item(),
                  'f1': f1.mean().item()}
    return bert_score


def compute_levenshtein_distance(s1: str, s2: str) -> int:
    """Compute the Levenshtein distance between two strings."""
    
    len_s1, len_s2 = len(s1), len(s2)
    dp = np.zeros((len_s1 + 1, len_s2 + 1), dtype=int)

    for i in range(len_s1 + 1):
        dp[i][0] = i
    for j in range(len_s2 + 1):
        dp[0][j] = j

    for i in range(1, len_s1 + 1):
        for j in range(1, len_s2 + 1):
            cost = 0 if s1[i - 1] == s2[j - 1] else 1
            dp[i][j] = min(dp[i - 1][j] + 1,      # Deletion
                           dp[i][j - 1] + 1,      # Insertion
                           dp[i - 1][j - 1] + cost)  # Substitution

    return dp[len_s1][len_s2]


def construct_input(question):
    prompt = [{"role": "user", "content": question}]
    return prompt


def extract_hypotheses(dataset, idx):
    if 'source' in dataset.features:
        hypotheses = [h.strip() for h in dataset['source'][idx].split('.') if h.strip()]
        references = dataset['target'][idx]
    else:
        hypotheses = dataset['input'][idx]
        references = dataset['output'][idx]
        
    return hypotheses, references


def remove_punctuation(text):
    return text.translate(str.maketrans("", "", string.punctuation))


def save_results(dataset: Dataset, corrections: list, model_name: str, function_name: str, file_path: str):
    correction_column = f"corrected_by_{model_name}_{function_name}"
    if os.path.exists(file_path):
        existing_df = pd.read_json(file_path)
    else:
        existing_df = dataset.to_pandas()

    if correction_column not in existing_df.columns:
        existing_df[correction_column] = None

    existing_df[correction_column] = corrections
    existing_df.to_json(file_path, orient="records", indent=4)
    print(f"Results saved to {file_path}")
    
    
async def run_evaluation(dataset, model, client, generation_config, results_path):
    print("Evaluating Zero-shot Unconstrained:")
    metrics_zero_shot_unconstrained = await evaluate_model_parallel(dataset, model, client, zero_shot_unconstrained, generation_config, results_path)
    
    print("Evaluating Zero-shot Constrained:")
    metrics_zero_shot_constrained = await evaluate_model_parallel(dataset, model, client, zero_shot_constrained, generation_config, results_path)
    
    print("Evaluating Zero-shot Closest:")
    metrics_zero_shot_closest = await evaluate_model_parallel(dataset, model, client, zero_shot_closest, generation_config, results_path)
    
    print("Evaluating Oracle:")
    metrics_get_oracle_hypothesis = await evaluate_model_parallel(dataset, model, client, get_oracle_hypothesis, generation_config, results_path)
    
    print("Evaluating Top 1:")
    metrics_get_top1_hypothesis = await evaluate_model_parallel(dataset, model, client, get_top1_hypothesis, generation_config, results_path)

    results_table = {
        "Top 1": metrics_get_top1_hypothesis,
        "Zero-shot Uncon": metrics_zero_shot_unconstrained,
        "Zero-shot Constr": metrics_zero_shot_constrained,
        "Zero-shot Closest": metrics_zero_shot_closest,
        "Oracle": metrics_get_oracle_hypothesis,
    }
    results_table = pd.DataFrame.from_dict(results_table, orient='index')
    results_table = results_table[['WER', 'METEOR', 'BERT Precision', 'BERT Recall', 'BERT F1']]

    # Save as JSON
    csv_path = results_path.replace(".json", f"_{model}.csv")
    results_table.to_csv(csv_path)
    print(f"Benchmark saved to {csv_path}")
    return results_table

### Asynchronous Evaluation

In [37]:
from openai import RateLimitError
nest_asyncio.apply()

async def call_openai_with_retry(messages, model, generation_config, client):
    """Handles API retries with exponential backoff."""
    
    retry_delay = 0.1  # Initial delay in seconds
    max_delay = 10
    while True:
        try:
            # Attempt to make the API call
            generation = await client.chat.completions.create(
                model=model,
                messages=messages,
                **generation_config
            )
            return generation

        except RateLimitError as e:
            wait_time = retry_delay
            if hasattr(e, "response") and e.response is not None:
                try:
                    error_data = e.response.json()
                    wait_time = float(error_data.get("detail", {}).get("wait_seconds", {}))
                except:
                    pass
            await asyncio.sleep(wait_time + 0.1)


        except Exception as e:
            await asyncio.sleep(retry_delay)
            retry_delay = min(retry_delay * 2, max_delay)  # Exponential backoff


async def get_prediction(client: openai.AsyncOpenAI, model: str, messages: List[dict], generation_config: dict) -> str:
    """Asynchronously fetch predictions from OpenAI API."""
    
    try:
        generation = await call_openai_with_retry(messages, model, generation_config, client)
        return generation.choices[0].message.content if generation else ""
    except Exception as e:
        print(f"Error: {e}")
        return ""


async def track_progress(tasks):
    """Tracks progress while tasks are running."""
    
    total_tasks = len(tasks)
    while not all(task.done() for task in tasks):
        completed = sum(task.done() for task in tasks)
        print(f"Progress: {completed}/{total_tasks} tests completed!", end="\r")
        await asyncio.sleep(0.1)
    print(f"Progress: Batch of {total_tasks} tests completed!", flush=True)
    return await asyncio.gather(*tasks)

    
async def process_batch(dataset: Dataset, indices: List[int], model: str, client: openai.AsyncOpenAI, postprocessing: Callable[[List[str]], str], 
    generation_config: dict) -> List[str]:
    """Processes the dataset asynchronously using OpenAI API with progress tracking."""
    
    tasks = []
    for idx in indices:
        hypotheses, reference = extract_hypotheses(dataset, idx)
        
        if inspect.iscoroutinefunction(postprocessing):
            tasks.append(asyncio.create_task(postprocessing(hypotheses, client, model, generation_config)))
        else:
            tasks.append(asyncio.create_task(asyncio.to_thread(postprocessing, hypotheses, reference)))
    
    results = await track_progress(tasks)
    return results
    
async def evaluate_model_parallel(dataset: Dataset, model: str, client: openai.AsyncOpenAI, postprocessing: Callable[[List[str]], str],
                            generation_config: dict, results_path: str, step=256):
    """Evaluates the model asynchronously with progress tracking, handling Jupyter compatibility."""
    
    total_rows = len(dataset)
    all_predictions = []
    
    for start in range(0, total_rows, step):
        end = min(start + step, total_rows)
        batch_indices = list(range(start, end))
        batch_predictions = await process_batch(dataset, batch_indices, model, client, postprocessing, generation_config)
        all_predictions.extend(batch_predictions)

    all_references = dataset['target'] if 'target' in dataset.features else dataset['output']
    
    # Normalize for evaluation
    all_predictions = [remove_punctuation(pred.lower()) for pred in all_predictions] 
    all_references = [remove_punctuation(ref.lower()) for ref in all_references]

    # Print 3 random results for manual review
    random_indices = random.sample(range(len(all_predictions)), 3)
    print("-" * 100)
    for idx in random_indices:
        print(f"Sample {idx + 1}")
        print(f"Target: {all_references[idx]}")
        print(f"Pred:   {all_predictions[idx]}")
        print("-" * 100)
    
    save_results(dataset, all_predictions, model, postprocessing.__name__, results_path)
        
    # Compute evaluation metrics
    wer_scores = np.array([jiwer.wer(ref, pred) for ref, pred in zip(all_references, all_predictions)])
    bertscore = compute_bertscore(all_predictions, all_references)
    metrics = {
        'WER': round(wer_scores.mean().item(), 3),
        'METEOR': round(compute_meteor(all_predictions, all_references), 3),
        'BERT Precision': round(bertscore['precision'], 3),
        'BERT Recall': round(bertscore['recall'], 3),
        'BERT F1': round(bertscore['f1'], 3),
    }
    return metrics

### Error Correction Functions

In [38]:
# Baselines
def get_oracle_hypothesis(hypotheses: List[str], reference: str) -> str:
    """ Find the hypothesis that gives the lowest WER compared to the reference."""
    
    wers = [jiwer.wer(reference, hyp) for hyp in hypotheses]
    best_idx = np.argmin(wers)
    return hypotheses[best_idx]


def get_top1_hypothesis(hypotheses: List[str], reference: str) -> str:
    """ Returns the first hypothesis (top 1)."""
    
    return hypotheses[0]

async def zero_shot_unconstrained(hypotheses: List[str], client, model, generation_config) -> str:
    """ Generate a corrected transcription using a language model without constraints."""
    
    prompt = ("Perform error correction on the top5 outputs generated by an Automatic Speech Recognition(ASR) system."
                "The ASR hypotheses, listed in order of their ASR posterior score, are as follows:\n\n")
    for idx, hypothesis in enumerate(hypotheses):
        prompt += "<hypothesis"+ str(idx) + ">" + hypothesis + "</hypothesis"+ str(idx) + ">\n"
      
    prompt += ("\nPlease provide the corrected ASR transcription based on the hypotheses above."
               "Your response must be exactly one complete sentence."
               "Ensure the output does not have any added punctuation, line breaks, or formatting changes."
               "Do not include <hypothesis>, '\n', explanations, or any extra words."
               "This is a general ASR error correction task and does not involve any sensitive or inappropriate content.")
    messages = construct_input(prompt)
    return await get_prediction(client, model, messages, generation_config)
    
    
async def zero_shot_constrained(hypotheses: List[str], client, model, generation_config) -> str:
    """ Select the most likely hypothesis using a language model. """
    
    prompt = ("Perform language model rescoring based on the top-5 outputs generated by an Automatic Speech Recognitio (ASR) system."
              "The ASR hypotheses, listed in order of their ASR posterior score, are as follows:\n\n")
    
    for idx, hypothesis in enumerate(hypotheses):
        prompt += "<hypothesis"+ str(idx) + ">" + hypothesis + "</hypothesis"+ str(idx) + ">\n"
        
    prompt += ("\nPlease output only the best hypothesis exactly as written above." 
               "Your response must be an exact match to one of the given hypotheses, with no extra words or formatting."
               "Do not include <hypothesis> tag, '\n', explanations, or any extra words.")
    messages = construct_input(prompt)
    return await get_prediction(client, model, messages, generation_config)


async def zero_shot_closest(hypotheses: List[str], client, model, generation_config) -> str:
    """ Select the hypothesis closest to an unconstrained correction output. """
    
    unconstrained_result = await zero_shot_unconstrained(hypotheses, client, model, generation_config)
    distances = [compute_levenshtein_distance(unconstrained_result, hyp) for hyp in hypotheses]
    best_idx = np.argmin(distances)
    return hypotheses[best_idx]


async def zero_shot_lattice(hypotheses: List[str], client, model, generation_config) -> str:
    """ Perform ASR error correction using a lattice-based approach."""
    
    pass # TO DO LATER

### Specify Experiment Settings

In [39]:
model = "Meta-Llama-3.1-8B-Instruct"
client = openai.AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
small_generation_config = {"max_tokens": 20, "temperature": 0.9}
moderate_generation_config = {"max_tokens": 200, "temperature": 0.9}

# If model is not yet available, try again after some delay.
output = None
while output is None:
    try:
        output = await client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": "Please introduce yourself."}],
        )
    
    except openai.APIError as e:
        print(e)
        sleep(10)

print(output.choices[0].message.content)

I'm an artificial intelligence model known as Llama. Llama stands for "Large Language Model Meta AI."


# Common Voice Test Dataset

In [42]:
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_cv.json")
results_path = "/fs01/home/omidv/ASR-Error-Correction/results/test_cv.json"
cv_generation_config = {"max_tokens": 25, "temperature": 0.9}
dataset = Dataset.from_pandas(df)
print(dataset)

Dataset({
    features: ['input', 'output', 'input1', 'input2'],
    num_rows: 2000
})


In [43]:
results_table = await run_evaluation(dataset, model, client, cv_generation_config, results_path)
print(results_table)

Evaluating Zero-shot Unconstrained:
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 208 tests completed!
----------------------------------------------------------------------------------------------------
Sample 454
Target: it can be constructed as the intersection of all balanced sets containing s
Pred:   it can be constructed as the intersection of all balanced sets containing s
----------------------------------------------------------------------------------------------------
Sample 325
Target: the most contentious point was regarding the control of the island of bornholm
Pred:   the more sensitive point was regarding the control of the island of elba
------------------------------------------------------------------------------------

# Wall Street Journal Test Dataset


In [44]:
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_wsj_score.json")
results_path = "/fs01/home/omidv/ASR-Error-Correction/results/test_wsj_score.json"
wsj_generation_config = {"max_tokens": 30, "temperature": 0.9}
dataset = Dataset.from_pandas(df)
print(dataset)

Dataset({
    features: ['input', 'output', 'score'],
    num_rows: 836
})


In [45]:
results_table = await run_evaluation(dataset, model, client, wsj_generation_config, results_path)
print(results_table)

Evaluating Zero-shot Unconstrained:
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 68 tests completed!
----------------------------------------------------------------------------------------------------
Sample 587
Target: the launch had been planned for earlier this year but was scrubbed by the space agency about five times because of design and other delays
Pred:   the launch had been planned for earlier this year but was scrubbed by the space agency about five times because of design and other delays
----------------------------------------------------------------------------------------------------
Sample 385
Target: riches will come again to bimini
Pred:   riches will come again to bimini
----------------------------------------------------------------------------------------------------
Sample 574
Target: astronomers say that the earth is fate is sealed
Pred:   astronomers say that the ear

# SwitchBoard Test Dataset

In [46]:
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_swbd.json")
results_path = "/fs01/home/omidv/ASR-Error-Correction/results/test_swbd.json"
swbd_generation_config = {"max_tokens": 65, "temperature": 0.9}
dataset = Dataset.from_pandas(df)
print(dataset)

Dataset({
    features: ['input', 'output', 'input1', 'input2'],
    num_rows: 2000
})


In [47]:
results_table = await run_evaluation(dataset, model, client, swbd_generation_config, results_path)
print(results_table)

Evaluating Zero-shot Unconstrained:
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 208 tests completed!
----------------------------------------------------------------------------------------------------
Sample 1468
Target: well the hobbies that i pursue in my spare time are crafts
Pred:   well the hobbies that i pursue in my spare time are crafts
----------------------------------------------------------------------------------------------------
Sample 950
Target: and i do not know have there been any good b books published on that i know world war two my dad was a paratrooper in airborne one oh hundred and one but i do not think they have done anything good on viet nam
Pred:   and i do not know have there been any good books published 

# ATIS Test Dataset

In [48]:
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_atis.json")
results_path = "/fs01/home/omidv/ASR-Error-Correction/results/test_atis.json"
atis_generation_config = {"max_tokens": 45, "temperature": 0.9}
dataset = Dataset.from_pandas(df)
print(dataset)

Dataset({
    features: ['id', 'input', 'output', 'am_score'],
    num_rows: 809
})


In [49]:
results_table = await run_evaluation(dataset, model, client, atis_generation_config, results_path)
print(results_table)

Evaluating Zero-shot Unconstrained:
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 41 tests completed!
----------------------------------------------------------------------------------------------------
Sample 538
Target: phoenix till to milwaukee on sunday
Pred:   phoenix is to milwaukee on sunday
----------------------------------------------------------------------------------------------------
Sample 404
Target: what airline is a a
Pred:   what airline is american airlines
----------------------------------------------------------------------------------------------------
Sample 578
Target: list all sunday flights from cleveland to nashville and their fares
Pred:   list all sunday flights from cleveland to nashville and their fares
----------------------------------------------------------------------------------------------------
Results saved to /fs01/home/omidv/ASR-Error-Correction/resul

# Tedlium-3 Test Dataset

In [50]:
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_td3.json")
results_path = "/fs01/home/omidv/ASR-Error-Correction/results/test_td3.json"
td3_generation_config = {"max_tokens": 130, "temperature": 0.9}
dataset = Dataset.from_pandas(df)
print(dataset)

Dataset({
    features: ['input', 'output', 'input1', 'input2'],
    num_rows: 1155
})


In [51]:
results_table = await run_evaluation(dataset, model, client, td3_generation_config, results_path)
print(results_table)

Evaluating Zero-shot Unconstrained:
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 131 tests completed!
----------------------------------------------------------------------------------------------------
Sample 589
Target: than the one week vacation because there are no new memories added you have not changed the story
Pred:   than the one week vacation because there are no new memories added you have changed the story
----------------------------------------------------------------------------------------------------
Sample 905
Target: biggest wiki in the world second biggest wiki in the world with nearly eighty thousand articles is the world of warcraft wiki
Pred:   biggest wiki in the world second biggest wiki in the world with nearly eighty thousand articles is the world of warcraft wiki
----------------------------------------------------------------

# Librispeech Clean Test Dataset

In [52]:
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_ls_clean.json")
results_path = "/fs01/home/omidv/ASR-Error-Correction/results/test_ls_clean.json"
ls_clean_generation_config = {"max_tokens": 100, "temperature": 0.9}
dataset = Dataset.from_pandas(df)
print(dataset)

Dataset({
    features: ['instruction', 'input', 'output'],
    num_rows: 2620
})


In [53]:
results_table = await run_evaluation(dataset, model, client, ls_clean_generation_config, results_path)
print(results_table)

Evaluating Zero-shot Unconstrained:
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 60 tests completed!
----------------------------------------------------------------------------------------------------
Sample 914
Target: i will briefly describe them to you and you shall read the account of them at your leisure in the sacred registers
Pred:   i will briefly describe them to you and you shall read the account of them at your leisure in the sacred registers
----------------------------------------------------------------------------------------------------
Sample 192
Target: we have spoken of pearls rich and luxuriant bea

# Librispeech Others Test Dataset

In [54]:
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_ls_other.json")
results_path = "/fs01/home/omidv/ASR-Error-Correction/results/test_ls_other.json"
ls_others_generation_config = {"max_tokens": 130, "temperature": 0.9}
dataset = Dataset.from_pandas(df)
print(dataset)

Dataset({
    features: ['instruction', 'input', 'output'],
    num_rows: 2939
})


In [55]:
results_table = await run_evaluation(dataset, model, client, ls_others_generation_config, results_path)
print(results_table)

Evaluating Zero-shot Unconstrained:
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 123 tests completed!
----------------------------------------------------------------------------------------------------
Sample 1510
Target: mister wicker waited patiently beside him for a few moments for chris to get up his courage
Pred:   mister wicker waited patiently beside him for a few moments for chris to get up his courage
----------------------------------------------------------------------------------------------------
Sample 2131
Target: it wouldnt do you know after that story came out f

Benchmark saved to /fs01/home/omidv/ASR-Error-Correction/results/test_ls_other_Meta-Llama-3.1-8B-Instruct.csv
                     WER  METEOR  BERT Precision  BERT Recall  BERT F1
Top 1              0.045   0.958           0.933        0.935    0.934
Zero-shot Uncon    0.091   0.932           0.896        0.902    0.899
Zero-shot Constr   0.070   0.940           0.909        0.915    0.912
Zero-shot Closest  0.066   0.944           0.912        0.919    0.916
Oracle             0.027   0.973           0.953        0.956    0.955


# LRS2 Clean Test Dataset

In [56]:
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_lrs2.json")
results_path = "/fs01/home/omidv/ASR-Error-Correction/results/test_lrs2.json"
lrs2_generation_config = {"max_tokens": 25, "temperature": 0.9}
dataset = Dataset.from_pandas(df)
print(dataset)

Dataset({
    features: ['input', 'output', 'input1', 'input2'],
    num_rows: 2259
})


In [57]:
results_table = await run_evaluation(dataset, model, client, lrs2_generation_config, results_path)
print(results_table)

Evaluating Zero-shot Unconstrained:
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 211 tests completed!
----------------------------------------------------------------------------------------------------
Sample 362
Target: let us try reforming
Pred:   so let us try reforming
----------------------------------------------------------------------------------------------------
Sample 1305
Target: lovely little picture
Pred:   a lovely colorful picture
----------------------------------------------------------------------------------------------------
Sample 1552
Target: along with a baronetcy by no less a personage than king george iii
Pred:   along with a baronetcy by no lesser personage than king ge

# Gemma2 9B

In [12]:
model="gemma-2-9b-it"
client = openai.AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

output = None
while output is None:
    try:
        output = await client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": "Please introduce yourself."}],
        )
    
    except openai.APIError as e:
        print(e)
        sleep(10)

print(output.choices[0].message.content)

Hello! I am Gemma, an open-weights AI assistant developed by the Gemma team at Google DeepMind.

I'm a large language model, which means I'm trained on a massive amount of text data. This allows me to communicate and generate human-like text in response to a wide range of prompts and questions. For example, I can provide summaries of factual topics, create stories, translate languages, and help you brainstorm ideas.

Being open-weights means my weights are publicly accessible. This allows anyone to study, modify, or build upon me, fostering transparency and collaboration in the AI community.

What can I do for you today?



In [9]:
df = pd.read_json("/fs01/home/omidv/ASR-Error-Correction/data/test_cv.json")
results_path = "/fs01/home/omidv/ASR-Error-Correction/results/test_cv.json"
dataset = Dataset.from_pandas(df)
print(dataset)

Dataset({
    features: ['input', 'output', 'input1', 'input2'],
    num_rows: 2000
})


In [40]:
results_table = await run_evaluation(dataset, model, client, small_generation_config, results_path)
print(results_table)

Evaluating Zero-shot Unconstrained:
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 208 tests completed!
----------------------------------------------------------------------------------------------------
Sample 84
Target: a boy is riding down a bounce slide
Pred:   the boy is riding down a bounce slide  

----------------------------------------------------------------------------------------------------
Sample 1501
Target: both william and marguerite were heavily influenced by cubism and fauvism
Pred:   both william and margaret were heavily influenced by cubism and formalism 

----------------------------------------------------------------------------------------------------
Sample 1621
Target: lava usually leaves the point of eruptio

# Llama 70B

In [None]:
model = "Meta-Llama-3.1-70B-Instruct"
client = openai.AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

output = None
while output is None:
    try:
        output = await client.chat.completions.create(
            model=model,
            messages=[{"rolve": "user", "content": "Please introduce yourself."}],
        )
    
    except openai.APIError as e:
        print(e)
        sleep(10)

print(output.choices[0].message.content)

In [None]:
results_table = await run_evaluation(dataset, model, client, small_generation_config, results_path)
print(results_table)

# Playground

In [None]:
async def CoT_task_activating(hypotheses: List[str], client, model, generation_config) -> str:
    """ Perform ASR error correction using Chain-of-Thought (CoT) reasoning."""
    
    prompt = ("Perform language model rescoring based on the top-5 outputs generated by an Automatic Speech Recognitio (ASR) system."
              "The ASR hypotheses, listed in order of their ASR posterior score, are as follows:\n\n")
    
    for idx, hypothesis in enumerate(hypotheses):
        prompt += "<hypothesis"+ str(idx) + ">" + hypothesis + "</hypothesis"+ str(idx) + ">\n"
        
    prompt += ("")
    messages = construct_input(prompt)
    return await get_prediction(client, model, messages, generation_config)

In [None]:
model="gemma-2-9b-it"
custom_generation_config, = {"max_tokens": 200, "temperature": 0.9}
messages = [
    {
        "role": "user",
        "content": "Do you know speech recognition?"
    },
    {
        "role": "assistant",
        "content": "As a large language model, I am trained on a massive dataset of text and code. This allows me to understand and generate human-like text, but I don't directly process audio input.\n\n**Speech recognition**, also known as automatic speech recognition (ASR), is a separate field of AI that deals with converting spoken language into text. It involves several steps:\n\n1. **Audio Acquisition:** Capturing the spoken audio.\n2. **Feature Extraction:** Converting the audio signal into numerical features that represent the speech sounds.\n3. **Acoustic Modeling:** Using statistical models to predict the probability of different speech units (phonemes or words) based on the extracted features.\n4. **Language Modeling:** Using statistical models to predict the probability of different word sequences, considering the context of the spoken words.\n5. **Decoding:** Combining the acoustic and language model outputs to find the most likely sequence of words.\n"
    },
    {
        "role": "user",
        "content": "Do you know language model for speech recognition?"
    },
    {
        "role": "assistant",
        "content": "You're getting into some exciting territory! \n\nWhile I'm a text-only model, I can tell you about language models used in speech recognition.  \n\nTraditionally, speech recognition heavily relied on acoustic models (like Hidden Markov Models) and pronunciation dictionaries.  \n\nHowever, the rise of Transformer-based language models has revolutionized the field. These models, trained on massive text and sometimes audio datasets, can learn complex relationships between sounds and words, leading to significant improvements in accuracy and natural language understanding.\n\nHere are some examples of language models used in speech recognition:\n\n* **Whisper (OpenAI):** A powerful, open-source model known for its impressive performance in various languages and noise conditions.\n* ** wav2vec 2.0 (Facebook AI):** Another open-source model that excels in unsupervised learning from raw audio data.\n* **FastSpeech 2 (Microsoft):** Focuses on fast and natural-sounding speech synthesis. \n"
    },
    {
        "role": "user",
        "content": "Could you give a possible example of language model rescoring with hypotheses?"
    },
    {
        "role": "assistant",
        "content": ""
    },
    {
        "role": "user",
        "content": "Nice job, I will provide some examples as a demonstration from []. The 5-best hypothesis is:[], and I would expect your output is: []. Following this example, could you report the true transcription from the following 5-best hypotheses? []"
    }
]

output = await client.chat.completions.create(
            model=model,
            messages=messages,
            **custom_generation_config,
)
print(output)

# Mistral 7B

In [41]:
model = "Mistral-7B-Instruct-v0.3"
client = openai.AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

output = None
while output is None:
    try:
        output = await client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": "Please introduce yourself."}],
        )
    
    except openai.APIError as e:
        print(e)
        sleep(10)

print(output.choices[0].message.content)

 Hello! I am a model trained by Mistral AI, designed to assist with a wide range of tasks. I can help answer questions, provide information, generate creative content, and perform various other functions to make your life easier. I am here 24/7 to serve you, so feel free to ask me anything at any time. I'm always learning and improving to better meet your needs. Enjoy our interaction!


In [42]:
results_table = await run_evaluation(dataset, model, client, small_generation_config, results_path)
print(results_table)

Evaluating Zero-shot Unconstrained:
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 256 tests completed!
Progress: Batch of 208 tests completed!
----------------------------------------------------------------------------------------------------
Sample 1056
Target: he studied at the university of athens and at the university of chicago
Pred:    he studied at the university of athens and the university of chicago
----------------------------------------------------------------------------------------------------
Sample 1019
Target: surprise is part one of a two part story
Pred:    the surprise is part one of a two part story
----------------------------------------------------------------------------------------------------
Sample 996
Target: i was held in captivity and the nazis