# Doc2Vecの動作確認（一部品詞を落として検証）

マイオペのテストデータを使用した結果では、あまり性能が出ないようなので、急きょ accuracy を調査いたしました。

結果、0.071 (サンプル16,288件中、正解1,152件) と、かなり低い率であることを確認しました。

## (1) テストデータ／環境準備

In [1]:
'''
    テスト環境を準備するためのモジュールを使用します。
'''
import sys
import os
learning_dir = os.path.abspath("../../") #<--- donusagi-bot/learning
os.chdir(learning_dir)

if learning_dir not in sys.path:
    sys.path.append(learning_dir)

## (2) Doc2Vecの動作確認

### (2-1) コーパス生成

コーパス（単語が半角スペースで区切られた文字列）生成時、一部の品詞を落とすようにします。

（＝learning.core.nlang.Nlang クラスの仕様に従います）

In [2]:
import numpy as np

from learning.core.learn.learning_parameter import LearningParameter
from learning.core.datasource import Datasource

_bot_id = 9  # bot_id = 9はセプテーニ
attr = {
    'include_failed_data': False,
    'include_tag_vector': False,
    'classify_threshold': 0.5,
    'algorithm': LearningParameter.ALGORITHM_LOGISTIC_REGRESSION,
    'params_for_algorithm': {'C': 140},
    'excluded_labels_for_fitting': None
}

learning_parameter = LearningParameter(attr)

In [3]:
_datasource = Datasource(type='csv')
learning_training_messages = _datasource.learning_training_messages(_bot_id)
questions = np.array(learning_training_messages['question'])
answer_ids = np.array(learning_training_messages['answer_id'])

2017/05/17 PM 02:54:32 ['./fixtures/learning_training_messages/benefitone.csv', './fixtures/learning_training_messages/ptna.csv', './fixtures/learning_training_messages/septeni.csv', './fixtures/learning_training_messages/toyotsu_human.csv']
2017/05/17 PM 02:54:32 ['./fixtures/question_answers/toyotsu_human.csv']


In [4]:
from learning.core.nlang import Nlang

In [5]:
_sentences = np.array(questions)
_separated_sentences = Nlang.batch_split(_sentences)

### (2-2) コーパスにタグ付け

models.doc2vecの仕様に従います。

In [6]:
from gensim import models
from gensim.models.doc2vec import Doc2Vec
from gensim.models.doc2vec import TaggedDocument

In [7]:
def doc_to_sentence(sentences, name):
    words = sentences.split(' ')
    return TaggedDocument(words=words, tags=[name])

def corpus_to_sentences(separated_sentences, answer_ids):
    for idx, (doc, name) in enumerate(zip(separated_sentences, answer_ids)):
        yield doc_to_sentence(doc, name)

### (2-3) 学習処理／モデルのシリアライズ

In [8]:
sentences = corpus_to_sentences(_separated_sentences, answer_ids)

In [9]:
sentence_list = list(sentences)

In [10]:
model = Doc2Vec(size=500, min_count=1, iter=200)

In [11]:
model.build_vocab(sentence_list)

In [12]:
model.train(sentence_list)

13682673

In [13]:
model_path = 'prototype/better_algorithm/doc2vec.model'

model.save(model_path)

In [14]:
'''
    モデル内に保持されているベクトルの数を取得
    （ラベルを回答IDにすると、ラベルの数が戻る。
    　同一回答IDのサンプルのベクトルが、
    　全て上書きされていると考えられる）
'''
len(model.docvecs)

173

In [15]:
model.docvecs

<gensim.models.doc2vec.DocvecsArray at 0x10b7a57f0>

### (2-4) 予測処理

In [16]:
model_path = 'prototype/better_algorithm/doc2vec.model'

def predict(word, model_path):
    '''
        予測処理にかけるコーパスを生成
        （学習セット作成時と同じ関数を使用）
    '''
    corpus = Nlang.split(word).split()

    '''
        コーパスからベクトルを生成し、
        ロードしたモデルから類似ベクトルを検索
    '''
    loaded_model = models.Doc2Vec.load(model_path)
    inferred_vector = loaded_model.infer_vector(corpus)
    ret = loaded_model.docvecs.most_similar([inferred_vector])

    return corpus, ret

In [17]:
'''
    マウス破損（正解＝4458）
'''
predict('マウス破損', model_path)

(['マウス', '破損'],
 [(4458, 0.7610461711883545),
  (4459, 0.6904299259185791),
  (4530, 0.6668727397918701),
  (7068, 0.5458143949508667),
  (4566, 0.536504864692688),
  (4581, 0.5306761264801025),
  (4620, 0.5201034545898438),
  (4531, 0.5128851532936096),
  (4541, 0.503987193107605),
  (4599, 0.5033320188522339)])

