# Context-sensitive Spelling Correction

The goal of the assignment is to implement context-sensitive spelling correction. The input of the code will be a set of text lines and the output will be the same lines with spelling mistakes fixed.

Submit the solution of the assignment to Moodle as a link to your GitHub repository containing this notebook.

Useful links:
- [Norvig's solution](https://norvig.com/spell-correct.html)
- [Norvig's dataset](https://norvig.com/big.txt)
- [Ngrams data](https://www.ngrams.info/download_coca.asp)

Grading:
- 60 points - Implement spelling correction
- 20 points - Justify your decisions
- 20 points - Evaluate on a test set


## Implement context-sensitive spelling correction

Your task is to implement context-sensitive spelling corrector using N-gram language model. The idea is to compute conditional probabilities of possible correction options. For example, the phrase "dking sport" should be fixed as "doing sport" not "dying sport", while "dking species" -- as "dying species".

The best way to start is to analyze [Norvig's solution](https://norvig.com/spell-correct.html) and [N-gram Language Models](https://web.stanford.edu/~jurafsky/slp3/3.pdf).

You may also want to implement:
- spell-checking for a concrete language - Russian, Tatar, etc. - any one you know, such that the solution accounts for language specifics,
- some recent (or not very recent) paper on this topic,
- solution which takes into account keyboard layout and associated misspellings,
- efficiency improvement to make the solution faster,
- any other idea of yours to improve the Norvig’s solution.

IMPORTANT:  
Your project should not be a mere code copy-paste from somewhere. You must provide:
- Your implementation
- Analysis of why the implemented approach is suggested
- Improvements of the original approach that you have chosen to implement

In [19]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
from collections import Counter
import re
import pandas as pd
import random
import time
import string

In [20]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased", device_map='cuda')

In [21]:
def words(text): return re.findall(r'\w+', text.lower())

WORDS = Counter(words(open('big.txt').read()))

In [22]:
class SpellChecker:
    def __init__(self, device='cpu'):
        self.device = device
    
    def spell_check(self, sentence: str):
        """Spell check a sentence using a dictionary of words.
        Parameters:
            - sentence (str): The sentence to be spell checked.
        Returns:
            - str: The corrected sentence with any misspelled words replaced.
        Processing Logic:
            - Split the sentence into a list of words.
            - Check each word in the list for spelling errors.
            - If a word is misspelled, replace it with the most likely correct word.
            - Join the corrected words back into a sentence."""
        sentence_list = sentence.split()
        res = [self._spell_check_at(sentence_list, i) if sentence_list[i].lower() not in WORDS else sentence_list[i] for i in range(len(sentence_list))]
    
        corrected_sentence = []
        for my_dict in res:
            if isinstance(my_dict, str):
                corrected_sentence.append(my_dict)
                continue
            top_keys = sorted(my_dict, key=my_dict.get, reverse=True)[:1][0]
            corrected_sentence.append(top_keys)
            
        return " ".join(corrected_sentence)
    
    def _levenshtein_distance(self, s1: str, s2: str):
        """Calculates the Levenshtein distance between two strings.
        Parameters:
            - s1 (str): First string.
            - s2 (str): Second string.
        Returns:
            - float: Normalized Levenshtein distance between the two strings.
        Processing Logic:
            - Creates a matrix of size (m+1) x (n+1).
            - Fills the first row and column with increasing numbers.
            - Calculates the distance between the two strings using dynamic programming.
            - Normalizes the distance by dividing it by the maximum length of the two strings.
            - Returns the normalized distance."""
            
        m, n = len(s1), len(s2)
        dp = [[0] * (n + 1) for _ in range(m + 1)]

        for i in range(m + 1):
            dp[i][0] = i
        for j in range(n + 1):
            dp[0][j] = j

        for i in range(1, m + 1):
            for j in range(1, n + 1):
                if s1[i - 1] == s2[j - 1]:
                    dp[i][j] = dp[i - 1][j - 1]
                else:
                    dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1])

        distance = dp[m][n]
        max_length = max(m, n)
        normalized_distance = distance / max_length
        return 1 - normalized_distance
    
    def _lcs_distance(self, s1: str, s2: str):
        """Docstring:
        Calculates the longest common subsequence (LCS) distance between two strings.
        Parameters:
            - s1 (str): First string.
            - s2 (str): Second string.
        Returns:
            - float: The LCS distance between the two strings.
        Processing Logic:
            - Uses dynamic programming.
            - Calculates LCS length and max length.
            - Normalizes distance and returns it."""
        m, n = len(s1), len(s2)
        dp = [[0] * (n + 1) for _ in range(m + 1)]

        for i in range(1, m + 1):
            for j in range(1, n + 1):
                if s1[i - 1] == s2[j - 1]:
                    dp[i][j] = dp[i - 1][j - 1] + 1
                else:
                    dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])

        lcs_length = dp[m][n]
        max_length = max(m, n)
        normalized_distance = 1 - (lcs_length / max_length)
        return 1 - normalized_distance
    
    def _spell_check_at(self, sentence_list: list[str], candidate_ind: int, topk: int = 150):
        """Spell checks a word in a sentence by replacing it with a [MASK] token and predicting the most probable words using a pre-trained language model.
        Parameters:
            - sentence_list (list[str]): List of words in the sentence.
            - candidate_ind (int): Index of the word to be spell checked.
            - topk (int): Number of top predicted words to be returned. Default is 150.
        Returns:
            - output (dict): Dictionary containing the top predicted words and their corresponding scores.
        Processing Logic:
            - Replaces the word at candidate_ind with a [MASK] token.
            - Encodes the sentence using a pre-trained tokenizer.
            - Uses a pre-trained language model to predict the most probable words.
            - Calculates a score for each predicted word based on its probability, Levenshtein distance, and longest common subsequence distance.
            - Returns a dictionary of the top predicted words and their scores."""
        candidate = sentence_list[candidate_ind]
        sentence_list[candidate_ind] = '[MASK]'
        text = " ".join(sentence_list)

        input_ids = tokenizer.encode(text, return_tensors="pt").to(self.device)

        mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1]

        with torch.no_grad():
            output = model(input_ids)
            logits = output.logits

        mask_logits = logits[0, mask_token_index, :]

        probabilities = torch.softmax(mask_logits, dim=-1)

        top_k_probs, top_k_tokens = torch.topk(probabilities, k=topk)
        top_k_probs = top_k_probs[0]
        top_k_tokens = top_k_tokens[0]
        
        top_k_words = [tokenizer.decode([token_id]) for token_id in top_k_tokens]

        output = {}
        for word, prob in zip(top_k_words, top_k_probs):
            score = self._levenshtein_distance(word, candidate) + prob + self._lcs_distance(word, candidate)
            output[word] = score
        sentence_list[candidate_ind] = candidate
        return output

