# Context-sensitive Spelling Correction

The goal of the assignment is to implement context-sensitive spelling correction. The input of the code will be a set of text lines and the output will be the same lines with spelling mistakes fixed.

Submit the solution of the assignment to Moodle as a link to your GitHub repository containing this notebook.

Useful links:
- [Norvig's solution](https://norvig.com/spell-correct.html)
- [Norvig's dataset](https://norvig.com/big.txt)
- [Ngrams data](https://www.ngrams.info/download_coca.asp)

Grading:
- 60 points - Implement spelling correction
- 20 points - Justify your decisions
- 20 points - Evaluate on a test set


## Implement context-sensitive spelling correction

Your task is to implement context-sensitive spelling corrector using N-gram language model. The idea is to compute conditional probabilities of possible correction options. For example, the phrase "dking sport" should be fixed as "doing sport" not "dying sport", while "dking species" -- as "dying species".

The best way to start is to analyze [Norvig's solution](https://norvig.com/spell-correct.html) and [N-gram Language Models](https://web.stanford.edu/~jurafsky/slp3/3.pdf).

You may also want to implement:
- spell-checking for a concrete language - Russian, Tatar, etc. - any one you know, such that the solution accounts for language specifics,
- some recent (or not very recent) paper on this topic,
- solution which takes into account keyboard layout and associated misspellings,
- efficiency improvement to make the solution faster,
- any other idea of yours to improve the Norvig’s solution.

IMPORTANT:  
Your project should not be a mere code copy-paste from somewhere. You must provide:
- Your implementation
- Analysis of why the implemented approach is suggested
- Improvements of the original approach that you have chosen to implement

### Solution is train encoder-decoder transformer for seq2seq spelling correction task

Norvig's solution is quite simple without the use of neural approaches, in contrast to this solution, an idea emerged on how much more effective modern neural approaches are than old probabilistic approaches. 

The solution is not based on any article, all implementations are completely unique.


In [None]:
from datasets import load_dataset

dataset = load_dataset("ag_news")

In [3]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 120000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})

In [4]:
import datasets
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=5):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [5]:
show_random_elements(dataset['train'])

Unnamed: 0,text,label
0,"Tiger Put to the Ryder Cup Challenge (AP) AP - Tiger Woods has been chasing Jack Nicklaus in golf record books since he was a kid. When it comes to the Ryder Cup, though, Tiger doesn't mean Jack.",Sports
1,"Johansson Upsets Roddick at U.S. Open NEW YORK - Andy Roddick ran into a bold, bigger version of himself at the U.S. Open, and 6-foot-6 Joachim Johansson sent the defending champion home...",World
2,"Internet Turns 35, Still Work in Progress (AP) AP - Thirty-five years after computer scientists at UCLA linked two bulky computers using a 15-foot gray cable, testing a new way for exchanging data over networks, what would ultimately become the Internet remains a work in progress.",Sci/Tech
3,"Disconnected PDAs are dead, according to RIM International wireless solutions manufacturer Research in Motion (RIM) believes the days of disconnected PDAs are gone. The BlackBerry-maker said that users #39; information is changing too rapidly for disconnected",Business
4,Doctors told to let baby Charlotte die The parents of a premature baby have lost their battle to force doctors to keep tiny 11-month-old Charlotte alive if she stops breathing a fourth time.,World


In [6]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-small")
max_number_of_tokens = 64

In [7]:
from nltk.tokenize import RegexpTokenizer
import numpy as np

basicTokenizer = RegexpTokenizer(r'\w+')

def random_augmented(word, probability=0.4):
    if np.random.choice([False, True], p=[1 - probability, probability]):
        augmentation_type = np.random.choice(['delete', 'transpose', 'replace', 'insert'], p=[0.25, 0.25, 0.25, 0.25])
        
        if augmentation_type == 'delete':
            
            if len(word) < 2:
                return word
            
            ind = np.random.randint(0, len(word) - 1)
            return word[:ind] + word[ind + 1:]
        
        elif augmentation_type == 'transpose':
            
            if len(word) < 3:
                return word
            
            ind = np.random.randint(1, len(word) - 1)
            L, R = word[:ind], word[ind:]
            return L + R[1] + R[0] + R[2:]
        
        elif augmentation_type == 'replace':
            c = chr(ord('a') + np.random.randint(0, 26))
            ind = np.random.randint(0, len(word))
            return word[:ind] + c + word[ind + 1:]
        else:
            c = chr(ord('a') + np.random.randint(0, 26))
            ind = np.random.randint(0, len(word) + 1)
            return word[:ind] + c + word[ind:]

    else:
        return word

In [8]:
def custom_collate(batch):

    batch_correct, batch_augmented = [], []

    for sentence in batch:
        tokenized_sentence = basicTokenizer.tokenize(sentence.lower())
        augmented_sentence = list(map(random_augmented, tokenized_sentence))
        batch_correct.append(' '.join(tokenized_sentence))
        batch_augmented.append(' '.join(augmented_sentence))
    
    batch_correct = tokenizer(batch_correct, padding='max_length', truncation=True, max_length=max_number_of_tokens, return_tensors='pt',)
    batch_augmented = tokenizer(batch_augmented, padding='max_length', truncation=True, max_length=max_number_of_tokens, return_tensors='pt')
    
    batch_correct['input_ids'][batch_correct['attention_mask'] == 0] = -100

    batch_augmented['labels'] = batch_correct['input_ids']
    return batch_augmented

In [154]:
from torch.utils.data import DataLoader
batch_size = 20

train_dataloader = DataLoader(dataset['train']['text'], shuffle=True, batch_size=batch_size, collate_fn=custom_collate)
eval_dataloader = DataLoader(dataset['test']['text'], batch_size=batch_size, collate_fn=custom_collate)

In [10]:
for batch in train_dataloader:
    print({k: v.shape for k, v in batch.items()})
    # print(model(**batch_correct).logits.shape)
    break

{'input_ids': torch.Size([20, 64]), 'attention_mask': torch.Size([20, 64]), 'labels': torch.Size([20, 64])}


In [301]:
def acc_score(preds, labels):

    max_len = max(len(preds), len(labels))
    preds.extend([''] * (max_len - len(preds)))
    labels.extend([''] * (max_len - len(labels)))

    acc = 0
    for pred, label in zip(preds, labels):
        acc += (pred == label)
    
    return acc / len(preds)

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    acc = []

    for pred, label in zip(decoded_preds, decoded_labels):
        pred = pred.split()
        label = label.split()

        acc.append(acc_score(pred, label))


    return acc

In [13]:
import torch

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
;

''

In [14]:
from tqdm.auto import tqdm
from torch.optim import AdamW
from transformers import get_scheduler

def train(train_dataloader, model, num_epochs = 10):
    
    num_training_steps = num_epochs * len(train_dataloader)

    optimizer = AdamW(model.parameters(), lr=5e-5)
    lr_scheduler = get_scheduler(
        name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
    )

    progress_bar = tqdm(range(num_training_steps))

    model.train()
    for epoch in range(num_epochs):
        for batch in train_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
    
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, f'checkpoint-epoch-{epoch}.pt')
train(train_dataloader, model)

