In [None]:
import os
import pickle
import re

from collections import Counter
from nltk.corpus import stopwords
from nltk.stem import SnowballStemmer
from nltk.util import ngrams
from tqdm.auto import tqdm

In [None]:
CORPUS_PATH = "path_to_text_corpus"
MODEL_PATH="path_to_output_model"

CORPUS_DOCS_SEP = "\n\n"
MAX_NGRAM_LEN = 3
LOWERCASE = True
LANGUAGE = "english"
REMOVE_STOPWORDS = False
STEM = False
PROB_WORDS_LIMIT = 100

In [None]:
if not os.path.isdir(MODEL_PATH):
    raise FileNotFoundError(MODEL_PATH)

In [None]:
%%time

with open(CORPUS_PATH, "r", encoding="utf-8") as fh:
    docs = fh.read().split(CORPUS_DOCS_SEP)

In [None]:
# Replace this by more sophisticated tokenizer depending on your demands

if REMOVE_STOPWORDS:
    STOPWORDS = set(stopwords.words(LANGUAGE))
    
if STEM:
    STEMMER = SnowballStemmer(LANGUAGE)

def tokenize(doc):
    if LOWERCASE:
        doc = doc.lower()
        
    tokens = re.split(r"\W+", doc)
    
    if REMOVE_STOPWORDS:
        tokens = [token for token in tokens if token not in STOPWORDS]
        
    if STEM:
        tokens = [STEMMER.stem(token) for token in tokens]
    
    return tokens

In [None]:
ngrams_stat = {}

min_ngram_freq = max(5, len(docs) // 10**5)

for ngram_len in range(2, MAX_NGRAM_LEN+1):
    ngrams_stat[ngram_len] = Counter()
    
    for doc in tqdm(docs):
        tokens = tokenize(doc)

        for ngram in ngrams(tokens, ngram_len):
            ngrams_stat[ngram_len][ngram] += 1

    ngrams_stat[ngram_len] = [item for item in ngrams_stat[ngram_len].items() if item[1] >= min_ngram_freq]
    
del docs

In [None]:
for ngram_len in range(2, MAX_NGRAM_LEN+1):
    prob_words = {}
    
    for ngram, cnt in tqdm(ngrams_stat[ngram_len]):
        for mask_pos in range(ngram_len):
            context_tokens = tuple([ngram[i] for i in range(ngram_len) if i != mask_pos])

            prob_words.setdefault(mask_pos, {})
            prob_words[mask_pos].setdefault(context_tokens, {})
            prob_words[mask_pos][context_tokens][ngram[mask_pos]] = cnt
    
    for mask_pos in range(ngram_len):
        prob_words[mask_pos] = {k: v for k, v in prob_words[mask_pos].items() if len(v) > 1}
         
        for context_tokens in tqdm(prob_words[mask_pos]):
            words = prob_words[mask_pos][context_tokens]
            
            cnt_sum = sum(words.values())
            
            words = {k: words[k] for k in sorted(words, key=lambda x: words[x], reverse=True)[:PROB_WORDS_LIMIT]}
            
            prob_words[mask_pos][context_tokens] = {k: float(words[k] / cnt_sum) for k in words}

    with open(f"{MODEL_PATH}/{ngram_len}grams.pkl", "wb") as fh:
        pickle.dump(prob_words, fh)