# Context-sensitive Spelling Correction

The goal of the assignment is to implement context-sensitive spelling correction. The input of the code will be a set of text lines and the output will be the same lines with spelling mistakes fixed.

Submit the solution of the assignment to Moodle as a link to your GitHub repository containing this notebook.

Useful links:
- [Norvig's solution](https://norvig.com/spell-correct.html)
- [Norvig's dataset](https://norvig.com/big.txt)
- [Ngrams data](https://www.ngrams.info/download_coca.asp)

Grading:
- 60 points - Implement spelling correction
- 20 points - Justify your decisions
- 20 points - Evaluate on a test set


## Implement context-sensitive spelling correction

Your task is to implement context-sensitive spelling corrector using N-gram language model. The idea is to compute conditional probabilities of possible correction options. For example, the phrase "dking sport" should be fixed as "doing sport" not "dying sport", while "dking species" -- as "dying species".

The best way to start is to analyze [Norvig's solution](https://norvig.com/spell-correct.html) and [N-gram Language Models](https://web.stanford.edu/~jurafsky/slp3/3.pdf).

You may also want to implement:
- spell-checking for a concrete language - Russian, Tatar, etc. - any one you know, such that the solution accounts for language specifics,
- some recent (or not very recent) paper on this topic,
- solution which takes into account keyboard layout and associated misspellings,
- efficiency improvement to make the solution faster,
- any other idea of yours to improve the Norvig’s solution.

IMPORTANT:  
Your project should not be a mere code copy-paste from somewhere. You must provide:
- Your implementation
- Analysis of why the implemented approach is suggested
- Improvements of the original approach that you have chosen to implement

In [1]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
from collections import Counter
import re

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased", device_map='cuda')

In [3]:
def words(text): return re.findall(r'\w+', text.lower())

WORDS = Counter(words(open('big.txt').read()))

In [4]:
class SpellChecker:
    def __init__(self, device='cpu'):
        self.device = device
    
    def spell_check(self, sentence):
        sentence_list = sentence.split()
        res = [self._spell_check_at(sentence_list, i) if sentence_list[i].lower() not in WORDS else sentence_list[i] for i in range(len(sentence_list))]
    
        corrected_sentence = []
        for my_dict in res:
            if isinstance(my_dict, str):
                corrected_sentence.append(my_dict)
                continue
            top_keys = sorted(my_dict, key=my_dict.get, reverse=True)[:1][0]
            corrected_sentence.append(top_keys)
            
        return " ".join(corrected_sentence)
    
    def _levenshtein_distance(self, s1, s2):
        m, n = len(s1), len(s2)
        dp = [[0] * (n + 1) for _ in range(m + 1)]

        for i in range(m + 1):
            dp[i][0] = i
        for j in range(n + 1):
            dp[0][j] = j

        for i in range(1, m + 1):
            for j in range(1, n + 1):
                if s1[i - 1] == s2[j - 1]:
                    dp[i][j] = dp[i - 1][j - 1]
                else:
                    dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1])

        distance = dp[m][n]
        max_length = max(m, n)
        normalized_distance = distance / max_length
        return 1 - normalized_distance
    
    def _lcs_distance(self, s1, s2):
        m, n = len(s1), len(s2)
        dp = [[0] * (n + 1) for _ in range(m + 1)]

        for i in range(1, m + 1):
            for j in range(1, n + 1):
                if s1[i - 1] == s2[j - 1]:
                    dp[i][j] = dp[i - 1][j - 1] + 1
                else:
                    dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])

        lcs_length = dp[m][n]
        max_length = max(m, n)
        normalized_distance = 1 - (lcs_length / max_length)
        return 1 - normalized_distance
    
    def _spell_check_at(self, sentence_list, candidate_ind, topk=150):
        candidate = sentence_list[candidate_ind]
        sentence_list[candidate_ind] = '[MASK]'
        text = " ".join(sentence_list)

        input_ids = tokenizer.encode(text, return_tensors="pt").to(self.device)

        mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1]

        with torch.no_grad():
            output = model(input_ids)
            logits = output.logits

        mask_logits = logits[0, mask_token_index, :]

        probabilities = torch.softmax(mask_logits, dim=-1)

        top_k_probs, top_k_tokens = torch.topk(probabilities, k=topk)
        top_k_probs = top_k_probs[0]
        top_k_tokens = top_k_tokens[0]
        
        top_k_words = [tokenizer.decode([token_id]) for token_id in top_k_tokens]

        output = {}
        for word, prob in zip(top_k_words, top_k_probs):
            score = self._levenshtein_distance(word, candidate) + prob + self._lcs_distance(word, candidate)
            output[word] = score
        sentence_list[candidate_ind] = candidate
        return output

