In [1]:
import re
import sys
import json


import sys
# sys.dont_write_bytecode = True
# sys.path.append('../')
from datatools.analyzer import *

from datatools.maneger import DataManager
from datatools.preproc import Preprocessor

In [2]:
from sklearn import metrics
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score

class Group:
    def __init__(self, y, y_pred, level, eval_, name) -> None:
        self.true = y
        self.pred = y_pred
        self.level = level
        self.scoring()
        self.eval_ = eval_
        self.name = name

    def scoring(self):
        self.pre = precision_score(self.true, self.pred)
        self.rec = recall_score(self.true, self.pred)
        self.f1 = f1_score(self.true, self.pred)

    def get_score(self):
        if self.eval_=="pre":
            return self.pre
        elif self.eval_=="rec":
            return self.rec
        else:
            return self.f1
    
    def __lt__(self, other):
        return self.get_score() < other.get_score()
    
    def __str__(self) -> str:
        return "{0} : {1}".format(self.name, self.level)
        

In [3]:
import numpy as np
class ResultManeger:
    def __init__(self, result_path="./X_y_data/y_pred/") -> None:
        self.path = result_path
        self.dataM = DataManager(result_path)

    def set_data(self, data_name_dict, group2idx,  eval_="pre"):
        self.groups = []
        # self.y_data = []
        for name in data_name_dict:
            true, pred = self.dataM.load_data(name)
            print(len(pred), name)
            self.groups.append(Group(true, pred, data_name_dict[name], eval_, name))
            # y.shape = (group, length)  
        self.groups.sort(reverse=True)
    
    def relabel_data(self, group2idx):
        y_pred = np.zeros(self.y_all.shape, dtype=np.int)
        for i, _ in enumerate(y_pred):
            # ソートされたグループを順に調べる
            is_Utt = False
            is_RES = False
            is_CON = False
            is_SOC = False
            for group in self.groups:
                level = group.level
                if group.pred[i] == 1:
                    #  1 だったとしても制御が必要よ
                    if "u-" in level:
                        if is_SOC or is_RES or is_CON:
                            continue
                    # ここが読まれるまでに uut は存在しない
                    elif "r-" in level:
                        # 同じレベルで複数ラベリングはされない
                        if is_RES:
                            continue
                    elif "c-" in level:
                        if is_CON:
                            continue
                    # ----------
                    idx = group2idx[level]
                    y_pred[i, idx] = 1
                    # ^---------
                    # 発話レベル
                    if "u-" in level:
                        is_Utt = True
                        break
                    elif "r-" in level:
                        is_RES = True
                    elif "c-" in level:
                        is_CON = True
                    else:
                        is_SOC = True
                        break
                    # y_pred
        return y_pred


    # 正しい y を抽出
    def extract_y_true_group(self, error_sets, path="./eval_labeled/"):
        datalist = ['DCM', 'DIT', 'IRS']
        convs = read_conv(path, datalist)
        self.y_all = []

        for conv in convs:
            for ut in conv:
                if ut.is_system() and ut.is_exist_error():
                    # エラーを検索
                    y_each = np.zeros(len(error_sets))
                    for i, errors in enumerate(error_sets):
                        if ut.is_error_included(errors):
                            y_each[i] = 1
                    self.y_all.append(y_each)
        self.y_all = np.array(self.y_all , dtype=np.int)


In [43]:
result_path = "./X_y_data/y_pred/"
result_path = "./X_y_data/base_y_pred/"

In [44]:
import os
os.listdir(result_path)

['context_content.pickle',
 'utt_form.pickle',
 'common.pickle',
 'impolite.pickle',
 'context_form.pickle',
 'ignore.pickle',
 'utt_content.pickle']

In [45]:
data_name_dict = {
    'wrong.pickle':"u-c",
    'LM.pickle':"u-f",
    'caseFrame.pickle':"u-c",
    'common.pickle':"s-c",
    'repeat.pickle':"c-c",
    'impolite.pickle':"s-f",
    'context_form.pickle':"c-f",
    'ignore.pickle':"r-f",
    'contradict.pickle':"c-c"
}

In [46]:
data_name_dict = {
    'utt_content.pickle':"u-c",
    'utt_form.pickle':"u-f",
    'common.pickle':"s-c",
    'impolite.pickle':"s-f",
    'context_form.pickle':"c-f",
    'ignore.pickle':"r-f",
    'context_content.pickle':"c-c"
}

In [47]:
group = "u-f u-c r-f r-c c-f c-c s-f s-c".split()
group2idx = dict( zip(group, range(len(group))) )


In [48]:
error_sets = [
    # 発話形式
    ['Uninterpretable', 'Grammatical error'],
    # 発話内容
    ['Semantic error', 'Wrong information'],
    # 応答形式
    ["Ignore question", 'Ignore offer', 'Ignore proposal', "Ignore greeting"],
    # 応答内容
    ["Ignore expectation"], 
    # 文脈形式
    ['Topic transition error', 'Lack of information', 'Unclear intention'],
    # 文脈内容
    ['Self-contradiction', 'Contradiction', 'Repetition'],
    # 社会形式
    ['Lack of sociality'],
    # 社会内容
    ['Lack of common sense']
]