100%|██████████| 60000/60000 [1:47:02<00:00,  9.30it/s]

In [18]:
#checkpoint = torch.load('checkpoint.pt')

In [155]:
def eval(model, eval_dataloader):
    model.eval()

    mean_acc = []

    for batch in tqdm(eval_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}

        with torch.no_grad():
            outputs = model.generate(**batch, max_new_tokens=max_number_of_tokens, num_beams=4, do_sample=True, temperature=1.1)

        labels = batch['labels'].to('cpu')
        mean_acc.extend(compute_metrics((outputs, labels)))
    print(f'Accuracy:{sum(mean_acc) / len(mean_acc)}')
eval(model, eval_dataloader)

100%|██████████| 380/380 [09:06<00:00,  1.44s/it]

Accuracy:0.6805173401695055





### Fine-tune on google_wellformed_query dataset

In [311]:
google_query_dataset = load_dataset("google_wellformed_query")

In [118]:
train(DataLoader(google_query_dataset['train']['content'], shuffle=True, batch_size=batch_size, collate_fn=custom_collate), model)

100%|██████████| 8750/8750 [14:56<00:00,  9.76it/s]


In [157]:
eval(model, DataLoader(google_query_dataset['validation']['content'], batch_size=batch_size, collate_fn=custom_collate))

