In [None]:
import pandas as pd
import numpy as np
from simpletransformers.classification import MultiLabelClassificationArgs
from multi_label_classification_model_wrapper import MultiLabelClassificationModelWrapper
import sklearn
import sklearn.metrics
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
import re
import os
from mbti_util import MbtiUtil
from denoicer import Denoicer

In [None]:
mbti_util = MbtiUtil()
denoice = Denoicer()

In [None]:
# 感情コーパスを読み込み訓練データとテストデータに分割する
def read_dataset_to_frame():
    all_data = []
    edf = pd.read_csv('./database/emotion/emotion_vector_corpus_ones.tsv', sep='\t')
    emost = 'Joy Sadness Anticipation Surprise Anger Fear Disgust Trust'
    edf_shuf = sklearn.utils.shuffle(edf, random_state=1)
    for s, e in edf_shuf[['Sentence', emost]].values:
        mc = re.match(r'.*nan.*', emost)
        if mc != None:
            continue
        all_data.append([denoice.normalize_text(s.strip()), np.array(e.split(' '), dtype=np.float32)])
    train_data, eval_data = train_test_split(all_data, random_state=1) # 訓練データ（75%）テストデータ（25%）に分割
    train_df = pd.DataFrame(train_data, columns=['text', 'labels'])
    print("read train data.")
    eval_df = pd.DataFrame(eval_data, columns=['text', 'labels'])
    print("read eval data.")
    return train_df, eval_df
train_df, eval_df = read_dataset_to_frame()

In [None]:
# 感情推定モデルの学習
def train(model_type, model_name, train_df, eval_df):
    model_args = MultiLabelClassificationArgs(
        num_train_epochs=5, 
        train_batch_size=64, 
        eval_batch_size=16,
        use_early_stopping=True
    )
    model = MultiLabelClassificationModelWrapper(
        model_type=model_type,
        model_name=model_name,
        num_labels=8,
        args=model_args,
        use_cuda=True
    )
    model_name = model_name.replace('/', '-')
    if not os.path.exists(f'./model/{model_name}'):
        os.mkdir(f'./model/{model_name}')

    # Train the model
    model.train_model(train_df, output_dir=f'./model/{model_name}')
    # Evaluate the model
    result, model_outputs, wrong_predictions = model.eval_model(eval_df)
    return result, model_outputs, wrong_predictions
    
'''
以下のいずれかのコメントを外して学習するモデルを選択する
'''
# model_type, model_name = ('roberta_waseda_ja', 'nlp-waseda/roberta-base-japanese')
# model_type, model_name = ('twhinbert', 'Twitter/twhin-bert-base')
# model_type, model_name = ('bert', 'cl-tohoku/bert-base-japanese-whole-word-masking')
# model_type, model_name = ('xlnet', 'hajime9652/xlnet-japanese')
# model_type, model_name = ('xlmroberta', 'xlm-roberta-base')

result, model_outputs, wrong_predictions = train(
    model_type=model_type, model_name=model_name, train_df=train_df, eval_df=eval_df)

In [None]:
def evaluate_emotion_model(eval_df, model_type, model_name):
    model = MultiLabelClassificationModelWrapper(
        model_type = model_type,
        model_name = model_name,
        num_labels=8,
        use_cuda = True
    )

    y_true = []
    for label in eval_df['labels'].values:
        v = []
        for l in label:
            v.append(int(l))
        y_true.append(v)
    y_true = np.array(y_true)

    predictions, raw_outputs = model.predict(eval_df['text'].to_list())

    target_names = ['Joy', 'Sadness', 'Anticipation', 'Surprise', 'Anger', 'Fear', 'Disgust', 'Trust']
    report = sklearn.metrics.classification_report(y_true, predictions, target_names=target_names)
    print(model_name)
    print(report)

'''
以下のいずれかのコメントを外して評価するモデルを選択する
'''
# model_type, model_name = ('roberta_waseda_ja', './model/trained/nlp-waseda-roberta-base-japanese/checkpoint-7555-epoch-5')
# model_type, model_name = ('twhinbert', './model/trained/Twitter-twhin-bert-base/checkpoint-7555-epoch-5')
# model_type, model_name = ('bert', './model/trained/cl-tohoku-bert-base-japanese-whole-word-masking/checkpoint-7555-epoch-5')
# model_type, model_name = ('xlnet', './model/trained/hajime9652-xlnet-japanese/checkpoint-7555-epoch-5')
# model_type, model_name = ('xlmroberta', './model/trained/xlm-roberta-base/checkpoint-7555-epoch-5')
evaluate_emotion_model(eval_df=eval_df, model_type=model_type, model_name=model_name)