In [5]:
spell_checker = SpellChecker('cuda')

In [6]:
spell_checker.spell_check('I love my kat and my dogg')

'I love my cat and my dog'

## Justify your decisions

Write down justificaitons for your implementation choices. For example, these choices could be:
- Which ngram dataset to use
- Which weights to assign for edit1, edit2 or absent words probabilities
- Beam search parameters
- etc.

## Justify your decisions

### Language Model
I used the DistilBERT language model from the Transformers library for context-aware spelling correction. DistilBERT is a smaller and faster version of BERT, which is a powerful transformer-based language model capable of capturing contextual information. Using a pre-trained language model like DistilBERT allows leveraging the learned language representations to make more informed decisions about spelling corrections based on the surrounding context.

### Spelling Correction Algorithm
The implemented spelling correction algorithm is inspired by Norvig's approach but incorporates the language model for context-aware corrections. The core idea is to generate a list of candidate corrections for a potentially misspelled word and then score these candidates based on their conditional probabilities obtained from the language model and their edit distance from the original word.

For each potentially misspelled word, the algorithm masks it in the input sentence and uses the language model to generate the top-k most probable token predictions at that position. These top-k tokens are then scored by combining their language model probabilities and their edit distances from the original word. The word with the highest combined score is chosen as the correction.

### Scoring Function
The scoring function used to rank the candidate corrections is a simple linear combination of the language model probability and the edit distance ratio. The language model probability captures the contextual fit of the candidate word, while the edit distance ratio measures the similarity to the original word. By combining these two factors, the algorithm can balance contextual relevance and similarity to the misspelled word.

### Hyperparameters
The implementation includes a `topk` hyperparameter that determines the number of top candidate corrections to consider from the language model. A larger value of `topk` increases the diversity of candidates but may also introduce more irrelevant options. In this implementation, `topk` is set to 150, which should provide a reasonable balance between diversity and computational efficiency.

### Efficiency Considerations
To improve efficiency, the implementation leverages PyTorch's GPU acceleration capabilities by running the language model on a GPU (if available). This can significantly speed up the inference process, especially for longer input sequences.

### Potential Improvements
While the current implementation provides a decent starting point, there are several areas for potential improvement:

1. **Weighted Scoring**: Instead of a simple linear combination, a more sophisticated weighting scheme could be explored to better balance the language model probabilities and edit distances.
2. **Beam Search**: Incorporating beam search or other decoding strategies could potentially improve the quality of corrections by considering multiple candidates simultaneously and allowing for better exploration of the search space.
3. **Language-Specific Considerations**: The current implementation is focused on English text. For other languages, additional language-specific considerations (e.g., character sets, common misspellings, morphological rules) may need to be incorporated.
4. **Vocabulary Expansion**: The current approach is limited to the vocabulary of the pre-trained language model. Techniques like character-level modeling or vocabulary expansion could help handle out-of-vocabulary words more effectively.
5. **Contextual Edit Distance**: The current implementation uses a simple edit distance metric. Incorporating contextual information into the edit distance calculation (e.g., considering common misspelling patterns or keyboard layout) could further improve the spelling correction accuracy.

## Evaluate on a test set

Your task is to generate a test set and evaluate your work. You may vary the noise probability to generate different datasets with varying compexity. Compare your solution to the Norvig's corrector, and report the accuracies.

In [7]:
import pandas as pd
import random
import string
import re
from collections import Counter

