# Задание 1. Реализуйте алгоритм Symspell

Он похож на алгоритм Норвига, но проще и быстрее. Там к словам в словаре применяется только одна операция - удаление символа. Чтобы найти исправление из слова тоже удаляются символы и сравниваются с теми, что хранятся в словаре. Оцените качество полученного алгоритма теми же тремя метриками.

https://medium.com/@wolfgarbe/1000x-faster-spelling-correction-algorithm-2012-8701fcd87a5f.


In [1]:
!pip install razdel
from razdel import sentenize
from razdel import tokenize as razdel_tokenize
import re
from string import punctuation
from collections import Counter
punctuation += "«»—…“”"
punct = set(punctuation)

bad = open('sents_with_mistakes.txt', encoding='utf8').read().splitlines()
true = open('correct_sents.txt', encoding='utf8').read().splitlines()



In [2]:
def align_words(sent_1, sent_2):
    tokens_1 = sent_1.lower().split()
    tokens_2 = sent_2.lower().split()
    
    tokens_1 = [re.sub('(^\W+|\W+$)', '', token) for token in tokens_1 if (set(token)-punct)]
    tokens_2 = [re.sub('(^\W+|\W+$)', '', token) for token in tokens_2 if (set(token)-punct)]
    
    return list(zip(tokens_1, tokens_2))


corpus = open('wiki_data.txt', encoding='utf8').read()

WORDS = Counter(re.findall('\w+', corpus.lower()))

# фунцкия расчета вероятности слова
N = sum(WORDS.values())
def P(word, N=N): 
    "Вычисляем вероятность слова"
    return WORDS[word] / N

corpus = set(re.findall('\w+', corpus.lower()))

In [3]:
def collect_deleted(dictionary, word, depth=1):
    if len(word) > depth:
        if depth == 1:
                for i in range(len(word)):
                    current = word[:i] + word[(i+1):]
                    if current not in dictionary:
                        dictionary[current] = [word]
                    else:
                        dictionary[current].append(word)
        if depth == 2:
            for i in range(len(word)):
                for k in range(i+1, len(word)):
                    current = word[:i] + word[(i+1):k] + word[k+1:]
                    if current not in dictionary:
                        dictionary[current] = [word]
                    else:
                        dictionary[current].append(word)

In [4]:
%%time

vocab_correct = {}
vocab_1_letter_deleted = {}
vocab_2_letters_deleted = {}

for word in corpus:
    vocab_correct[word] = [word]
    collect_deleted(vocab_1_letter_deleted, word, depth=1)
    collect_deleted(vocab_2_letters_deleted, word, depth=2)
    
vocabs = [vocab_correct, vocab_1_letter_deleted, vocab_2_letters_deleted]

CPU times: user 29.8 s, sys: 1.58 s, total: 31.4 s
Wall time: 32 s


In [5]:
def choose_candidate(c_candidates):
    if len(c_candidates) == 0:
        return None
    chosen_candidate = max(c_candidates, key=P) 
    return chosen_candidate

In [6]:
def check_vocabs(word, vocabs, candidates, start_weight):
    for i in range(len(vocabs)): 
        vocab = vocabs[i]
        if word in vocab:
            weight = max(i, start_weight)
            candidates[weight] += vocab[word]

In [7]:
def correction(word):
    correct = word
    candidates = [[], [], [], [], []] 
    check_vocabs(word, vocabs, candidates, 0)
    for i in range(len(word)):
        current_1 = word[:i] + word[(i+1):]
        check_vocabs(current_1, vocabs, candidates, 1)
        for k in range(i+1, len(word)):
            current_2 = word[:i] + word[(i+1):k] + word[k+1:]
            check_vocabs(current_2, vocabs, candidates, 2)
    for c in candidates:
        if len(c) > 0:
            return choose_candidate(c)
    return correct

In [8]:
#проверим работу системы

correction('ондон')

'лондон'

In [9]:
text_fun = ['сяпала', 'калуша', 'с', 'калушатами', 'по', 'напушке', 'и', 'увазила', 'бутявку', 'волит', 'калушата', 'калушаточки', 'бутявка'
]

for i in text_fun:
    print(correction(i))


спасла
карлуш
с
калатафими
по
науке
и
выразил
путёвку
вошли
карлуша
калушаточки
утковка


In [10]:
#считаем метрики

correct = 0
total = 0

total_mistaken = 0
mistaken_fixed = 0

total_correct = 0
correct_broken = 0

for i in range(len(true)):
    word_pairs = align_words(true[i], bad[i])
    for pair in word_pairs:
        pred = correction(pair[1])

        if pred == pair[0]:
            correct += 1
        total += 1
        
        if pair[0] == pair[1]:
            total_correct += 1
            if pair[0] !=  pred:
                correct_broken += 1
        else:
            total_mistaken += 1
            if pair[0] == pred:
                mistaken_fixed += 1
        
    if not i % 100:
        print(i)

0
100
200
300
400
500
600
700
800
900