In [18]:
'''
    無線を使用したい（正解＝4516）
'''
predict('無線を使用したい', model_path)

(['無線', '使用', 'する'],
 [(4458, 0.6817476153373718),
  (7042, 0.6709047555923462),
  (4548, 0.6647758483886719),
  (7044, 0.6387978196144104),
  (7068, 0.6331000328063965),
  (4494, 0.6238508820533752),
  (4566, 0.6182248592376709),
  (4516, 0.6084833145141602),
  (4521, 0.597450852394104),
  (4627, 0.5923368334770203)])

In [19]:
'''
    情報システムのアドレス（正解＝7040）
'''
predict('情報システムのアドレス', model_path)

(['情報', 'システム', 'アドレス'],
 [(7065, 0.5496551990509033),
  (7040, 0.4953303933143616),
  (4565, 0.48816725611686707),
  (7037, 0.48583632707595825),
  (7084, 0.4851142168045044),
  (4498, 0.48511412739753723),
  (4489, 0.48499345779418945),
  (4620, 0.4819757044315338),
  (4432, 0.4768473505973816),
  (4523, 0.47387853264808655)])

In [20]:
'''
    誤送信防止システムを使いたい（正解＝4432）
'''
predict('誤送信防止システムを使いたい', model_path)

(['誤る', '送信', '防止', 'システム', '使う'],
 [(4458, 0.5443743467330933),
  (4472, 0.542833685874939),
  (4603, 0.5288375616073608),
  (4525, 0.524386465549469),
  (4566, 0.5055374503135681),
  (4450, 0.4928273856639862),
  (4565, 0.49246060848236084),
  (4569, 0.4910944104194641),
  (4511, 0.4873133599758148),
  (4598, 0.4871697425842285)])

In [21]:
'''
    携帯からサイボウズを使いたいのですが、どうしたら出来ますか？（正解＝4504）
'''
predict('携帯からサイボウズを使いたいのですが、どうしたら出来ますか？', model_path)

(['携帯', 'サイボウズ', '使う', 'どう', 'する', '出来る'],
 [(4611, 0.667762815952301),
  (4459, 0.6388451457023621),
  (4549, 0.6275416612625122),
  (4590, 0.5874701142311096),
  (4598, 0.5784797668457031),
  (4601, 0.5705738663673401),
  (4523, 0.569864809513092),
  (7037, 0.5617852210998535),
  (4507, 0.5613116025924683),
  (4625, 0.5555490851402283)])

## (3) accuracy 測定

ここからが本題です

In [22]:
model_path = 'prototype/better_algorithm/doc2vec.model'

len(_separated_sentences)

16288

In [23]:
def predict_similarity(separated_sentence, model_path):
    corpus = separated_sentence.split()
    loaded_model = models.Doc2Vec.load(model_path)
    inferred_vector = loaded_model.infer_vector(corpus)
    ret = loaded_model.docvecs.most_similar([inferred_vector])

    answer_id, similarity = ret[0]
    return corpus, answer_id, similarity

In [24]:
def get_prediction_statistics(separated_sentences, model_path):
    '''
        学習セットの質問文をそのまま予測処理にかけて、
        回答を予測
    '''
    statistics = []
    for i, _ in enumerate(_separated_sentences):
        sentence = _separated_sentences[i]
        preferred_answer_id = answer_ids[i]
        corpus, answer_id, similarity = predict_similarity(_separated_sentences[i], model_path)
        corpus_len = len(corpus)
        statistics.append((i, corpus_len, preferred_answer_id, answer_id, similarity))

    return statistics

In [25]:
prediction_statistics = get_prediction_statistics(_separated_sentences, model_path)

In [26]:
ncorrect_by_corpus_len = {}
nsample_by_corpus_len = {}

ncorrect_by_answer_id = {}
nsample_by_answer_id = {}
corpus_len_by_answer_id = {}

ncorrect = 0
nsample = 0

'''
    予測結果を、質問文の単語数毎／回答ID毎に統計する
'''
for statistics in prediction_statistics:
    i, corpus_len, preferred_answer_id, answer_id, similarity = statistics
    
    '''
        質問文の単語数ごとに統計を取る
    '''
    if corpus_len not in nsample_by_corpus_len.keys():
        ncorrect_by_corpus_len[corpus_len] = 0
        nsample_by_corpus_len[corpus_len] = 0
    nsample_by_corpus_len[corpus_len] += 1

    '''
        回答IDごとに統計を取る
    '''
    if preferred_answer_id not in nsample_by_answer_id.keys():
        ncorrect_by_answer_id[preferred_answer_id] = 0
        nsample_by_answer_id[preferred_answer_id] = 0
        corpus_len_by_answer_id[preferred_answer_id] = 0
    nsample_by_answer_id[preferred_answer_id] += 1
    corpus_len_by_answer_id[preferred_answer_id] += corpus_len
    
    '''
        正解かどうか検査
    '''
    nsample += 1
    if preferred_answer_id == answer_id:
        ncorrect += 1
        ncorrect_by_corpus_len[corpus_len] += 1
        ncorrect_by_answer_id[preferred_answer_id] += 1