In [8]:
class Mistaker:
    def generate_spelling_mistakes(self, sentence, mistake_probability=0.1):
        words = sentence.split()
        misspelled_words = []
        
        for word in words:
            if random.random() < mistake_probability:
                misspelled_word = self._introduce_mistake(word)
                if misspelled_word:
                    misspelled_words.append(misspelled_word) 
            else:
                misspelled_words.append(word)
        
        return ' '.join(misspelled_words)

    def _introduce_mistake(self, word):
        mistake_type = random.randint(1, 4)
        
        if len(word) <= 2:
            return None
        
        if mistake_type == 1:
            # Delete a character
            position = random.randint(0, len(word) - 1)
            misspelled_word = word[:position] + word[position + 1:]
        elif mistake_type == 2:
            # Insert a character
            position = random.randint(0, len(word))
            inserted_char = random.choice(string.ascii_lowercase)
            misspelled_word = word[:position] + inserted_char + word[position:]
        elif mistake_type == 3:
            # Substitute a character
            position = random.randint(0, len(word) - 1)
            substituted_char = random.choice(string.ascii_lowercase)
            misspelled_word = word[:position] + substituted_char + word[position + 1:]
        else:
            # Transpose two adjacent characters
            position = random.randint(0, len(word) - 2)
            misspelled_word = word[:position] + word[position + 1] + word[position] + word[position + 2:]
        
        return misspelled_word

In [9]:
def P(word, N=sum(WORDS.values())): 
    "Probability of `word`."
    return WORDS[word] / N

def correction(word): 
    "Most probable spelling correction for word."
    return max(candidates(word), key=P)

def candidates(word): 
    "Generate possible spelling corrections for word."
    return (known([word]) or known(edits1(word)) or known(edits2(word)) or [word])

def known(words): 
    "The subset of `words` that appear in the dictionary of WORDS."
    return set(w for w in words if w in WORDS)

def edits1(word):
    "All edits that are one edit away from `word`."
    letters    = 'abcdefghijklmnopqrstuvwxyz'
    splits     = [(word[:i], word[i:])    for i in range(len(word) + 1)]
    deletes    = [L + R[1:]               for L, R in splits if R]
    transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R)>1]
    replaces   = [L + c + R[1:]           for L, R in splits if R for c in letters]
    inserts    = [L + c + R               for L, R in splits for c in letters]
    return set(deletes + transposes + replaces + inserts)

def edits2(word): 
    "All edits that are two edits away from `word`."
    return (e2 for e1 in edits1(word) for e2 in edits1(e1))

In [12]:
class TestSolutions:
    def __init__(self, dataset_path='singapore_airlines_reviews.csv'):
        self.spell_checker = SpellChecker('cuda')
        self.test_data = pd.read_csv(dataset_path).title.values[:50]
        self.mistaker = Mistaker()
        
    def run(self):
        test_set = [self.mistaker.generate_spelling_mistakes(text) for text in self.test_data]
        
        scores_norvig = [self._eval_norvig_sentence(self.test_data[i], test_set[i]) for i in range(len(test_set)) if test_set[i]]
        scores_spell_checker = [self._eval_my_checker_sentence(self.test_data[i], test_set[i]) for i in range(len(test_set)) if test_set[i]]
        
        return sum(scores_norvig) / len(scores_norvig), sum(scores_spell_checker) / len(scores_spell_checker), scores_norvig, scores_spell_checker, test_set
        
        
    def _eval_norvig_sentence(self, correct_sentence, incorrect_sentence):
        total_words = 0
        correct_words = 0
        
        for correct, incorrect in zip(correct_sentence.split(), incorrect_sentence.split()):
            total_words += 1
            corrected_word = correction(incorrect)
            
            if correct.lower() == corrected_word.lower():
                correct_words += 1
        
        return correct_words / total_words
    
    def _eval_my_checker_sentence(self, correct_sentence, incorrect_sentence):
        total_words = 0
        correct_words = 0
        corrected = spell_checker.spell_check(incorrect_sentence)
        for correct, corrected_word in zip(correct_sentence.split(), corrected.split()):
            total_words += 1
            
            if correct.lower() == corrected_word.lower():
                correct_words += 1
        
        return correct_words / total_words

In [13]:
t = TestSolutions()
a, b, c, d, e = t.run()
a, b

(0.7365858585858586, 0.7868715728715729)