In [91]:
import numpy as np
import os
import math

# define N-gram
N = 3
# define hyperparameter of knears smoothing
d = 0.75
lambda_unk = 0.05
def _total_num_words(data):
    num = 0
    for _, seq in enumerate(data):
        words = seq.split(" ")
        num += len(words)
    return num
    
# function of read data
def _read_txt_(url):
    file = open(url, 'r', encoding='utf-8')
    seg_list = []
    lines = file.readlines()
    for i in range(len(lines)):
        seg_list.append("<s> " + lines[i].rstrip("\n") + " </s>")
        pass
    file.close()
    return seg_list

train_list = _read_txt_("wiki-en-train.word")

# basic function to count n-gram (windows) frequency
def _count_N_(data, n=1):
    tmp_dict = {}
    for _, seq in enumerate(data):
        words = seq.split(" ")
        for i in range(len(words)):
            if (i + n) >= len(words) + 1: continue
            windows = words[i:i+n]
            combination = ""
            for j in range(n):
                if j == 0:
                    combination += windows[j]
                else:
                    combination += "|" + windows[j]
            if combination not in tmp_dict:
                tmp_dict[combination] = 1
            else:
                num = tmp_dict[combination]
                tmp_dict[combination] = num + 1
    np.save("{}_count_dict".format(str(n)), tmp_dict)
    return tmp_dict

# count unique word frequency
count_set = {}
for i in range(N):
    n = i + 1
    file_name = "{}_count_dict.npy".format(str(n))
    if os.path.exists(file_name):
        count_set[str(n)] = np.load(file_name,allow_pickle=True).item()
    else:
        count_set[str(n)] = _count_N_(train_list,n)

        
# Witten-Bell Smoothing to confirm the lambda of each element
lambda_dict = {}
file_name = "{}_lambda_dict.npy".format(str(N))
if os.path.exists(file_name):
    lambda_dict = np.load(file_name,allow_pickle=True).item()
else:
    tmp_count = list(count_set[str(N-1)].keys())
    tmp_count_current = list(count_set[str(N)].keys())
    unique_set = {}
    for _, x in enumerate(tmp_count):
        num = 0
        for _,y in enumerate(tmp_count_current):
            if y.startswith(x+"|"):
                num += 1
        unique_set[x] = num
    for idx, x in enumerate(tmp_count):
        lambda_dict[x] = 1 - (unique_set[x]/(unique_set[x] + count_set[str(N-1)][x]))
    np.save(file_name, lambda_dict)

# p_continuation
# key: N-1 gram ; Value: list[N-gram with same N-1 gram]
count_continuation = {}
file_name = "{}_continuation_dict.npy".format(str(N))
if os.path.exists(file_name):
    count_continuation = np.load(file_name,allow_pickle=True).item()
else:
    tmp_count = list(count_set[str(N-1)].keys())
    tmp_count_current = list(count_set[str(N)].keys())
    for idx, x in enumerate(tmp_count):
        list_tmp = []
        for _, y in enumerate(tmp_count_current):
            if y.startswith(x+"|"):
                list_tmp.append(y)
        count_continuation[x] = list_tmp
    np.save("{}_continuation_dict.npy".format(str(N)), count_continuation)

# P_kn = p_kn + P_continuation
# p_kn = max(count(N_gram) - d, 0)/count(N-1_gram)
# P_continuation = lambda * p_continuation
# lambda: d/ count(N-1_gram) * count(N_gram)
# p_continuation: count(N_gram) / sum of count(N_gram) with same N-1 gram, that is sum (vw') by w'
def _compute_p_kn_(words):
    
    prefix = words[0]
    for idx, x in enumerate(words[1:-1]):
        prefix =  prefix + "|" + x
    windows = prefix + "|" + words[len(words) - 1]
    tmp_count = count_set[str(N-1)]
    tmp_count_current = count_set[str(N)]
    same_prefix_list = count_continuation[prefix]
    
    value = tmp_count_current[windows] - d
    denominator = 0
    for x in same_prefix_list:
        denominator = denominator + tmp_count_current[x]
    numerator = tmp_count_current[windows]
    if numerator > 0:
        return value/tmp_count[prefix] + (d / tmp_count[prefix]) * tmp_count_current[windows] * (numerator/denominator)
    else:
        return (d / tmp_count[prefix]) * tmp_count_current[windows] * (numerator/denominator)
def _p_(sentence, n=1, T=1):
    words = sentence.split(" ")
    keys_count_current = list(count_set[str(N)].keys())
    lambda_unk = 0.05
    p = 0
    for i in range(len(words)):
#         print(p)
        if (i + n) >= len(words) + 1: continue
        windows = words[i:i+n]
        combination = ""
        for j in range(n):
            if j == 0:
                combination += windows[j]
            else:
                combination += "|" + windows[j]
        if combination not in keys_count_current:
            # this word is unknown
            if p == 0:
                    # the first window
                p = lambda_unk * (1 / T)
            else:
                p = p * lambda_unk * (1 / T)
        else:
            if p == 0:
                p = _compute_p_kn_(windows)
            else:
                p = p * _compute_p_kn_(windows)
    return -math.log(p)

test_list = _read_txt_("wiki-en-test.word")
P_total = 0
T_test = _total_num_words(test_list)
for idx, x in enumerate(test_list):
    p = _p_(x, N, T_test)
    if p > 500:
        print(idx)
    P_total = P_total + p
print(P_total/T_test)

3
7
10
11
13
38
46
74
113
139
145
169
9.491503869427241
