In [154]:
import json
import pypinyin
import numpy as np
from gensim.models import Word2Vec

## Load

In [155]:
print('Loading...')

characterFrequency = json.load(open('singleCharacterFrequency.json', 'r'))
characters = list(characterFrequency.keys())
char_finals = [ pypinyin.pinyin(char, style=pypinyin.Style.FINALS)[0][0] for char in characters]
char_initials = [ pypinyin.pinyin(char, style=pypinyin.Style.INITIALS)[0][0] for char in characters]
tupleCharacterFrequency = json.load(open('tupleCharacterFrequency.json', 'r'))
characterFrequency[''] = 0
for key in characterFrequency:
    characterFrequency[''] += characterFrequency[key]
characterFrequency.update(tupleCharacterFrequency)

model = Word2Vec.load('word2vec_model')

print('Loading complete!')

Loading dictionary...
Loading complete!


In [196]:
np.cos(model.wv['你'])

array([ 0.98771191,  0.75084817,  0.80930418,  0.98708963,  0.99385154,
        0.51219028,  0.44752839,  0.99900562,  0.62777478,  0.58852053,
       -0.6615302 , -0.99959993,  0.90712661,  0.3797313 , -0.10601971,
        0.95225739, -0.27716762,  0.6210534 ,  0.59067506,  0.30496737,
        0.23779927,  0.9789629 ,  0.9394905 ,  0.1825501 , -0.22118084,
        0.98841017,  0.99573839,  0.99968958,  0.73323852, -0.52361488,
        0.842529  ,  0.95947987, -0.09842757,  0.40489838,  0.3452858 ,
        0.30990127,  0.71232855,  0.68370593,  0.99010342,  0.9936341 ,
        0.99388856,  0.97015762,  0.70044303,  0.40831292,  0.8876943 ,
        0.99680346, -0.09522749,  0.71229827,  0.34720078, -0.04126502,
        0.60583615,  0.43080217,  0.94020194,  0.98551464,  0.69989645,
        0.48613837,  0.99932337,  0.8640241 ,  0.07694994,  0.97163868,
        0.29716748,  0.54222894,  0.81688821,  0.22546697, -0.04843049,
        0.13188387,  0.77423269,  0.71766788, -0.06749327,  0.51

In [201]:
def similarity(char_x, char_y):
    try:
        return model.similarity(char_x, char_y)
    except Exception:
        return 0
    
def similarity_to_input(line):
    return [ np.sum([similarity(char, line_char) for line_char in line]) for char in characters]

In [202]:
def predict_phrase(line):
    '''Predict the most frequent phrases that satisfiy the rhythm
    
    line: the previous line, ex: ['i', 'ao']
    num_view: the number of phrases that returned
    '''
    finals = [li[0] for li in pypinyin.pinyin(line, style=pypinyin.Style.FINALS)]
    initials = [li[0] for li in pypinyin.pinyin(line, style=pypinyin.Style.INITIALS)]
    char_similarity = similarity_to_input(line)
    print(char_similarity)
    
    phrase_length = len(finals)
    vocab_size = len(characters)
    
    probability = np.zeros([phrase_length, vocab_size], dtype='float')
    path = np.zeros([phrase_length, vocab_size], dtype='int') - 1
    for idx, char in enumerate(characters):
        if char_finals[idx] == finals[0] and char_initials[idx] != initials[0]:
            probability[0][idx] = characterFrequency[char] * char_similarity[idx]

    for k in range(1, phrase_length):
        for idx, char in enumerate(characters):
            if char_finals[idx] == finals[k] and char_initials[idx] != initials[k]:
                with_prev_freq = np.array(\
                        [ float(characterFrequency.get(prev + char, 0)) / characterFrequency.get(prev, 1) * char_similarity[idx] for prev in characters ])
                
                probability[k][idx] = np.max(probability[k - 1] * with_prev_freq)
                if probability[k][idx] > 0:
                    path[k][idx] = np.argmax(probability[k - 1] * with_prev_freq)
    
    def path2phrase(k, idx):
        phrase = ''
        while k >= 0:
            if idx == -1: return None
            phrase = characters[idx] + phrase
            idx = path[k][idx]
            k -= 1
        return phrase
    
    return [path2phrase(phrase_length - 1, idx)\
            for idx in np.argsort(probability[phrase_length - 1])[::-1]\
            if probability[phrase_length - 1][idx] > 0]

In [203]:
while True:
    line = input("Please input pinyin:\n")
    result = predict_phrase(line)

    print(result)

Please input pinyin:
你好
[0.053554769845141556, -0.17663253776686438, 0.066475642717516867, 0, 0.24480652124904526, 0, 0.018865798127344457, 0, 0, 0, 0, 0, -0.22905830748575759, 0, 0, 0, 0, 0, 0, -0.23369684441659988, -0.20314029450644819, 0, 0.16839569530814008, 0, 0, -0.15928423099966346, 0, 0, 0, -0.11631225497836742, 0.08736440103080724, 0, -0.0707613424328771, 0, 0, 0, 0, 0, -0.059225087386132531, 0.064196683632551654, -0.10131667574565537, 0, 0, 0, 0, 0.072652224899237172, 0, 0, -0.035566678791055989, 0, 0, -0.26839295484988063, 0, 0.12166221686679769, 0, 0, 0, -0.10684310081891063, -0.12372573989732064, 0, 0.062211117874163703, 0, -0.16703757281293402, 0, 0, 0, 0, 0, -0.089390211389770519, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -0.046842610137105181, 0, 0, -0.072906212242303203, -0.14891053604801668, 0, -0.03009396231347336, 0, 0, 0.016751359436276609, 0, 0.18779928979291327, 0, 0.085922199510583802, -0.26642194703653099, -0.021344299674394857, 0, 0, 0, -0.16669192800619659, -

TypeError: 'list' object is not callable