In [14]:
from Levenshtein import distance
from collections import Counter
from nltk import bigrams
from tqdm import tqdm
from math import log2
from razdel import tokenize
from itertools import product
from math import log, exp
import string
import operator
import re

In [2]:
class Rating:
    def __init__(self, l=[], size=5):
        self.size = size
        self.r = sorted(l, key=lambda x: x[1], reverse=True)
        self.is_change = True
        if len(self.r) > self.size:
            self.r = self.r[len(self.r) - self.size:]
        
    def push(self, item):
        for i, elem in enumerate(self.r):
            if item == elem:
                return
            if item[1] > elem[1]:
                if i != 0:
                    self.is_change = True
                    self.r.insert(i, item)
                    if len(self.r) > self.size:
                        del self.r[i - 1]
                elif len(self.r) < self.size:
                    self.is_change = True
                    self.r.insert(0, item)
                return
        self.is_change = True
        self.r.append(item)
        if len(self.r) > self.size:
            del self.r[0]
    
    def show(self):
        for item in list(reversed(self.r)):
            print(item[0], item[1])

In [3]:
with open('lenta_words.txt', 'r', encoding='utf-8') as file:
    text = file.read().lower().split('\n')

In [4]:
class Tree:
    def __init__(self, key=None, is_word=False, prev=None):
        self.letters = {}
        self.key = key
        self.is_word = is_word
        self.amount = 0
    
    def __getitem__(self, key):
        return self.letters[key]
    
    def __setitem__(self, key, value):
        self.letters[key] = value

In [5]:
def word_insert(word, tree):
    supp = tree
    prefix_count = 0
    for let in word:
        if not re.findall(r'[a-zA-Zа-яА-ЯёЁ]', let):
            continue
        if let not in supp.letters:
            supp[let] = Tree(let, False, supp)
        supp = supp[let]
    supp.amount += 1
    tree.amount += 1
    supp.is_word = True

In [6]:
tree = Tree()
for i in tqdm(text):
    word_insert(i, tree) 

100%|██████████| 19232114/19232114 [02:17<00:00, 139915.45it/s]


In [7]:
def find_word(word, tree):
    supp = tree
    for i in word:
        supp = supp[i]
    return supp

In [62]:
def weight(word_true, word_current, freq=1, alpha1=0, alpha2=10):
    return -alpha1 * log2(freq) - log2(alpha2 ** (-distance(word_true, word_current)))

In [66]:
def find_true_word(word, tree, rating_size=5):
    result = [('', 0, tree, 0)]
    rating = Rating(size=rating_size)
    while rating.is_change:
        prefixes = set()
        for i in result:
            supp = i[2]
            for let in supp.letters:
                cur_prefix = i[0] + word[i[3]] if i[3] < len(word) else i[0]
                w = weight(i[0] + let, cur_prefix)
                if w + i[1] < 10 and i[3] - len(i[0]) < 3:
                    prefixes.add((i[0] + let, w + i[1], supp[let], i[3] + 2))
                    prefixes.add((i[0] + let, w + i[1],  supp[let], i[3] + 1))
                    prefixes.add((i[0] + let, w + i[1], supp[let], i[3]))
        result = sorted(list(prefixes), key=operator.itemgetter(1))[:500]
        rating.is_change = False
        for item in result:
            if item[2].is_word:
                rating.push((item[0], weight(item[0], word, item[2].amount / tree.amount, alpha1=0.2)))
    return list(reversed(rating.r))

In [67]:
bigrams_amount = len(list(bigrams(text)))
bigrams_count = Counter((bigrams(text)))
words_count = Counter(text)

In [172]:
def prob_sent(request):
    words = [item.text.lower() for item in tokenize(request) if item.text not in string.punctuation]
    bigr = list(bigrams(words))
    v = len(text)
    p = 1
    for i in bigr:
        p *= (bigrams_count[i] + 1) / (words_count[i[0]] + v)
    return p

In [169]:
request = 'путн оцинил роботу новвых самалетав и виртолтов в сирийи'

In [170]:
result = []
for word in request.lower().split(' '):
    result.append(find_true_word(word, tree, rating_size=200)[:3])

In [173]:
result

[[('путин', 5.680382706507288),
  ('пути', 6.036149768929111),
  ('путь', 6.193959314988637)],
 [('оценил', 6.42193981865721),
  ('оценили', 9.688196097554579),
  ('оценила', 10.087109320344453)],
 [('роботу', 4.072931498548263),
  ('работу', 5.715310407795874),
  ('робота', 6.6199657698523975)],
 [('новых', 5.777154942112299),
  ('новый', 8.97124896874595),
  ('новые', 9.09774345214872)],
 [('самолета', 9.188497749909462),
  ('самолетов', 9.223375276370653),
  ('самолетах', 9.912018339932114)],
 [('и', 1.130388844140091),
  ('в', 4.2338613068869115),
  ('с', 4.664363788300127)],
 [('вертолетов', 9.531587804611743),
  ('вертолётов', 11.29816580744469),
  ('вертолетом', 13.531505203162308)],
 [('в', 0.9119332119995489),
  ('и', 4.452316939027453),
  ('с', 4.664363788300127)],
 [('сирии', 5.961141882659725),
  ('серии', 9.220132016204655),
  ('сибири', 9.843873201863815)]]

In [171]:
probs_d = {}
for sent in tqdm(list(product(*result))):
    words = ' '.join([s[0] for s in sent])
    probs_d[words] = prob_sent(words)
print(max(probs_d, key=probs_d.get))

100%|██████████| 19683/19683 [00:00<00:00, 19755.94it/s]

путин оценил работу новых самолетов и вертолетов в сирии



