In [None]:
import json
import os
import pandas as pd
import numpy as np
from pathlib import Path
import collections
from sklearn.model_selection import train_test_split
from sklearn import metrics

import sys
sys.path.append("../")
from datatools.analyzer import *
from utterance.error_tools import *

from datatools.maneger import DataManager
from datatools.preproc import Preprocessor

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules import loss
import torch.optim as optim
import torch.nn.utils.rnn as rnn

In [None]:
path = "../hand_labeled/"
datalist = ['DCM', 'DIT', 'IRS']
convs = read_conv(path, datalist)

In [None]:
def make_Xy_4test(convs, N=4):
    errors = ["Topic transition error", 'Lack of information', 'Unclear intention']
    # errors = errors[:1]
    
    X = []
    y = []
    for conv in convs:
        dialogue = [""]*N
        for i, ut in enumerate( conv ) :
            # ユーザ発話駆動
            dialogue.append(clean_text( ut.utt) )
            if ut.is_exist_error():
                X.append( dialogue[-N:] )
                    # X.append(dialogue[-N:])
                if ut.is_error_included(errors) :
                    y.append(1)
                else:
                    y.append(0)
        
    return X, y

In [None]:
N = 3
X_str, y = make_Xy_4test(convs, N=N)

In [None]:
from sentence_transformers import SentenceTransformer
# from sentence_transformers import models

bert_path = "../../corpus/pretrained/sbert_unclear1"
sbert = SentenceTransformer(bert_path)

In [None]:
X_forward_all_str = sum(X_str, [])
x_length = len(X_forward_all_str)//N
# X_topic_vec = smodel.encode(X_forward_all_str).reshape(x_length, N, -1)
X_topic_vec = sbert.encode(X_forward_all_str).reshape(x_length, N, -1)

In [None]:
def vec2feature(vector):
    diff = np.abs( vector[0] - vector[1] )
    return np.concatenate([vector.flatten(), diff])

In [None]:
emb_dim = 768
def sentence2formated(vectors):
    features = []
    prev_vector = np.zeros(emb_dim)
    for i, vector in enumerate(vectors):
        feature = vec2feature( np.array([prev_vector, vector]) ) 
        features.append(feature)
        prev_vector = vector
    return np.array(features)

In [None]:
X_topic = np.array([ sentence2formated(vec) for vec in X_topic_vec ])

In [None]:
X = torch.from_numpy(X_topic)
y = torch.Tensor(y)

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30, random_state=5, stratify=y)

In [None]:
class Datasets(torch.utils.data.Dataset):
    def __init__(self, X_data, y_data):
        # self.transform = transform

        self.X_data = X_data
        self.y_data = y_data

        self.datanum = len(X_data)

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        out_X = self.X_data[idx]
        out_y = self.y_data[idx]

        return out_X, out_y

In [None]:
import copy
class TopicClassifier(nn.Module):
    def __init__(self, topic_dim, forward_dim, topic_hid, for_hid, tagset_size):
        # 親クラスのコンストラクタ。決まり文句
        super(TopicClassifier, self).__init__()
        # 隠れ層の次元数。これは好きな値に設定しても行列計算の過程で出力には出てこないので。    
        self.tlen = topic_dim
        self.flen = forward_dim
        # self.hidden = hidden_dim
        # 768->256
        self.tlstm = nn.LSTM(topic_dim, topic_hid, batch_first=True)
        # self.lay2_lstm = nn.LSTM(hidden_dim+forward_dim//2, hidden_dim2, batch_first=True)
        self.flstm = nn.LSTM(forward_dim, for_hid, batch_first=True)
        # self.for2hid = nn.Linear(forward_dim , forward_dim//2)
        self.hid2out = nn.Linear(topic_hid , tagset_size)
        self.tanh = nn.Tanh()
        self.relu = nn.ReLU()
        self.softmax = nn.LogSoftmax()

    def forward(self, x):
        x_topic = x[:, :, :self.tlen].to(torch.float)
        # x_for_hid = self.for2hid(x_forward)
        # print(x_topic.shape)

        # forward_c = torch.stack( [ self.fmodel.last_context(xfid) for xfid in x_forward_id])
        # topic_out, _ = self.tlstm(x_topic)
        _, tout = self.tlstm(x_topic)
        out = self.hid2out(tout[0][0] )
        y = self.softmax(out)
        
        return y