100%|██████████| 188/188 [02:43<00:00,  1.15it/s]

Accuracy:0.7454319263128661





## Justify your decisions

Write down justificaitons for your implementation choices. For example, these choices could be:
- Which ngram dataset to use
- Which weights to assign for edit1, edit2 or absent words probabilities
- Beam search parameters
- etc.

### The model was trained on two datasets:

1. [Ag_news](https://huggingface.co/datasets/ag_news) 

    AG is a collection of more than 1 million news articles. News articles have been gathered from more than 2000 news sources.

    ### Data Splits

    | name  |train |test|
    |-------|-----:|---:|
    |default|120000|7600|
    
&nbsp;

2. [Google wellformed query](https://huggingface.co/datasets/google_wellformed_query)

    Google's query wellformedness dataset was created by crowdsourcing well-formedness annotations for 25,100 queries from the Paralex corpus

    ### Data Splits

    | name  |train |valid|
    |-------|-----:|---:|
    |default|17500|3750|

### How to create spelling mistake?

The approach is to generate some modificationsof word with an optional probability parameter (default: 0.4) which determines the likelihood of applying a modification to the word. The types of modifications(all edit distance 1) include deletion, transposition, replacement, and insertion. The probabilities for each type of modification are equal.

Here's a brief overview of what each modification does:

1. **Deletion**: Randomly deletes one character from the word, unless the word has only one character.
2. **Transposition**: Randomly selects a character in the word (except the first and last characters) and swaps it with its adjacent character.
3. **Replacement**: Randomly selects a character in the word and replaces it with a randomly chosen alphabet character.
4. **Insertion**: Randomly inserts a randomly chosen alphabet character at a random position in the word.

The reason to use only edit 1 is because edit distance 2 can cause huge deviations from the original word and edit distance 1 is the most common type of misspelling.

### Further improvement

1. Increase the edit distance depending on the length of the word.
2. Train on more data.


## Evaluate on a test set

Your task is to generate a test set and evaluate your work. You may vary the noise probability to generate different datasets with varying compexity. Compare your solution to the Norvig's corrector, and report the accuracies.

### Norvig solution

In [134]:
import re
from collections import Counter

def words(text): return re.findall(r'\w+', text.lower())

WORDS = Counter(words(open('big.txt').read()))

def P(word, N=sum(WORDS.values())): 
    "Probability of `word`."
    return WORDS[word] / N

def correction(word): 
    "Most probable spelling correction for word."
    return max(candidates(word), key=P)

def candidates(word): 
    "Generate possible spelling corrections for word."
    return (known([word]) or known(edits1(word)) or known(edits2(word)) or [word])

def known(words): 
    "The subset of `words` that appear in the dictionary of WORDS."
    return set(w for w in words if w in WORDS)

def edits1(word):
    "All edits that are one edit away from `word`."
    letters    = 'abcdefghijklmnopqrstuvwxyz'
    splits     = [(word[:i], word[i:])    for i in range(len(word) + 1)]
    deletes    = [L + R[1:]               for L, R in splits if R]
    transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R)>1]
    replaces   = [L + c + R[1:]           for L, R in splits if R for c in letters]
    inserts    = [L + c + R               for L, R in splits for c in letters]
    return set(deletes + transposes + replaces + inserts)

def edits2(word): 
    "All edits that are two edits away from `word`."
    return (e2 for e1 in edits1(word) for e2 in edits1(e1))

### Test and compare two methods

In [314]:
def preprocess(example):
    
    lower = list(map(str.lower, example['content']))
    tokenized = list(map(basicTokenizer.tokenize, lower))
    augmented = [list(map(random_augmented, sentence)) for sentence in tokenized]
    
    mask = []
    for sentence1, sentence2 in zip(tokenized, augmented):
        mask.append([False if word1 == word2 else True for word1, word2 in zip(sentence1, sentence2)])

    augmented = [' '.join(sentence) for sentence in tokenized]
    tokenized = [' '.join(sentence) for sentence in tokenized]

    example = example.add_column('augmented', augmented)
    example = example.add_column('label', tokenized)
    example = example.add_column('mask', mask)
    
    example = example.remove_columns(['content', 'rating'])
    return example

