[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sithtsar/ASR-Noun-Enhancement/blob/main/notebooks/notebook.ipynb)

# ASR Noun Enhancement Assignment
Implementation of spell correction for ASR noun enhancement using baseline and advanced models.

This notebook covers:
- Data loading and preprocessing
- Exploratory Data Analysis (EDA) with visualizations
- Error extraction and categorization
- Baseline model (dictionary + Levenshtein + trigram)
- Advanced model (T5 fine-tuning)
- Evaluation and comparison

In [None]:
# Install dependencies (for Colab)
!pip install pandas numpy scikit-learn nltk transformers torch datasets sentencepiece matplotlib spacy
!python -m spacy download en_core_web_sm
import nltk
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')

In [None]:
import pandas as pd
from pathlib import Path
import json
import re
import spacy
import nltk
from nltk import pos_tag, word_tokenize
from collections import Counter
import matplotlib.pyplot as plt
from difflib import SequenceMatcher
from transformers import T5Tokenizer, T5ForConditionalGeneration
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import torch

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Data Loading and Preprocessing

In [None]:
def load_df(data_path: str, extension_pattern: str) -> pd.DataFrame:
    data_dir = Path.cwd() / data_path
    data_file_extsn = f"*.{extension_pattern}"
    data_file_path = list(data_dir.glob(data_file_extsn))[0]
    df = pd.read_excel(data_file_path)
    return df

df = load_df("data", "xlsx")
print("Dataset loaded:")
print(df.head())
print(f"Total samples: {len(df)}")

In [None]:
def clean_text(text: str) -> str:
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def load_spacy_model():
    try:
        nlp = spacy.load("en_core_web_sm")
    except OSError:
        spacy.cli.download("en_core_web_sm")
        nlp = spacy.load("en_core_web_sm")
    return nlp

nlp = load_spacy_model()

def extract_nouns_ner(text: str) -> list:
    doc = nlp(text)
    nouns = [ent.text for ent in doc.ents if ent.label_ in ['PERSON', 'ORG', 'GPE', 'PRODUCT']]
    return nouns

def extract_nouns_pos(text: str) -> list:
    tokens = word_tokenize(text)
    tagged = pos_tag(tokens)
    nouns = [word for word, tag in tagged if tag in ['NN', 'NNS', 'NNP', 'NNPS']]
    return nouns

# Apply preprocessing
df['clean_correct'] = df['correct sentences'].apply(clean_text)
df['clean_incorrect'] = df['ASR-generated incorrect transcriptions'].apply(clean_text)
df['nouns_correct_ner'] = df['correct sentences'].apply(extract_nouns_ner)
df['nouns_incorrect_ner'] = df['ASR-generated incorrect transcriptions'].apply(extract_nouns_ner)
df['nouns_correct_pos'] = df['correct sentences'].apply(extract_nouns_pos)
df['nouns_incorrect_pos'] = df['ASR-generated incorrect transcriptions'].apply(extract_nouns_pos)

print("Preprocessing completed.")

## 2. Exploratory Data Analysis (EDA)

In [None]:
# Basic stats
df['len_correct'] = df['clean_correct'].apply(lambda x: len(x.split()))
df['len_incorrect'] = df['clean_incorrect'].apply(lambda x: len(x.split()))

stats = {
    'total_samples': len(df),
    'avg_len_correct': df['len_correct'].mean(),
    'avg_len_incorrect': df['len_incorrect'].mean(),
    'vocab_size_correct': len(set(' '.join(df['clean_correct']).split())),
    'vocab_size_incorrect': len(set(' '.join(df['clean_incorrect']).split())),
    'avg_nouns_ner': df['nouns_correct_ner'].apply(len).mean(),
    'avg_nouns_pos': df['nouns_correct_pos'].apply(len).mean()
}

print("Stats:", json.dumps(stats, indent=2))

In [None]:
# Plots
plt.figure(figsize=(10, 6))
plt.hist(df['len_correct'], alpha=0.5, label='Correct')
plt.hist(df['len_incorrect'], alpha=0.5, label='Incorrect')
plt.xlabel('Sentence Length')
plt.ylabel('Frequency')
plt.title('Sentence Length Distribution')
plt.legend()
plt.show()

# Noun count vs length
plt.scatter(df['len_correct'], df['nouns_correct_pos'].apply(len))
plt.xlabel('Sentence Length')
plt.ylabel('Noun Count')
plt.title('Nouns vs Sentence Length')
plt.show()

