In [1]:
import re
import numpy
import gzip
import pickle
from collections import Counter
from random import uniform
from collections import defaultdict
from math import log
from functools import reduce

In [2]:
class StandardTrie:
    def __init__(self, strings):
        self._root = self.Node()

        for s in strings:
            self.add(s, s)
            
    @property
    def root(self):
        return self._root

    def add(self, string, value):
        node = self._root
        for char in string:
            found_in_child = False
            for child in node.children:
                if child.key == char:
                    found_in_child = True
                    node = child
                    break
            if not found_in_child:
                new_node = self.Node()
                new_node.key = char
                node.children[new_node] = new_node
                node = new_node
        node._value = value

    def find(self, pattern):
        node = self._root

        if not node._children:
            return False

        for char in pattern:
            char_not_found = True
            for child in node._children:
                if child.key == char:
                    char_not_found = False
                    node = child
                    break
            if char_not_found:
                return False

        if not node._value:
            return False
        else:
            return node._value
    
    class Node:
        __slots__ = '_children', '_key', '_value'

        def __init__(self):
            self._children = {}
            self._key = None
            self._value = None

        @property
        def children(self):
            return self._children

        @property
        def key(self):
            return self._key

        @key.setter
        def key(self, new_key):
            self._key = new_key

        @property
        def value(self):
            return self._value

In [3]:
class TrieSpellChecker:
    def __init__(self, lexicon):
        self._lexicon = StandardTrie(lexicon)
        self._threshold = 2

    def check(self, word):
        spellings = []

        if self._lexicon.find(word):
            spellings.append(word)

        root_node = self._lexicon.root

        current_row = range(len(word) + 1)

        for child in root_node.children:
            self._recursive_check(child, child.key, word,
                                 current_row, spellings)

        if not spellings:
            return [word]
        
        return spellings

    def _recursive_check(self, node, char, word, previous_row, spellings):
        num_cols = len(word) + 1
        current_row = [previous_row[0] + 1]

        # levenstein
        for col in range(1, num_cols):

            left = current_row[col - 1] + 1
            up = previous_row[col] + 1
            diagonal = previous_row[col - 1]

            if word[col - 1] != char:
                diagonal += 1

            current_row.append(min(left, up, diagonal))

        if current_row[-1] <= self._threshold and node.value != None:
            spellings.append(node.value)

        if min(current_row) <= self._threshold:
            for child in node.children:
                self._recursive_check(child, child.key, word,
                                     current_row, spellings)

In [4]:
class LanguageModel:
    def __init__(self, flag = 'all'):
        self._model= {}
        self.N = 0
        text = ''
        
        if flag == 'indexer' or flag == 'all':
            for line in self._gen_lines('queries_all.txt'):
                parts = line.split('\t')
                if len(parts) == 1:
                    text += line + ' '
                if len(parts) == 2:
                    text += parts[1] + ' '
            
            words_counter = Counter(self._words(text))
            self._model = words_counter
            
        if flag == 'indexer':
            self._save()
            
        if flag == 'spellchecker':
            self._load()
            
        self.N = sum(self._model.values())
        
    def corpus(self):
        return self._model
    
    def P(self, word):
        return self._model.get(word, 0.1) / self.N
    
    def proba(self, text):
        return reduce((lambda x, y: x * y), list(map(self.P, text.split())))
        
    def _words(self, text): 
        return re.findall(r'\w+', text.lower())
    
    def _gen_lines(self, fname):
        with open(fname) as data:
            for line in data:
                yield line.lower()
            
    def _save(self):
        with open('model.pkl', 'wb') as f:
            pickle.dump(self._model, f, pickle.HIGHEST_PROTOCOL)

    def _load(self):
        with open('model.pkl', 'rb') as f:
            self._model = pickle.load(f)

In [5]:
class SpellChecker:
    def __init__(self, language_model):
        self._language = language_model
        self._words = self._language.corpus()
        self._trie = TrieSpellChecker(list(self._words.keys()))
        self._threshold = 1e-5
    
    def _correct_word(self, word):
        candidates = self._trie.check(word)
        if word in candidates and self._language.P(word) > self._threshold:
            return word
        return max(self._trie.check(word), key=self._language.P)
    
    def correct(self, text):
        words = text.lower().split()
        words = list(map(self._correct_word, words))
        result = ' '.join(words)
        if result != text:
            return True, result
        return False, result