In [315]:
google_query_dataset['test'] = preprocess(google_query_dataset['test'])

In [316]:
google_query_dataset['test']

Dataset({
    features: ['augmented', 'label', 'mask'],
    num_rows: 3850
})

In [319]:
def test_collate(batch):

    batch = {
        'label': [dct['label'] for dct in batch],
        'augmented': [dct['augmented'] for dct in batch],
        'mask': [dct['mask'] for dct in batch],
    }
    
    batch_augmented = tokenizer(batch['augmented'], padding='max_length', truncation=True, max_length=max_number_of_tokens, return_tensors='pt')
    batch_augmented['labels'] = batch['label']
    batch_augmented['mask'] = batch['mask']
    return batch_augmented

In [320]:
testLoader = DataLoader(google_query_dataset['test'], batch_size=1, collate_fn=test_collate)

In [337]:
def precision_recall_score(preds, labels, mask):
    
    max_len = max(len(preds), len(labels))
    preds.extend([''] * (max_len - len(preds)))
    labels.extend([''] * (max_len - len(labels)))
    mask.extend([False] * (max_len - len(mask)))
    
    # print(f'{preds=}')
    # print(f'{labels=}')
    # print(f'{mask=}')
    
    TP, FP, FN, TN = 0, 0, 0, 0

    for pred, label, mask in zip(preds, labels, mask):
        if pred == label:
            if mask:
                TP += 1
            else:
                TN += 1
        else:
            if mask:
                FP += 1
            else:
                #mask False means input = label
                FN += 1
    return TP, FP, FN, TN

In [338]:
model.eval()
transformer_score, norvig_score = {'TP': 0, 'FP': 0, 'FN': 0, 'TN': 0}, {'TP': 0, 'FP': 0, 'FN': 0, 'TN': 0}

for batch in tqdm(testLoader):
    
    label = basicTokenizer.tokenize(batch['labels'][0])

    with torch.no_grad():
        input_ids = batch['input_ids'].to('cuda')
        attention_mask = batch['attention_mask'].to('cuda')
        outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_number_of_tokens, num_beams=4, do_sample=True, temperature=1.1)
        
        decoded = basicTokenizer.tokenize(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
        TP, FP, FN, TN = precision_recall_score(decoded, label, batch['mask'][0])
        
        transformer_score['TP'] += TP
        transformer_score['FP'] += FP
        transformer_score['FN'] += FN
        transformer_score['TN'] += TN

        tokenized = basicTokenizer.tokenize(tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True)[0])
        norvig_output = list(map(correction, tokenized))

        TP, FP, FN, TN = precision_recall_score(norvig_output, label, batch['mask'][0])
        norvig_score['TP'] += TP
        norvig_score['FP'] += FP
        norvig_score['FN'] += FN
        norvig_score['TN'] += TN


print(f"Transformer solution\nRecall:{transformer_score['TP'] / (transformer_score['TP'] + transformer_score['FN'])}\n\
      Precision:{transformer_score['TP'] / (transformer_score['TP'] + transformer_score['FP'])}")
print(f"Norvig \nRecall:{norvig_score['TP'] / (norvig_score['TP'] + norvig_score['FN'])}\n\
      Precision:{norvig_score['TP'] / (norvig_score['TP'] + norvig_score['FP'])}")

100%|██████████| 3850/3850 [10:44<00:00,  5.97it/s]

Transformer solution
Recall:0.8646024665442141
      Precision:0.9522204026587034
Norvig 
Recall:0.8440992095938948
      Precision:0.8950004816491668





## Example

You can also try your example using [HuggingFace Inference API](https://huggingface.co/the-hir0/google-t5-small-spellchecker)

In [96]:
model.eval()
promt = 'Thsi is a Japanees dull'

input = tokenizer(promt, return_tensors='pt').to('cuda')

output = model.generate(**input, max_new_tokens=max_number_of_tokens, num_beams=4, do_sample=True, temperature=1.1)
tokenizer.batch_decode(output, skip_special_tokens=True)

['this is a japanese dull']