In [1]:
from collections import Counter, namedtuple
import json
import re

In [2]:
DATASET_DIR = './WebNews.json'
with open(DATASET_DIR, encoding = 'utf8') as f:
    dataset = json.load(f)

In [5]:
seg_list = list(map(lambda d: d['detailcontent'], dataset))
rule = re.compile(r"[^\u4e00-\u9fa5]")
seg_list = [rule.sub('', seg) for seg in seg_list]

In [6]:
def ngram(documents, N=2):
    ngram_prediction = dict()
    total_grams = list()
    words = list()
    Word = namedtuple('Word', ['word', 'prob'])

    for doc in documents:
        split_words = ['<s>'] + list(doc) + ['</s>']
        # 計算分子
        [total_grams.append(tuple(split_words[i:i+N])) for i in range(len(split_words)-N+1)]
        # 計算分母
        [words.append(tuple(split_words[i:i+N-1])) for i in range(len(split_words)-N+2)]
        
    total_word_counter = Counter(total_grams)
    word_counter = Counter(words)
    
    for key in total_word_counter:
        word = ''.join(key[:N-1])
        if word not in ngram_prediction:
            ngram_prediction.update({word: set()})
            
        next_word_prob = total_word_counter[key]/word_counter[key[:N-1]]
        w = Word(key[-1], '{:.3g}'.format(next_word_prob))
        ngram_prediction[word].add(w)
        
    return ngram_prediction

In [9]:
tri_prediction = ngram(seg_list, N=3)
for word, ng in tri_prediction.items():
    tri_prediction[word] = sorted(ng, key=lambda x: x.prob, reverse=True)

In [18]:
tri_prediction

{'<s>出': [Word(word='席', prob='1')],
 '出席': [Word(word='活', prob='0.23'),
  Word(word='桃', prob='0.151'),
  Word(word='</s>', prob='0.139'),
  Word(word='年', prob='0.0629'),
  Word(word='中', prob='0.0207'),
  Word(word='典', prob='0.019'),
  Word(word='國', prob='0.0141'),
  Word(word='大', prob='0.0132'),
  Word(word='八', prob='0.0124'),
  Word(word='新', prob='0.0116'),
  Word(word='財', prob='0.00993'),
  Word(word='記', prob='0.00993'),
  Word(word='龍', prob='0.00993'),
  Word(word='視', prob='0.00911'),
  Word(word='說', prob='0.00828'),
  Word(word='第', prob='0.00828'),
  Word(word='平', prob='0.00828'),
  Word(word='頒', prob='0.00745'),
  Word(word='開', prob='0.00662'),
  Word(word='市', prob='0.00579'),
  Word(word='台', prob='0.00579'),
  Word(word='蘆', prob='0.00497'),
  Word(word='觀', prob='0.00497'),
  Word(word='全', prob='0.00497'),
  Word(word='啟', prob='0.00414'),
  Word(word='表', prob='0.00414'),
  Word(word='農', prob='0.00414'),
  Word(word='南', prob='0.00414'),
  Word(word='音', 

In [16]:
text = '天下'
next_words = list(tri_prediction[text])[:5]
for next_word in next_words:
    print('next word: {}, probability: {}'.format(next_word.word, next_word.prob))
    
"""
next word: 仁, probability: 0.111
next word: 及, probability: 0.0741
next word: 地, probability: 0.0741
next word: 超, probability: 0.0741
next word: 文, probability: 0.0741
"""

next word: 雜, probability: 0.355
next word: 午, probability: 0.129
next word: 的, probability: 0.0968
next word: 第, probability: 0.0968
next word: 沒, probability: 0.0323


'\nnext word: 仁, probability: 0.111\nnext word: 及, probability: 0.0741\nnext word: 地, probability: 0.0741\nnext word: 超, probability: 0.0741\nnext word: 文, probability: 0.0741\n'