In [6]:
import json
import os
import pandas as pd
import numpy as np
from pathlib import Path
import collections
from sklearn.model_selection import train_test_split
from sklearn import metrics

import sys
sys.path.append("../")
from datatools.analyzer import *
from utterance.error_tools import *

from datatools.maneger import DataManager
from datatools.preproc import Preprocessor

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules import loss
import torch.optim as optim

In [8]:
class CaseModel(nn.Module):
    def __init__(self, embedding_dim, tagset_size):
        # 親クラスのコンストラクタ。決まり文句
        super(CaseModel, self).__init__()    
        self.embedding_dim = embedding_dim
        self.hid1= embedding_dim*2
        self.hid2 = embedding_dim//2
        self.fc1 = nn.Linear(self.embedding_dim, self.hid1)
        self.fc2 = nn.Linear(self.hid1, self.hid2)
        self.hidden2tag = nn.Linear(self.hid2, tagset_size)
    
    def forward(self, x):
        y = F.relu(self.fc1(x))
        y = F.relu(self.fc2(y))
        y = self.hidden2tag( y )
        y = F.log_softmax(y, dim=1)
        return y

In [9]:
border=0.8

In [10]:
dict_path = "../models/utterance/"
dict_name = "group_border={0}_2.pickle".format(border)
dictM = DataManager(dict_path)
group_dict = dictM.load_data(dict_name)

success load : ../models/utterance/group_border=0.8_2.pickle


In [11]:
def search_word(word):
    word = clean_text(word)
    found_list = []
    for group_key in group_dict.keys():
        if word in group_dict[group_key]:
            found_list.append(group_key)
    return found_list

In [12]:
from gensim.models import KeyedVectors
w2v_path = "../../corpus/w2v/"
# fasttext
# https://qiita.com/Hironsan/items/513b9f93752ecee9e670
# w2v_name =  "dep-ja-300dim"
w2v_name =  "model.vec"
w2v_model = KeyedVectors.load_word2vec_format(w2v_path+w2v_name)

[2777] 2022-01-06 14:04:02,775 Info gensim.models.keyedvectors :loading projection weights from ../../corpus/w2v/model.vec
[2777] 2022-01-06 14:05:05,971 Info gensim.utils :KeyedVectors lifecycle event {'msg': 'loaded (351122, 300) matrix of type float32 from ../../corpus/w2v/model.vec', 'binary': False, 'encoding': 'utf8', 'datetime': '2022-01-06T14:05:05.971525', 'gensim': '4.0.1', 'python': '3.6.9 (default, Jan 26 2021, 15:33:00) \n[GCC 8.4.0]', 'platform': 'Linux-5.4.72-microsoft-standard-WSL2-x86_64-with-Ubuntu-18.04-bionic', 'event': 'load_word2vec_format'}


In [13]:
group2index = dict(zip( group_dict.keys(), range(len(group_dict.keys())) ))
index2group = dict(zip(range(len(group_dict.keys())), group_dict.keys() ))

In [14]:
model_path = "../models/utterance/"
model_name = "case_frame_{0}_2.pickle".format(border)
modelM = DataManager(model_path)
cmodel = modelM.load_data(model_name)

success load : ../models/utterance/case_frame_0.8_2.pickle


In [15]:
def classify_word(word):
    word = clean_text(word)
    if word not in w2v_model:
        return ""
    with torch.no_grad():
        vector = torch.tensor([w2v_model[word]]).cuda()
        pred = np.array(cmodel(vector).cpu()).argmax()
    
    return index2group[pred]

In [16]:
def register_triple(case_frame:dict, V, C, Noun):
    # 動詞の確認
    if V not in case_frame:
        case_frame[V] = dict()

    # 格の確認
    if C not in case_frame[V]:
        case_frame[V][C] = set()
    
    # 名詞の登録
    case_frame[V][C].add(Noun)

In [17]:
def compound_noun(token):
    base = token.lemma_
    for child in token.children:
        if child.dep_ == "compound":
           base = child.lemma_ + base
    return base

In [25]:
def extract_RDF_triple(text):
    text = clean_text(text)
    doc = nlp(text)
    triple_list = []
    for i, token in enumerate( doc ):
        if token.pos_ in ["VERB", "ADJ"]:
        # if token.pos_=="VERB":
        # if is_declinable(token):
            for c in token.children:
                if c.dep_ in ["nsubj", "obj", "obl"]:
                    noun = c.lemma_
                    for c2 in c.children:
                        # if c2.dep_ == "case" and c2.orth_ in case_set:
                        if c2.dep_ == "case":
                            case = c2.orth_
                            if case == "は":
                                case = "が"
                            triple_list.append( (token.lemma_, case, noun) )
    return triple_list

In [26]:
path = "../hand_labeled/"
datalist = ['DCM', 'DIT', 'IRS']
convs = read_conv(path, datalist)

In [27]:
error = "Semantic error"
sys_utt = []
y = []
for conv in convs:
    for ut in conv:
        # if ut.is_system() and ut.is_exist_error():
        if not ut.is_exist_error():
            sys_utt.append(ut.utt)
            if ut.is_error_included(error):
                y.append(1)
            else:
                y.append(0)

In [66]:
with open("../../corpus/NTT/persona.json", "r", encoding="utf-8") as f:
    convs = json.load(f)
all_utt = []
for did in tqdm( convs["convs"] ) :
    dids = list( did.keys() )[0]
    all_utt += did[dids]