## Justify your decisions

### Idea

My idea is to check each word (let it be `w`) in the target sentence, get the probabilities of the top k words instead of `w` using a language model, calculate various metrics between `w` and each top k element and use a combination of probability and metrics as a word score instead of `w`.

### Concrete implementation

* using distilbert for language model - this model is simple enough that it can be run on a PC and works quite well for this task, since it is trained on a fairly good amount of data and can understand the context
* using Levenshtein distance and longest common subsequence as metrics - I chose these metrics because I remember the ideas behind them and how to write them. I think there are many algorithms that can be used to compare strings, so it's hard to give a rationale for how to choose the best ones.
* simple insertion of new metrics - you need to implement the function (should return a score from 0 to 1) and simply add its result

### Possible updates

* create weights for the measure - we can do this after collecting data sets with each feature score and determining whether the word with the highest score is the target word or not. Using the dataset, we can train a simple logistic regression to produce the weights.
* add new algorithms - for example, phonetic algorithms or n-gram similarity
* take a better language model or a model trained on text similar to the target

## Evaluate on a test set

Your task is to generate a test set and evaluate your work. You may vary the noise probability to generate different datasets with varying compexity. Compare your solution to the Norvig's corrector, and report the accuracies.

In [23]:
class Mistaker:
    def generate_spelling_mistakes(self, sentence: str, mistake_probability: float = 0.1):
        """Generates spelling mistakes in a sentence.
        Parameters:
            - sentence (str): The sentence to generate spelling mistakes in.
            - mistake_probability (float): The probability of a word being misspelled, defaults to 0.1.
        Returns:
            - str: The sentence with spelling mistakes.
        Processing Logic:
            - Splits the sentence into words.
            - Generates a list of misspelled words.
            - Joins the misspelled words into a sentence."""
        words = sentence.split()
        misspelled_words = []
        
        for word in words:
            if random.random() < mistake_probability:
                misspelled_word = self._introduce_mistake(word)
                if misspelled_word:
                    misspelled_words.append(misspelled_word) 
            else:
                misspelled_words.append(word)
        
        return ' '.join(misspelled_words)

    def _introduce_mistake(self, word: str):
        mistake_type = random.randint(1, 4)
        
        if len(word) <= 2:
            return None
        
        if mistake_type == 1:
            # Delete a character
            position = random.randint(0, len(word) - 1)
            misspelled_word = word[:position] + word[position + 1:]
        elif mistake_type == 2:
            # Insert a character
            position = random.randint(0, len(word))
            inserted_char = random.choice(string.ascii_lowercase)
            misspelled_word = word[:position] + inserted_char + word[position:]
        elif mistake_type == 3:
            # Substitute a character
            position = random.randint(0, len(word) - 1)
            substituted_char = random.choice(string.ascii_lowercase)
            misspelled_word = word[:position] + substituted_char + word[position + 1:]
        else:
            # Transpose two adjacent characters
            position = random.randint(0, len(word) - 2)
            misspelled_word = word[:position] + word[position + 1] + word[position] + word[position + 2:]
        
        return misspelled_word

