# Context-sensitive Spelling Correction

The goal of the assignment is to implement context-sensitive spelling correction. The input of the code will be a set of text lines and the output will be the same lines with spelling mistakes fixed.

Submit the solution of the assignment to Moodle as a link to your GitHub repository containing this notebook.

Useful links:
- [Norvig's solution](https://norvig.com/spell-correct.html)
- [Norvig's dataset](https://norvig.com/big.txt)
- [Ngrams data](https://www.ngrams.info/download_coca.asp)

Grading:
- 60 points - Implement spelling correction
- 20 points - Justify your decisions
- 20 points - Evaluate on a test set


## Implement context-sensitive spelling correction

Your task is to implement context-sensitive spelling corrector using N-gram language model. The idea is to compute conditional probabilities of possible correction options. For example, the phrase "dking sport" should be fixed as "doing sport" not "dying sport", while "dking species" -- as "dying species".

The best way to start is to analyze [Norvig's solution](https://norvig.com/spell-correct.html) and [N-gram Language Models](https://web.stanford.edu/~jurafsky/slp3/3.pdf).

You may also want to implement:
- spell-checking for a concrete language - Russian, Tatar, etc. - any one you know, such that the solution accounts for language specifics,
- some recent (or not very recent) paper on this topic,
- solution which takes into account keyboard layout and associated misspellings,
- efficiency improvement to make the solution faster,
- any other idea of yours to improve the Norvig’s solution.

IMPORTANT:  
Your project should not be a mere code copy-paste from somewhere. You must provide:
- Your implementation
- Analysis of why the implemented approach is suggested
- Improvements of the original approach that you have chosen to implement

In [1]:
from abc import ABC, abstractmethod
from collections import defaultdict, Counter
import itertools
import random
import re

import tqdm.notebook as tqdm

### Language models

In [2]:
class LanguageModel(ABC):
    """Abstract context-aware language model.
    """

    @abstractmethod
    def __call__(self, candidates: list[str], pretext: list[str],
                 posttext: list[str]) -> dict[str, float]:
        """Predict the probability of the word given the context.
        
        Args:
            candidates (list[str]): Set of candidate replacements.
            pretext (list[str]): Context preceeding the word.
            posttext (list[str]): Context after the word.
        
        Returns:
            dict[str, float]: Mapping from candidate set to the probability of
                the candidate in the context.
        """
        pass

In [3]:
class NGramLanguageModel(LanguageModel):
    """N-gram language model.

    Attributes:
        total (int): Total number of words in the training set.
        prefix (dict[str, int]): Mapping from a n-gram to a count of the n-gram
            in the training set.
        suffix (dict[str, int]): Mapping from a reversed n-gram to a count of
            the n-gram in the training set. E.g. if suffix[('a', 'b')] = 3,
            then ('b', 'a') occurs 3 times in training dataset.
        n (int): Number of words in n-grams.
    """

    def __init__(self, counts: dict[tuple[str], int], n: int = 3):
        """Initializes a n-gram language model.

        Args:
            counts (dict[tuple[str], int]): Mapping from n-gram to the number of
                times it occurs in the training set.
            n (int): Number of words in n-grams. Defaults to 3.
        """
        self.n = n
        self.total = sum(counts.values())
        self.prefix = defaultdict(lambda: 0)
        self.suffix = defaultdict(lambda: 0)
        # Fill n-grams dictionaries with all k-grams with k <= n
        for gram, count in counts.items():
            for k in range(1, n + 1):
                if len(gram) < k:
                    continue
                kgram = gram[:k]
                self.prefix[kgram] += count
                kgram = kgram[::-1]
                self.suffix[kgram] += count
        self.suffix[tuple()] = self.prefix[tuple()] = self.total


    def __call__(self, candidates: list[str], pretext: list[str],
                 posttext: list[str]) -> dict[str, float]:
        """Predict the probability of the word using n-gram data.

        Args:
            candidates (list[str]): Set of candidate replacements.
            pretext (list[str]): Context preceeding the word.
            posttext (list[str]): Context after the word.
        
        Returns:
            dict[str, float]: Mapping from candidate set to the probability of
                the candidate in the context.
        """
        pretext = tuple(pretext[-self.n+1:])
        # Posttext is reversed to find the conditional probability
        posttext = tuple(posttext[:self.n - 1][::-1])
        result = {}
        for candidate in candidates:
            # Use beta distribution mean to smooth the probabilities for
            # infrequent words
            prev_prob = (self.prefix[pretext + (candidate, )] +
                         1) / (self.prefix[pretext] + 2)
            next_prob = (self.suffix[posttext + (candidate, )] +
                         1) / (self.suffix[posttext] + 2)
            result[candidate] = prev_prob * next_prob
        return result

### Error model

In [4]:
class ErrorModel(ABC):
    """Abstract model for typo probabilities.
    """
    @abstractmethod
    def __call__(self, word: str, candidate: str) -> float:
        """Returns the probability of word being a misspelled candidate word.
        
        Args:
            word (str): Word that is typed.
            candidate (str): Misspeling correction candidate.

        Returns:
            float: Probability of word being a misspelled candidate word.
        """
        pass

In [5]:
class LevensteinErrorModel(ErrorModel):
    """Error model based on Levenstein distance.

    Attributes:
        base (float): Probability of word being a misspelling of candidate word
            is calculated as base^d, where d is the Levenstein distance between
            the word and the candidate.
    """
    def __init__(self, base: float = 4):
        """Initializes the error model.
        
        Args:
            base (float): Probability of word being a misspelling of candidate word
                is calculated as base^d, where d is the Levenstein distance between
                the word and the candidate.
        """
        self.base = base

    def __call__(self, word: str, candidate: str) -> float:
        """Calculate the probability of the word being a misspelling of the
        candidate.

        Args:
            word (str): Word that is typed.
            candidate (str): Candidate word.
        
        Returns:
            float: Probability of word being a misspelling of candidate.
        """
        word = word.lower()
        candidate = candidate.lower()
        dp = [[0] * (len(candidate) + 1) for _ in range(len(word) + 1)]
        for i in range(len(candidate)):
            dp[0][i + 1] = i + 1
        for i in range(len(word)):
            dp[i + 1][0] = i + 1
        for i in range(1, len(word) + 1):
            for j in range(1, len(candidate) + 1):
                dp[i][j] = min(dp[i - 1][j], dp[i][j - 1]) + 1
                if word[i - 1] == candidate[j - 1]:
                    dp[i][j] = dp[i - 1][j - 1]
        return pow(self.base, -dp[-1][-1])

### Corrector class

In [6]:
class Corrector:
    """Context-aware text corrector.

    Attributes:
        candidates (set[str]): Set of candidate words to choose from.
        language_model (LanguageModel): Context-aware language model to
            predict probabilities of words occuring in a given context.
        error_model (ErrorModel): Model for predicting the probability
            of misspelling one word as another.
        context_size (int): Size of the context given to the language
            model. Does not include the word itselt. Counted individually
            for each direction (in the text "a b c d e", context of "c" is
            ("b", "d"), given `context_size` = 1.) Defaults to 10.
    """

    def __init__(
        self,
        candidates: set[str],
        language_model: LanguageModel,
        error_model: ErrorModel,
        context_size: int = 10,
    ):
        """Context-aware text corrector.

        Args:
            candidates (set[str]): Set of candidate words to choose from.
            language_model (LanguageModel): Context-aware language model to
                predict probabilities of words occuring in a given context.
            error_model (ErrorModel): Model for predicting the probability
                of misspelling one word as another.
            context_size (int): Size of the context given to the language
                model. Does not include the word itselt. Counted individually
                for each direction (in the text "a b c d e", context of "c" is
                ("b", "d"), given `context_size` = 1.) Defaults to 10.
        """
        self.candidates = candidates
        self.language_model = language_model
        self.error_model = error_model
        self.context_size = context_size

    @staticmethod
    def get_words(text: str) -> list[str]:
        """Returns words in the string.
        
        Args:
            text (str): Text to extract the word from.
        
        Returns:
            list[str]: Words in the given text.
        """
        return re.findall('\\w+', text)

    @staticmethod
    def replace_words(text: str, words: list[str]) -> str:
        """Replaces words in the text by the list of given words.

        Args:
            text (str): Text to replace the words in.
            words (list[str]): List of replacements words.
        
        Returns:
            str: Text with words replaced.
        """

        def copy_capitalization(src: str, dst: str) -> str:
            if src[0].isupper():
                return dst.capitalize()
            else:
                return dst

        filler = re.split('\\w+', text)
        original_words = Corrector.get_words(text)
        assert len(original_words) == len(words)
        words = list(
            map(lambda x: copy_capitalization(*x), zip(original_words, words)))
        assert len(filler) == len(words) + 1
        return ''.join(itertools.chain(*zip(filler, words + [''])))

    def correct_words(self, words: list[str]) -> list[str]:
        """Corrects the words, and returns the list of words after correction.
        Expects all words to be in lowercase.

        Args:
            words (list[str]): Words to correct.
        
        Returns:
            list[str]: Corrected words.
        """
        result = words.copy()
        for i, word in enumerate(words):
            pretext = words[i - self.context_size:i]
            posttext = words[i + 1:i + self.context_size + 1]
            candidates = self.candidates - {word}
            candidate_probs = self.language_model(candidates, pretext,
                                                  posttext)
            word_prob = self.language_model({word}, pretext, posttext)[word]

            def get_prob(candidate: str) -> float:
                return candidate_probs[candidate] * self.error_model(
                    word, candidate)
            
            correction = max(candidates, key=get_prob)
            if get_prob(correction) >  1e4 * word_prob:
                result[i] = correction
        return result

    def __call__(
        self,
        lines: list[str],
    ) -> list[str]:
        """Corrects lines and returns the corrected version.

        Args:
            lines (list[str]): Text lines to correct.    
        
        Returns:
            list[str]: Corrected lines.
        """
        result = []
        for line in lines:
            words = self.get_words(line.lower())
            corrected_words = self.correct_words(words)
            corrected_line = self.replace_words(line, corrected_words)
            result.append(corrected_line)
        return result

    def get_candidates(self, pretext: list[str], word: str,
                       posttext: list[str]) -> dict[str, float]:
        """Returns probabilities associated with each candidate for a given
        context. Used for testing.

        Args:
            pretext (list[str]): Context preceeding the word.
            word (str): Typed word.
            postfix (list[str]): Context after the word.
        
        Returns:
            dict[str, float]: Mapping from each candidate to the probability
                of it being the word that was meant to be typed.
        """
        candidates = self.candidates.copy() - {word}
        candidate_probs = self.language_model(candidates, pretext, posttext)

        def get_prob(candidate):
            return self.error_model(word, candidate) *\
                   candidate_probs[candidate]

        return {candidate: get_prob(candidate) for candidate in candidates}

### Loading the text data

In [7]:
fivegrams = dict()
with open('fivegrams.txt') as file:
    for line in file:
        count, *gram = line.split()
        gram = tuple(gram)
        count = int(count)
        fivegrams[gram] = count

In [8]:
bigrams = dict()
with open('bigrams.txt', encoding='iso-8859-1') as file:
    for line in file:
        count, *gram = line.split()
        gram = tuple(gram)
        count = int(count)
        bigrams[gram] = count

In [9]:
with open('wikipedia.txt') as file:
    spelling_mistakes = {}
    for line in file:
        line = line.replace(':', '')
        word, *misspellings = line.split()
        spelling_mistakes[word] = misspellings

In [10]:
# "The hound of the Baskervilles" by Arthur Conan Doyle
with open('pg3070.txt') as file:
    val_text = file.read().lower()

### Validation functions

In [11]:
Dataset = list[tuple[tuple[list[str], str, list[str]], str]]

In [12]:
def generate_test_dataset(
    text: str,
    spelling_mistakes: dict[str, list[str]],
    context_size: int = 10
) -> Dataset:
    """Generates a test/validation dataset from a text. First, extracts words
    using a simple regular expression, then, for each word, generates a test
    case by replacing the word by corresponding misspelled versions from
    `spelling_mistakes`.

    Args:
        text (str): Text to generate the dataset from.
        spelling_mistakes (dict[str, list[str]]): Mapping from a word to possible
            misspellings.
        context_size: Context to include in the test case. Omnidirectional, does
            not include the word itself. Size of context is counted individually
            in each direction.

    Returns:
        list[input, target]: Input is a tuple of pretext, word, and posttext, and
            target is the original word being replaced.
    """
    text = text.lower()
    words = re.findall('\\w+', text)
    result = []
    for i, word in enumerate(words):
        if word not in spelling_mistakes:
            continue
        pretext = words[i - context_size:i]
        posttext = words[i + 1:i + context_size + 1]
        for misspelling in spelling_mistakes[word]:
            result.append(((pretext, misspelling, posttext), word))
    return result

In [13]:
def get_accuracy(
    dataset: Dataset,
    corrector: Corrector,
    verbose: bool = False,
) -> float:
    """Calculates the accuracy of the corrector on the dataset.

    Args:
        dataset: Testing dataset.
        corrector (Corrector): Corrector to test.
        verbose: Whether to display the progress bar or not. Defaults to `False`.
    
    Returns:
        float: Accuracy on the dataset. Between 0 and 1 inclusive.
    """
    total = len(dataset)
    accurate = 0
    for (pretext, word, posttext), target in tqdm.tqdm(dataset,
                                                       disable=not verbose):
        probs = corrector.get_candidates(pretext, word, posttext)
        prediction = max(probs, key=probs.__getitem__)
        if prediction == target:
            accurate += 1
    return accurate / total

In [14]:
def test_corrector(
    n: int,
    training_set: dict[tuple[str], int],
    levenstein_base: float,
    candidates_size: int,
    val_set: Dataset,
) -> tuple[float, Corrector]:
    """Creates, trains, and evaluates a corrector according to the given parameters.

    Args:
        n (int): Number of words in n-grams.
        training_set (dict[tuple[str], int]): Counts of n-grams in the training
            set.
        candidates_size (int): Size of the candidates set.
        val_set (Dataset): Set to evaluate corrector on.

    Returns:
        tuple[float, Corrector]: Accuracy of the model and the trained model.
    """
    lm = NGramLanguageModel(training_set, n=n)
    em = LevensteinErrorModel(base=levenstein_base)

    def get_count(word):
        return lm.prefix[
            (word),
        ]

    candidates = sorted({gram[0]
                         for gram in training_set.keys()},
                        key=get_count,
                        reverse=True)
    candidates = candidates[:candidates_size]
    candidates = set(candidates)
    corrector = Corrector(candidates, lm, em)
    return get_accuracy(val_set, corrector), corrector

In [15]:
def test_matrix(matrix: dict[str, list]) -> tuple[dict, float, Corrector]:
    """Tests all combinations of hyperparameters in the matrix, and returns
    best parameters, score, and model.

    Args:
        matrix (dict[str, list]): Matrix with hyperparameters that should be
            passed into `test_corrector`.
    
    Returns:
        tuple[dict, float, Corrector]: Best hyperparameters, accuracy, and
            model respectively.
    """
    keys = sorted(matrix.keys())
    iterables = [matrix[key] for key in keys]
    best_values = None
    best_score = -float('inf')
    best_model = None
    cases = list(itertools.product(*iterables))
    for values in tqdm.tqdm(cases):
        kwargs = dict(zip(keys, values))
        accuracy, model = test_corrector(**kwargs)
        if accuracy > best_score:
            best_values = dict(zip(keys, values))
            best_score = accuracy
            best_model = model
    return best_values, best_score, best_model

### Finding the best hyperparameters

In [27]:
random.seed(42)
test_dataset = generate_test_dataset(val_text, spelling_mistakes)
cases_by_target = defaultdict(list)
for case, target in test_dataset:
    cases_by_target[target].append(case)
test_dataset = []
for _ in range(10):
    for target, cases in cases_by_target.items():
        test_dataset.append((random.choice(cases), target))
random.shuffle(test_dataset)
val_dataset, test_dataset = test_dataset[:100], test_dataset[100:200]

In [17]:
matrix = {
    'n': [2],
    'training_set': [bigrams],
    'levenstein_base': [2, 8, 32, 128, 512, 2048],
    'candidates_size': [100000],
    'val_set': [val_dataset],
}

In [18]:
_values, score, model = test_matrix(matrix)

  0%|          | 0/6 [00:00<?, ?it/s]

In [19]:
print(f'Accuracy on the validation set: {score * 100:.1f}%')

Accuracy on the validation set: 75.0%


## Justify your decisions

Write down justificaitons for your implementation choices. For example, these choices could be:
- Which ngram dataset to use
- Which weights to assign for edit1, edit2 or absent words probabilities
- Beam search parameters
- etc.

Most of the hyperparameters, as well as training dataset are chosen experimentally (see "Finding the best hyperparameters" for details,) so they won't be mentioned here.

### Language model

The language model is a $n$-gram model, using beta distribution mean value to smooth the probabilities:
$$
\begin{array}{rl}
\hat{P}(w_n|w_1..w_{n-1}) &= \mu\left[\Beta(\#(w_1..w_n) + 1, \#(w_1..w_{n-1}) - \#(w_1..w_n) + 1)\right] = \\
                    &= \frac{\#(w_1..w_n) + 1}{\#(w_1..w_{n-1}) + 2}
\end{array}
$$
Where $\#(w_1..w_n)$ is the number of times n-gram $w_1..w_n$ occurs in the training set.
Given a preceeding context $U = u_1..u_{n-1}$, and succeeding context $V = v_1..v_{n-1}$, the probability of word $w$ occurring is:
$$\hat{P}(w|U,V) = \hat{P}(w|U)\hat{P}(w|V)$$
Beta distribution is a natural choice when dealing with probabilities, as it is a conjugate prior distribution for Bernoulli distribution.

### Error model

Error model is meant to predict the probability of word $w$ being a misspelling of the candidate word $c$ $\mathcal{P}(c|w)$. To define this probability, let us introduce the notion of Levenstein distance first:
$$d(u, v) = f(|u|, |v|)$$
$$f(i, j) = \begin{cases}
0, &i = 0 \land j = 0 \\
i, &j = 0 \\
j, &i = 0 \\
f(i - 1, j - 1), &u_i = v_j \\
\min \left\{ f(i - 1, j), f(i, j - 1) \right\}, &\text{otherwise} \\
\end{cases}$$
Probability $\mathcal{P}(c|w)$ is then:
$$\mathcal{P}(c|w) = b^{-d(c, w)}$$
Where $b$ is a hyperparameter.

### Corrector

Given the context pair $(U, V)$ and the word $w$, the probability of word $w$ being a mispelling of candidate word $c \in C$ is simply:
$$P(c|w,U,V) = \hat{P}(c|U,V) \mathcal{P}(c|w)$$
Most likely candidate $c'$ is selected purely on this probability:
$$c' = \mathop{\text{argmax}}\limits_{c \in C} P(c|w,U,V)$$

The word is then replaced with the corrected version if:
$$\lambda\hat{P}(w|U,V) < P(c|w,U,V)$$
Where $\lambda$ is a threshold constant, chosen empirically to be $10^4$.

## Evaluate on a test set

Your task is to generate a test set and evaluate your work. You may vary the noise probability to generate different datasets with varying compexity. Compare your solution to the Norvig's corrector, and report the accuracies.

*Note: the test dataset and evaluation functions are defined in the "Finding the best hyperparameters" section"*

In [28]:
test_accuracy = get_accuracy(test_dataset, model, verbose=True)

  0%|          | 0/100 [00:00<?, ?it/s]

In [29]:
print(f'Accuracy on the test set: {test_accuracy * 100:.1f}%')

Accuracy on the test set: 85.0%


### Norvig's solution

In [22]:
"""Spelling Corrector in Python 3; see http://norvig.com/spell-correct.html

Copyright (c) 2007-2016 Peter Norvig
MIT license: www.opensource.org/licenses/mit-license.php
"""


class Norvig:

    @staticmethod
    def get_words(text):
        return re.findall(r'\w+', text.lower())

    def __init__(self, corpus_file: str) -> None:
        with open(corpus_file) as file:
            words = self.get_words(file.read())
            self.n = len(words)
            self.words = Counter(words)

    def probability(self, word):
        return self.words[word] / self.n

    def correction(self, word):
        return max(self.candidates(word), key=self.probability)

    def candidates(self, word):
        return (self.known([word]) or self.known(self.edits1(word))
                or self.known(self.edits2(word)) or [word])

    def known(self, words):
        return set(words) & set(self.words)

    @staticmethod
    def edits1(word):
        letters = 'abcdefghijklmnopqrstuvwxyz'
        splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]
        deletes = [L + R[1:] for L, R in splits if R]
        transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R) > 1]
        replaces = [L + c + R[1:] for L, R in splits if R for c in letters]
        inserts = [L + c + R for L, R in splits for c in letters]
        return set(deletes + transposes + replaces + inserts)

    @staticmethod
    def edits2(word):
        return {e2 for e1 in Norvig.edits1(word) for e2 in Norvig.edits1(e1)}

    def get_candidates(self, pretext: list[str], word: str,
                       posttext: list[str]) -> dict[str, float]:
        """Returns probabilities associated with each candidate for a given
        context. Used for testing.

        Args:
            pretext (list[str]): Context preceeding the word.
            word (str): Typed word.
            postfix (list[str]): Context after the word.
        
        Returns:
            dict[str, float]: Mapping from each candidate to the probability
                of it being the word that was meant to be typed.
        """
        candidates = set(self.candidates(word)) - {word}
        correction = self.correction(word)
        result = dict.fromkeys(candidates, 0)
        result[correction] = 1
        return result

In [30]:
norvig = Norvig('big.txt')

In [31]:
norvig_accuracy = get_accuracy(test_dataset, norvig, verbose=True)

  0%|          | 0/100 [00:00<?, ?it/s]

In [32]:
print(f'Accuracy on the test set: {norvig_accuracy * 100:.1f}%')

Accuracy on the test set: 84.0%


Final accuracies on the test set are:

| Norvig | N-gram |
| --- | --- |
| 84.0% | 85.0% |
