## 概率n阶语言模型测试
```
概率语言模型

处理N阶语言模型的统计信息
N-gram = 2，词对[x1,x2]的统计信息
N-gram = 3，词对[x1,x2,x3]的统计信息
... ...

todo 统计信息N向下兼容（todo设计思路）

todo ngram的问题：
    n越大
        需要存的index就越多，内存吃紧
        但是统计数据越发稀疏，miss率高
```

In [1]:
# 环境配置
%cd /playground/sgd_deep_learning/sgd_nlp/
import sys 
sys.path.append('./python')

/playground/sgd_deep_learning/sgd_nlp


In [2]:
from sgd_nlp.common import LanguageModel, DefaultToken
from math import log

In [3]:
# simple test code
def test_language_model(tokens_list):
    lan_model = LanguageModel(tokens_list=tokens_list, N_gram=3)

    flag_test_all = True

    flag_test_build_ngram_counters = False
    if flag_test_build_ngram_counters or flag_test_all:
        print('\n********** build ngram counters **********')
        for i in range(3):
            print(lan_model.ngram_counter[i])

    flag_test_tuple = False
    if flag_test_tuple or flag_test_all:
        # python 稀奇的小问题
        print('\n********** python 稀奇的小问题 **********',
                "\ncase1:\t", (DefaultToken.unk_token), ' type is ', type((DefaultToken.unk_token)),
                "\ncase2:\t", (DefaultToken.unk_token,), ' type is ', type((DefaultToken.unk_token,)), )

    flag_test_count_all_token = False
    if flag_test_count_all_token or flag_test_all:
        print("\n********** test: count_all_token() **********",
                "\nall token num:\t", lan_model.count_all_token(),
                "\ntoken dict:\t", lan_model.ngram_counter[0])

    flag_test_count = False
    if flag_test_count or flag_test_all:
        print("\n********** test: count()_1 **********",
                "\ncount bos_token:\t", lan_model.count(DefaultToken.bos), )
        assert lan_model.count(DefaultToken.bos) == 8

        print("\n********** test: count()_2 **********",
                "\ncount (a, b, c):\t", lan_model.count(['a', 'b', 'c']), )
        assert lan_model.count(['a', 'b', 'c']) == 1

    flag_test_prob = False
    if flag_test_prob or flag_test_all:
        print("\n********** test: prob()_1 **********",
                "\nprob c|(a, b):\t", lan_model.prob(['a', 'b', 'c']), )
        assert abs(lan_model.prob(['a', 'b', 'c']) - 0.5) < 1e-5  # 2case (a b c) & (a b <-eos->)

        print("\n********** test: prob()_2 **********",
                "\nprob f|e:\t", lan_model.prob(['e', 'f']), )
        assert abs(lan_model.prob(['e', 'f']) - 0) < 1e-5

        print("\n********** test: prob()_3 **********",
                "\nprob b:\t", lan_model.prob('b'),
                "\ncount b:\t", lan_model.count('b'),
                "\ncount all token:\t", lan_model.count_all_token(),
                "\nvalid prob:\t", lan_model.count('b') / lan_model.count_all_token(), )
        assert abs(lan_model.prob('b') - 0.142857142) < 1e-5

    flag_test_log_prob = True
    if flag_test_log_prob or flag_test_all:
        print("\n********** test: log_prob()_1 **********",
                "\nlog_prob c|(a, b):\t", lan_model.log_prob(['a', 'b', 'c']),
                "\nvalid log prob:\t", log(0.5), )
        assert abs(lan_model.log_prob(['a', 'b', 'c']) - -0.693147185) < 1e-5  # 2case (a b c) & (a b <-eos->)

In [4]:
" simple test code main loop"
doc = "a b c d\n c b a \n a b\n a"
tokens_list = [line.strip().split() for line in doc.split('\n')]
tokens = [token for tokens in tokens_list for token in tokens]

print(tokens_list)
print(tokens)

test_language_model(tokens_list)

[['a', 'b', 'c', 'd'], ['c', 'b', 'a'], ['a', 'b'], ['a']]
['a', 'b', 'c', 'd', 'c', 'b', 'a', 'a', 'b', 'a']

********** build ngram counters **********
Counter({('\x01',): 8, ('a',): 2, ('b',): 2, ('c',): 2})
Counter({('\x01', '\x01'): 4, ('\x01', 'a'): 3, ('a', 'b'): 2, ('b', 'c'): 1, ('c', 'd'): 1, ('\x01', 'c'): 1, ('c', 'b'): 1, ('b', 'a'): 1})
Counter({('\x01', '\x01', 'a'): 3, ('\x01', 'a', 'b'): 2, ('a', 'b', 'c'): 1, ('b', 'c', 'd'): 1, ('c', 'd', '\x02'): 1, ('\x01', '\x01', 'c'): 1, ('\x01', 'c', 'b'): 1, ('c', 'b', 'a'): 1, ('b', 'a', '\x02'): 1, ('a', 'b', '\x02'): 1, ('\x01', 'a', '\x02'): 1})

********** python 稀奇的小问题 ********** 
case1:	   type is  <class 'str'> 
case2:	 ('\x04',)  type is  <class 'tuple'>

********** test: count_all_token() ********** 
all token num:	 14 
token dict:	 Counter({('\x01',): 8, ('a',): 2, ('b',): 2, ('c',): 2})

********** test: count()_1 ********** 
count bos_token:	 8

********** test: count()_2 ********** 
count (a, b, c):	 1

********