### Norvig Solution

In [24]:
def P(word, N=sum(WORDS.values())): 
    "Probability of `word`."
    return WORDS[word] / N

def correction(word): 
    "Most probable spelling correction for word."
    return max(candidates(word), key=P)

def candidates(word): 
    "Generate possible spelling corrections for word."
    return (known([word]) or known(edits1(word)) or known(edits2(word)) or [word])

def known(words): 
    "The subset of `words` that appear in the dictionary of WORDS."
    return set(w for w in words if w in WORDS)

def edits1(word):
    "All edits that are one edit away from `word`."
    letters    = 'abcdefghijklmnopqrstuvwxyz'
    splits     = [(word[:i], word[i:])    for i in range(len(word) + 1)]
    deletes    = [L + R[1:]               for L, R in splits if R]
    transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R)>1]
    replaces   = [L + c + R[1:]           for L, R in splits if R for c in letters]
    inserts    = [L + c + R               for L, R in splits for c in letters]
    return set(deletes + transposes + replaces + inserts)

def edits2(word): 
    "All edits that are two edits away from `word`."
    return (e2 for e1 in edits1(word) for e2 in edits1(e1))

### Tests

In [25]:
class TestSolutions:
    def __init__(self, dataset_path='singapore_airlines_reviews.csv', device='cpu'):
        self.spell_checker = SpellChecker(device)
        self.test_data = pd.read_csv(dataset_path).title.values[:500]
        self.mistaker = Mistaker()
        
    def run(self):
        test_set = [self.mistaker.generate_spelling_mistakes(text) for text in self.test_data]
        
        start_norvig = time.time()
        scores_norvig = [self._eval_norvig_sentence(self.test_data[i], test_set[i]) for i in range(len(test_set)) if test_set[i]]
        end_norvig = time.time()
        
        start_spell_checker = time.time()
        scores_spell_checker = [self._eval_my_checker_sentence(self.test_data[i], test_set[i]) for i in range(len(test_set)) if test_set[i]]
        end_spell_checker = time.time()
        
        print('Norvig exucution time: ', end_norvig - start_norvig)
        print('My spell checker exucution time: ', end_spell_checker - start_spell_checker)
        
        return sum(scores_norvig) / len(scores_norvig), sum(scores_spell_checker) / len(scores_spell_checker)
        
        
    def _eval_norvig_sentence(self, correct_sentence, incorrect_sentence):
        total_words = 0
        correct_words = 0
        
        for correct, incorrect in zip(correct_sentence.split(), incorrect_sentence.split()):
            total_words += 1
            corrected_word = correction(incorrect)
            
            if correct.lower() == corrected_word.lower():
                correct_words += 1
        
        return correct_words / total_words
    
    def _eval_my_checker_sentence(self, correct_sentence, incorrect_sentence):
        total_words = 0
        correct_words = 0
        corrected = self.spell_checker.spell_check(incorrect_sentence)
        for correct, corrected_word in zip(correct_sentence.split(), corrected.split()):
            total_words += 1
            
            if correct.lower() == corrected_word.lower():
                correct_words += 1
        
        return correct_words / total_words

In [27]:
tester = TestSolutions(device='cuda')
norvig_score, spell_checker_score = tester.run()

Norvig exucution time:  34.195584297180176
My spell checker exucution time:  91.62865161895752


In [28]:
print("Norvig accuracy: ", norvig_score)
print('My spell checker accuracy: ', spell_checker_score)

Norvig accuracy:  0.7464436674522592
My spell checker accuracy:  0.7781652153488069