In [11]:
print('correct/total:', correct/total)
print('mistaken_fixed/total_mistaken: ', mistaken_fixed/total_mistaken)
print('correct_broken/total_correct: ', correct_broken/total_correct)

correct/total: 0.8579420579420579
mistaken_fixed/total_mistaken:  0.4405218726016884
correct_broken/total_correct:  0.07959113357069025



# Задание 2. Добавьте к полученному алгоритму исправления (Symspell) триграммную модель и проверьте, улучшает ли она качество. 
Триграммную модель нужно вставить туда, где у вас выбирается один из нескольких кандидатов на исправление.


In [12]:
corpus = open('wiki_data.txt', encoding='utf8').read()

def normalize(text):
    normalized_text = [word.text.strip(punctuation) for word \
                                                            in razdel_tokenize(text)]
    normalized_text = [word.lower() for word in normalized_text if word and len(word) < 20 ]
    return normalized_text

def preprocess(text):
    sents = sentenize(text)
    return [normalize(sent.text) for sent in sents]

def ngrammer(tokens, n):
    ngrams = []
    tokens = [token for token in tokens]
    for i in range(0,len(tokens)-n+1):
        ngrams.append(tuple(tokens[i:i+n]))
    return ngrams

corpus_wiki = [['<start>'] + sent + ['<end>'] for sent in preprocess(corpus)]
                

In [13]:
trigrams = {}

for sent in corpus_wiki:
    for i in range(1 , len(sent)-1):
        token = sent[i]
        context = (sent[i-1], sent[i+1])
        if context not in trigrams:
            trigrams[context] = {}
            trigrams[context][token] = 1
        else:
            if token not in trigrams[context]:
                trigrams[context][token] = 1
            else:
                trigrams[context][token] += 1

In [14]:
def align_words(sent_1, sent_2):
    tokens_1 = sent_1.lower().split()
    tokens_2 = sent_2.lower().split()
    
    tokens_1 = ['<start>'] + [re.sub('(^\W+|\W+$)', '', token) for token in tokens_1 if (set(token)-punct)] + ['<end>']
    tokens_2 = ['<start>'] + [re.sub('(^\W+|\W+$)', '', token) for token in tokens_2 if (set(token)-punct)] + ['<end>']
    
    return list(zip(tokens_1, tokens_2))

In [15]:
#считаем вероятность слова в контексте
def Prob_in_context(word, context):
    if context not in trigrams:
        return 1
    else:
        counter = 0
        trigram = trigrams[context]        
        for w in trigram:
            counter += trigram[w]
        if word in trigram:
            return trigram[word]/counter
        else:
            return 1/(100*N)

In [16]:
def choose_candidate(c_candidates, context):
    if len(c_candidates) == 0:
        return None
    chosen_candidate = max(c_candidates, key=lambda x: P(x) * Prob_in_context(x, context)) 
    return chosen_candidate

In [17]:
#добавим контекст в функцию исправления
def correction(word, context):
    correct = word
    candidates = [[], [], [], [], []] 
    check_vocabs(word, vocabs, candidates, 0)
    for i in range(len(word)):
        current_1 = word[:i] + word[(i+1):]
        check_vocabs(current_1, vocabs, candidates, 1)
        for k in range(i+1, len(word)):
            current_2 = word[:i] + word[(i+1):k] + word[k+1:]
            check_vocabs(current_2, vocabs, candidates, 2)
    for c in candidates:
        if len(c) > 0:
            return choose_candidate(c, context)
    return correct

In [18]:
# считаем метрики

correct = 0
total = 0

total_mistaken = 0
mistaken_fixed = 0

total_correct = 0
correct_broken = 0

for i in range(len(true)):
    word_pairs = align_words(true[i], bad[i])
    for k in range(1, len(word_pairs)-1):
        pair = word_pairs[k]
        previous = word_pairs[k-1]
        following = word_pairs[k+1]
        context = (previous[1], following[1])
        pred = correction(pair[1], context)
        
        if pred == pair[0]:
            correct += 1
        total += 1
        
        if pair[0] == pair[1]:
            total_correct += 1
            if pair[0] !=  pred:
                correct_broken += 1
        else:
            total_mistaken += 1
            if pair[0] == pred:
                mistaken_fixed += 1

    if not i % 100:
        print(i)

0
100
200
300
400
500
600
700
800
900


In [19]:
print('correct/total: ', correct/total)
print('mistaken_fixed/total_mistaken: ', mistaken_fixed/total_mistaken)
print('correct_broken/total_correct: ', correct_broken/total_correct)

correct/total:  0.859040959040959
mistaken_fixed/total_mistaken:  0.44896392939370683
correct_broken/total_correct:  0.07959113357069025


По сравнению с предыдущим результатом:

correct/total: 0.8579420579420579  
mistaken_fixed/total_mistaken 0.4405218726016884  
correct_broken/total_correct:  0.07959113357069025  

...качество выросло по первым двум метрикам, но незначительно.