100%|██████████| 5016/5016 [00:00<00:00, 1263353.68it/s]


In [50]:
# all_utt = sys_utt + all_utt
# len(all_utt)

63167

In [67]:
case_frame = dict()
for utt in tqdm( all_utt ):
    is_valid = True

    triples = extract_RDF_triple(utt)
    if len(triples)>0:
        for triple in triples:
            V = triple[0]
            C = triple[1]
            noun = triple[2]
            group_ = search_word(noun)
            if len(group_) > 0:
                for group in group_:
                    register_triple(case_frame, V, C, group)
            else:
                group = classify_word(noun)
                register_triple(case_frame, V, C, group)

            register_triple(case_frame, V, C, noun)


100%|██████████| 61781/61781 [19:26<00:00, 52.98it/s]


In [68]:
case_frame_name = "case_frame_V2_ntt_uni.pickle"
data_path = "../X_y_data/utterance/"

In [69]:
dataM = DataManager(data_path)
dataM.save_data(case_frame_name, case_frame)

success save : ../X_y_data/utterance/case_frame_V2_ntt.pickle


In [70]:
def search_triple(V, C, N):
    if N in case_frame[V][C]:
        return True
    else:
        return False

# True : 用法は問題ない
# False : 問題あり
def judge_triple(triple):
    V = triple[0]
    C = triple[1]
    noun = triple[2]

    # Vが登録されていないならパスしておくかな(一旦)
    if V not in case_frame:
        # print("not registered V :", V)
        return True
    
    # C が登録されていないなら，アウト
    if C not in case_frame[V]:
        # print("not registered V, C :", V, C)
        return False
    
    # そのままの名詞で検索 -> 発見なら正例
    if search_triple(V, C, noun):
        return True

    # 名詞がどこかのグループに属する
    group_ = search_word(noun)
    if len(group_) > 0:
        # 1つでも引っかかればOK
        for group in group_:
            if search_triple(V, C, group):
                return True
        return False
    # グループを推定
    else:
        group = classify_word(noun) 
        return search_triple(V, C, group)

In [71]:
path = "../eval_labeled/"
datalist = ['DCM', 'DIT', 'IRS']
convs = read_conv(path, datalist)

error = "Semantic error"
sys_utt = []
y = []
for conv in convs:
    for ut in conv:
        if ut.is_system() and ut.is_exist_error():
            sys_utt.append(ut.utt)
            if ut.is_error_included(error):
                y.append(1)
            else:
                y.append(0)

In [72]:
y_pred = []
for utt in tqdm(sys_utt):
    is_valid = True

    triples = extract_RDF_triple(utt)
    for triple in triples:
        if not judge_triple(triple):
            is_valid = False
            break
    if is_valid:
        y_pred.append(0)
    else:
        y_pred.append(1)

100%|██████████| 1386/1386 [00:22<00:00, 60.61it/s]


In [76]:
score(y, y_pred)

confusion matrix = 
 [[836 542]
 [  6   2]]
accuracy =  0.6046176046176046
precision =  0.003676470588235294
recall =  0.25
f1 score =  0.007246376811594203


- 学習データのみでやってみた

        confusion matrix = 
        [[716 662]
        [  1   7]]
        accuracy =  0.5216450216450217
        precision =  0.01046337817638266
        recall =  0.875
        f1 score =  0.020679468242245203

    - 再現率は良い！

In [74]:
def judge_triple(triple):
    V = triple[0]
    C = triple[1]
    noun = triple[2]

    # Vが登録されていないならパスしておくかな(一旦)
    if V not in case_frame:
        # print("not registered V :", V)
        return True
    
    # C が登録されていないなら，アウト
    if C not in case_frame[V]:
        # print("not registered V, C :", V, C)
        return False
    
    # そのままの名詞で検索 -> 発見なら正例
    if search_triple(V, C, noun):
        print("\tfound", V, C, noun)
        return True

    # 名詞がどこかのグループに属する
    group_ = search_word(noun)
    if len(group_) > 0:
        # 1つでも引っかかればOK
        for group in group_:
            if search_triple(V, C, group):
                return True
        return False
    # グループを推定
    else:
        group = classify_word(noun) 
        return search_triple(V, C, group)

In [75]:
for t, p, utt in zip(y, y_pred, sys_utt):
    if t==1:
        print(utt)
        triples = extract_RDF_triple(utt)
        for triple in triples:
            print(triple)
            print( judge_triple(triple) )
        print("--------")

元気ですかは元気です
--------
好きだを見ますよねー
('見る', 'を', '好き')
False
--------
病院は治療を受けましょう
('受ける', 'が', '病院')
True
('受ける', 'を', '治療')
True
--------
好きだは好きですか。お寿司はエンガワが好きですね
('好き', 'が', '縁側')
True
--------
時期から資格を取りますねぇ
('取る', 'から', '時期')
False
('取る', 'を', '資格')
	found 取る を 資格
True
--------
手を貯金に出しますねぇ
('出す', 'を', '手')
	found 出す を 手
True
('出す', 'に', '貯金')
True
--------
ところで、テレビでテレビあるって言ってましたが、テレビは民主党支持が多いですね
('ある', 'で', 'テレビ')
	found ある で テレビ
True
('言う', 'で', '所')
	found 言う で 所
True
('多い', 'が', '支持')
True
--------
旬ですねぇ。自分もオリンピック書いたし。
('書く', 'も', '自分')
	found 書く も 自分
True
--------