In [49]:
resman = ResultManeger(result_path)
resman.extract_y_true_group(error_sets)
resman.set_data(data_name_dict, group2idx, eval_="pre")
y_pred = resman.relabel_data(group2idx)

success load : ./X_y_data/base_y_pred/utt_content.pickle
1386 utt_content.pickle
success load : ./X_y_data/base_y_pred/utt_form.pickle
1386 utt_form.pickle
success load : ./X_y_data/base_y_pred/common.pickle
1386 common.pickle
success load : ./X_y_data/base_y_pred/impolite.pickle
1386 impolite.pickle
success load : ./X_y_data/base_y_pred/context_form.pickle
1386 context_form.pickle
success load : ./X_y_data/base_y_pred/ignore.pickle
1386 ignore.pickle
success load : ./X_y_data/base_y_pred/context_content.pickle
1386 context_content.pickle


In [50]:
print('EM:', metrics.accuracy_score(resman.y_all, y_pred))
# print("jaccard:", metrics.jaccard_score(resman.y_all, y_pred, average='samples'))
print("precision:", metrics.precision_score(resman.y_all, y_pred, average='samples'))
print("recall", metrics.recall_score(resman.y_all, y_pred, average='samples'))
print('F-measure: ', metrics.f1_score(resman.y_all, y_pred, average='samples'))
print('0/1 loss: ', metrics.zero_one_loss(resman.y_all, y_pred))

EM: 0.5627705627705628
precision: 0.6566859066859068
recall 0.6885521885521886
F-measure:  0.6607022607022607
0/1 loss:  0.4372294372294372


  _warn_prf(average, modifier, msg_start, len(result))


- precision ベース

        EM: 0.5959595959595959
        F-measure:  0.7104858104858105

- racell ベース

        EM: 0.5735930735930735
        F-measure:  0.6823472823472824

- f値ベース
        
        EM: 0.5735930735930735
        F-measure:  0.6823472823472824

### 提案手法

        EM: 0.5959595959595959
        precision: 0.7028619528619529
        recall 0.7436267436267435
        F-measure:  0.7104858104858105

### ベースライン

        EM: 0.538961038961039
        precision: 0.6313131313131313
        recall 0.6611351611351611
        F-measure:  0.6345598845598845
        

In [51]:
zero = 0
for f in y_pred:
    # print(f){}
    if 1 not in f:
        zero += 1


In [52]:
for g in resman.groups:
  print(g)

ignore.pickle : r-f
context_form.pickle : c-f
utt_content.pickle : u-c
common.pickle : s-c
context_content.pickle : c-c
utt_form.pickle : u-f
impolite.pickle : s-f


In [53]:
zero

133

In [54]:
y = np.array(resman.y_all)

In [55]:
y_pred.shape

(1386, 8)

In [56]:
for i, g in enumerate(group):
    print(g)
    yt = y.T[i]
    yp = y_pred.T[i]
    print("\tacuracy:", metrics.accuracy_score(yt, yp))
    print("\tprecision:", metrics.precision_score(yt, yp))
    print("\trecall", metrics.recall_score(yt, yp))
    print('\tF-measure: ', metrics.f1_score(yt, yp))
    print()
    # print('0/1 loss: ', metrics.zero_one_loss(yt, yp))

u-f
	acuracy: 0.9949494949494949
	precision: 0.0
	recall 0.0
	F-measure:  0.0

u-c
	acuracy: 0.8484848484848485
	precision: 0.8063492063492064
	recall 0.630272952853598
	F-measure:  0.7075208913649025

r-f
	acuracy: 0.8910533910533911
	precision: 0.7843137254901961
	recall 0.7909604519774012
	F-measure:  0.7876230661040786

r-c
	acuracy: 1.0
	precision: 0.0
	recall 0.0
	F-measure:  0.0

c-f
	acuracy: 0.6796536796536796
	precision: 0.6608910891089109
	recall 0.7585227272727273
	F-measure:  0.7063492063492064

c-c
	acuracy: 0.9357864357864358
	precision: 0.0
	recall 0.0
	F-measure:  0.0

s-f
	acuracy: 0.9949494949494949
	precision: 0.0
	recall 0.0
	F-measure:  0.0

s-c
	acuracy: 0.9884559884559885
	precision: 0.15789473684210525
	recall 1.0
	F-measure:  0.2727272727272727



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  average, "true nor predicted", 'F-score is', len(true_sum)
  _warn_prf(average, modifier, msg_start, len(result))


In [59]:
a = list(y.flatten())

In [60]:
len(a)

11088

In [61]:
a.count(1)

1555

In [63]:
(11088-1555)/11088

0.8597582972582972