In [6]:
class LayoutChecker:
    def __init__(self, language_model):
        self._language = language_model
        self._eng_chars = " ~!@#$%^&qwertyuiop[]asdfghjkl;'zxcvbnm,./QWERTYUIOP{}ASDFGHJKL:\"|ZXCVBNM<>?"
        self._rus_chars = " ё!\"№;%:?йцукенгшщзхъфывапролджэячсмитьбю.ЙЦУКЕНГШЩЗХЪФЫВАПРОЛДЖЭ/ЯЧСМИТЬБЮ,"
        self._eng_to_rus = dict(zip(self._eng_chars, self._rus_chars))
        self._rus_to_eng = dict(zip(self._rus_chars, self._eng_chars))

    def correct(self, text):
        text_rus = ''.join([self._eng_to_rus.get(c, c) for c in text])
        text_eng = ''.join([self._rus_to_eng.get(c, c) for c in text])
        result = max([text, text_rus, text_eng], key=self._language.proba)   
        if result != text:
            return True, result
        return False, result

In [7]:
class SplitChecker:
    def __init__(self, language_model):
        self._language = language_model
        self._maxword = max(len(w) for w in self._language.corpus())
        
    def correct(self, s):
        l = [self._split(x) for x in s.split()]
        return ' '.join([item for sublist in l for item in sublist])
    
    def _split(self, s):
        def best_match(i):
            candidates = enumerate(reversed(cost[max(0, i - self._maxword):i]))
            return min((c - log(self._language.P(s[i - k - 1:i].lower())), k + 1) for k, c in candidates)

        cost = [0]
        for i in range(1, len(s) + 1):
            c, k = best_match(i)
            cost.append(c)

        out = []
        i = len(s)
        while i > 0:
            c, k = best_match(i)
            out.append(s[i - k:i])
            i -= k

        return reversed(out)

In [8]:
class JoinChecker:
    def __init__(self, language_model):
        self._language = language_model
    
    def correct(self, s):
        i = 0
        while s.find(' ', i + 1) != -1:
            i = s.find(' ', i + 1)
            if self._language.proba(s[0:i] + s[i + 1:]) > self._language.proba(s):
                s = s[0:i] + s[i + 1:]
        
        return s

In [26]:
class SplitJoinChecker:
    def __init__(self, language_model):
        self._language = language_model
        self._words = self._language.corpus().keys()

    def split(self, word_parts):
        split = False
        splitted_tok = []

        for i in self._word_positions(word_parts):
            token = word_parts[i]
            
            if len(token) > 2:
                pos = 1
                while pos < len(token):
                    if token[:pos] in self._words and token[pos:] in self._words:
                        if token not in self._words:
                            splitted_tok = list(word_parts)
                            splitted_tok[i] = token[pos:]
                            splitted_tok.insert(i, ' ')
                            splitted_tok.insert(i, token[:pos])
                            split = True
                            break
                    pos += 1
            if split:
                break

        return splitted_tok, split

    def join(self, word_parts):
        flag = False
        words_positions = self._word_positions(word_parts)
        i = 0
        while i < len(words_positions) - 1:
            left = word_parts[words_positions[i]]
            right = word_parts[words_positions[i + 1]]
            
            if left not in self._words or right not in self._words:
                if left + right in self._words:
                    word_parts[words_positions[i]] = left + right
                    for pos in sorted(range(words_positions[i] + 1, words_positions[i + 1] + 1), reverse=True):
                        del word_parts[pos]
                    flag = True
            if flag:
                break
            i += 1
            
        return word_parts, flag

    def _word_positions(self, word_parts):
        i = 0
        words_positions = []
        for token in word_parts:
            if token.isalpha():
                words_positions.append(i)

            i += 1

        return words_positions

In [27]:
%time language_model = LanguageModel('indexer')
%time language_model = LanguageModel('spellchecker')

CPU times: user 6.56 s, sys: 919 ms, total: 7.48 s
Wall time: 7.73 s
CPU times: user 197 ms, sys: 11.4 ms, total: 209 ms
Wall time: 209 ms


In [28]:
spellchecker = SpellChecker(language_model)
layoutchecker = LayoutChecker(language_model)
splitchecker = SplitChecker(language_model)
joinchecker = JoinChecker(language_model)
test = SplitJoinChecker(language_model)

In [29]:
print(spellchecker.correct('Cкачат бесплатно онлан'))
print(layoutchecker.correct('ghbdtn'))
print(splitchecker.correct('скачать бесплатноонлайнсмотреть'))
print(joinchecker.correct('не навидеть'))
print(test.join('бесп латно'.split()))

скачать бесплатно онлайн
привет
скачать бесплатно онлайн смотреть
ненавидеть
(['бесплатно'], True)