## 2.5 Additional EDA Plots

In [None]:
# Load processed data for additional plots
import json
from collections import Counter
import ast

# Assuming processed data exists; if not, run the scripts first
try:
    processed_df = pd.read_csv('data/processed/augmented_df.csv')
    with open('data/processed/stats.json', 'r') as f:
        stats = json.load(f)
    print("Loaded processed data.")
except FileNotFoundError:
    print("Processed data not found. Run preprocessing scripts first.")
    processed_df = None
    stats = None

if processed_df is not None:
    # Error types plot
    counters = processed_df['error_categories'].dropna().apply(ast.literal_eval)
    total = Counter()
    for c in counters:
        total.update(c)
    plt.figure(figsize=(10, 6))
    plt.bar(list(total.keys()), list(total.values()))
    plt.xlabel('Error Type')
    plt.ylabel('Frequency')
    plt.title('Frequency of Error Types')
    plt.show()
    
    # Error distribution
    plt.figure(figsize=(10, 6))
    plt.hist(processed_df['num_errors'], bins=range(0, processed_df['num_errors'].max()+2), alpha=0.7)
    plt.xlabel('Number of Errors per Sentence')
    plt.ylabel('Frequency')
    plt.title('Distribution of Errors per Sentence')
    plt.show()
    
    # Vocab stats pie
    if stats:
        labels = ['Correct Vocab', 'Incorrect Vocab', 'Medical Correct', 'Medical Incorrect']
        sizes = [stats['vocab_size_correct'], stats['vocab_size_incorrect'], stats['medical_terms_correct'], stats['medical_terms_incorrect']]
        plt.figure(figsize=(8, 8))
        plt.pie(sizes, labels=labels, autopct='%1.1f%%', startangle=140)
        plt.title('Vocabulary Coverage Statistics')
        plt.show()
    
    # Length impact
    plt.figure(figsize=(10, 6))
    plt.scatter(processed_df['len_correct'], processed_df['num_errors'])
    plt.xlabel('Sentence Length (words)')
    plt.ylabel('Number of Errors')
    plt.title('Context Length Impact on Correction Errors')
    plt.show()
    
    # Noun error distribution
    plt.figure(figsize=(10, 6))
    plt.scatter(processed_df['nouns_correct_pos'].apply(len), processed_df['num_errors'])
    plt.xlabel('Number of Nouns in Sentence')
    plt.ylabel('Number of Errors')
    plt.title('Noun Count vs Errors')
    plt.show()

## 3. Error Extraction and Categorization

In [None]:
from difflib import SequenceMatcher

def get_differing_words(correct, incorrect):
    correct_words = correct.split()
    incorrect_words = incorrect.split()
    matcher = SequenceMatcher(None, correct_words, incorrect_words)
    blocks = matcher.get_matching_blocks()
    errors = []
    i = 0
    for block in blocks:
        for j in range(i, block.a):
            if block.b > 0:
                errors.append((correct_words[j], incorrect_words[block.b - (block.a - j)]))
        i = block.a + block.size
    return errors

def categorize_error(incorrect, correct):
    if not incorrect or not correct:
        return 'segmentation'
    matcher = SequenceMatcher(None, incorrect, correct)
    dist = len(incorrect) + len(correct) - 2 * sum(block.size for block in matcher.get_matching_blocks())
    if dist <= 2:
        return 'character'
    elif len(incorrect.split()) != len(correct.split()):
        return 'segmentation'
    elif matcher.ratio() > 0.6:
        return 'phonetic'
    else:
        return 'word'

# Apply
df['error_words'] = df.apply(lambda row: get_differing_words(row['clean_correct'], row['clean_incorrect']), axis=1)
df['error_categories'] = df['error_words'].apply(lambda x: [categorize_error(inc, cor) for inc, cor in x if inc and cor])

# Error DB
from collections import defaultdict
error_db = defaultdict(list)
for errors in df['error_words']:
    for cor, inc in errors:
        if inc and cor:
            error_db[inc].append(cor)
final_db = {inc: Counter(cors).most_common(1)[0][0] for inc, cors in error_db.items()}

print("Error DB built with", len(final_db), "entries.")

## 4. Baseline Model

In [None]:
from difflib import get_close_matches