In [27]:
'''
    質問文の単語数ごとの統計情報を編集
'''
info_by_corpus_len = []
for k, v in ncorrect_by_corpus_len.items():
    info_by_corpus_len.append((
        k, 
        ncorrect_by_corpus_len[k]/nsample_by_corpus_len[k], 
        ncorrect_by_corpus_len[k], 
        nsample_by_corpus_len[k]
    ))

In [28]:
'''
    回答IDごとの統計情報を編集
'''
info_by_answer_id = []
for k, v in ncorrect_by_answer_id.items():
    info_by_answer_id.append((
        k, 
        ncorrect_by_answer_id[k]/nsample_by_answer_id[k], 
        ncorrect_by_answer_id[k], 
        nsample_by_answer_id[k],
        corpus_len_by_answer_id[k]/nsample_by_answer_id[k]
    ))

In [29]:
'''
    全体の正解率
'''
print("accuracy=%0.3f (%d/%d)" % (
    ncorrect/nsample, ncorrect, nsample))

accuracy=0.071 (1152/16288)


In [30]:
'''
    質問文の単語数ごとの正解率をリスト
'''
for info in info_by_corpus_len:
    print("word_count=%2d: accuracy=%0.3f (%d/%d)" % (
        info[0], info[1], info[2], info[3]
    ))

word_count= 1: accuracy=0.273 (6/22)
word_count= 2: accuracy=0.121 (156/1292)
word_count= 3: accuracy=0.085 (310/3626)
word_count= 4: accuracy=0.071 (306/4287)
word_count= 5: accuracy=0.063 (180/2871)
word_count= 6: accuracy=0.057 (90/1587)
word_count= 7: accuracy=0.045 (44/974)
word_count= 8: accuracy=0.046 (32/689)
word_count= 9: accuracy=0.044 (16/366)
word_count=10: accuracy=0.033 (8/242)
word_count=11: accuracy=0.014 (2/144)
word_count=12: accuracy=0.000 (0/74)
word_count=13: accuracy=0.059 (2/34)
word_count=14: accuracy=0.000 (0/44)
word_count=15: accuracy=0.000 (0/20)
word_count=16: accuracy=0.000 (0/8)
word_count=17: accuracy=0.000 (0/4)
word_count=18: accuracy=0.000 (0/4)


In [31]:
'''
    回答ラベルごとの正解率をリスト
    （正解率の高いもの順に、上位５０件だけ表示）
    
    参考情報として、回答ラベルがつけられている
    質問文の平均単語数（word average count）を表示
'''
sorted_info_by_answer_id = sorted(info_by_answer_id, key=lambda x:x[1], reverse=True)
for info in sorted_info_by_answer_id[0:50]: # 上位50件だけ表示
    print("answer_id=%2d: accuracy=%0.3f (%d/%d) word average count=%0.1f" % (
        info[0], info[1], info[2], info[3], info[4]
    ))

answer_id=4423: accuracy=0.807 (92/114) word average count=3.6
answer_id=7068: accuracy=0.667 (8/12) word average count=3.0
answer_id=4517: accuracy=0.608 (62/102) word average count=4.0
answer_id=4458: accuracy=0.600 (12/20) word average count=2.7
answer_id=4494: accuracy=0.600 (12/20) word average count=3.1
answer_id=4609: accuracy=0.569 (66/116) word average count=6.0
answer_id=4537: accuracy=0.522 (48/92) word average count=5.6
answer_id=4576: accuracy=0.457 (42/92) word average count=5.5
answer_id=4568: accuracy=0.429 (30/70) word average count=3.8
answer_id=4500: accuracy=0.378 (34/90) word average count=3.3
answer_id=4613: accuracy=0.326 (28/86) word average count=3.0
answer_id=4620: accuracy=0.321 (36/112) word average count=5.4
answer_id=7044: accuracy=0.300 (6/20) word average count=4.6
answer_id=4496: accuracy=0.286 (8/28) word average count=5.9
answer_id=4548: accuracy=0.286 (16/56) word average count=4.0
answer_id=4519: accuracy=0.281 (18/64) word average count=4.9
answer_