def build_trigram_model(sentences):
    model = {}
    for sent in sentences:
        tokens = word_tokenize(sent)
        for i in range(len(tokens)-2):
            prev, mid, next_t = tokens[i], tokens[i+1], tokens[i+2]
            if prev not in model:
                model[prev] = {}
            if mid not in model[prev]:
                model[prev][mid] = {}
            if next_t not in model[prev][mid]:
                model[prev][mid][next_t] = 0
            model[prev][mid][next_t] += 1
    return model

trigram_model = build_trigram_model(df['clean_correct'])

def correct_sentence(sentence, error_db, trigram_model):
    words = sentence.split()
    corrected = []
    for i, word in enumerate(words):
        if word in error_db:
            corrected.append(error_db[word])
        else:
            candidates = get_close_matches(word, list(trigram_model.keys()), n=3)
            if candidates:
                corrected.append(candidates[0])
            else:
                corrected.append(word)
    return ' '.join(corrected)

# Test on sample
sample = df['clean_incorrect'].iloc[0]
corrected = correct_sentence(sample, final_db, trigram_model)
print(f"Original: {sample}")
print(f"Corrected: {corrected}")
print(f"Ground truth: {df['clean_correct'].iloc[0]}")

## 5. Advanced Model (T5)

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from datasets import Dataset
from sklearn.model_selection import train_test_split

# Prepare data
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
train_dataset = Dataset.from_dict({'input_text': train_df['clean_incorrect'].tolist(), 'target_text': train_df['clean_correct'].tolist()})

tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small').to(device)

def tokenize_function(examples):
    inputs = tokenizer(examples['input_text'], max_length=128, truncation=True, padding='max_length')
    targets = tokenizer(examples['target_text'], max_length=128, truncation=True, padding='max_length')
    inputs['labels'] = targets['input_ids']
    return inputs

train_dataset = train_dataset.map(tokenize_function, batched=True)

# Training (quick for demo)
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=1,
    per_device_train_batch_size=8,
    save_steps=500,
)

trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset)
trainer.train()

print("T5 training completed.")

## 6. Evaluation

In [None]:
def evaluate_model(df, model_func):
    predictions = [model_func(row['clean_incorrect']) for _, row in df.iterrows()]
    references = df['clean_correct'].tolist()
    
    # Accuracy
    accuracy = sum(1 for p, r in zip(predictions, references) if p == r) / len(df)
    
    # BLEU
    smoothie = SmoothingFunction().method4
    bleu = sum(sentence_bleu([r.split()], p.split(), smoothing_function=smoothie) for p, r in zip(predictions, references)) / len(df)
    
    # Noun accuracy
    def get_nouns(sent):
        return set(extract_nouns_pos(sent))
    
    correct_nouns = [get_nouns(r) for r in references]
    pred_nouns = [get_nouns(p) for p in predictions]
    noun_acc = sum(len(c & p) / len(c) if c else 0 for c, p in zip(correct_nouns, pred_nouns)) / len(df)
    
    return accuracy, bleu, noun_acc

# Baseline eval
baseline_func = lambda x: correct_sentence(x, final_db, trigram_model)
acc_b, bleu_b, noun_b = evaluate_model(test_df, baseline_func)
print(f"Baseline: Acc {acc_b:.4f}, BLEU {bleu_b:.4f}, Noun {noun_b:.4f}")

# Advanced eval
def t5_correct(sentence):
    inputs = tokenizer(sentence, return_tensors='pt', max_length=128, truncation=True).to(device)
    outputs = model.generate(**inputs, max_length=128)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

acc_a, bleu_a, noun_a = evaluate_model(test_df.head(10), t5_correct)  # sample for speed
print(f"Advanced: Acc {acc_a:.4f}, BLEU {bleu_a:.4f}, Noun {noun_a:.4f}")

## 7. Results and Comparison

In [None]:
# Comparison plot
metrics = ['Accuracy', 'BLEU', 'Noun Accuracy']
baseline_vals = [acc_b, bleu_b, noun_b]
advanced_vals = [acc_a, bleu_a, noun_a]

x = range(len(metrics))
plt.bar(x, baseline_vals, width=0.4, label='Baseline')
plt.bar([i+0.4 for i in x], advanced_vals, width=0.4, label='Advanced')
plt.xticks([i+0.2 for i in x], metrics)
plt.legend()
plt.title('Model Comparison')
plt.show()

print("Notebook completed. For full pipeline, run the scripts.")