In [1]:
import os
import json
from tqdm import tqdm


def load_data(filename):
    D = []
    with open(filename) as f:
        for i, l in enumerate(f):
            label = None
            items = l.strip().split('\t')
            if len(items) == 3:
                idx, text, label = items
#                 label = int(label)
            else:
                idx, text = items
            D.append((idx, text, label))
    return D

def load_ocnli_data(filename):
    D = []
    with  open(filename) as f:
        for l in f:
            label = None
            items = l.strip().split('\t')
            if len(items) == 4:
                idx, s1, s2, label = items
            else:
                idx, s1, s2 = items
            D.append((idx, s1, s2, label))
    return D


tnews_train = load_data('/home/mingming.xu/datasets/NLP/ptms_data/TNEWS_train1128.csv')
tnews_test = load_data('/home/mingming.xu/datasets/NLP/ptms_data/TNEWS_a.csv')

ocemotion_train = load_data('/home/mingming.xu/datasets/NLP/ptms_data/OCEMOTION_train1128.csv')
ocemotion_test = load_data('/home/mingming.xu/datasets/NLP/ptms_data/OCEMOTION_a.csv')

ocnli_train = load_ocnli_data('/home/mingming.xu/datasets/NLP/ptms_data/OCNLI_train1128.csv')
ocnli_test = load_ocnli_data('/home/mingming.xu/datasets/NLP/ptms_data/OCNLI_a.csv')

len(tnews_train), len(tnews_test), len(ocemotion_train), len(ocemotion_test), len(ocnli_train), len(ocnli_test)

(63360, 1500, 35694, 1500, 53387, 1500)

In [2]:
tnews_train[0], ocemotion_train[0], ocnli_train[0]

(('0', '上课时学生手机响个不停,老师一怒之下把手机摔了,家长拿发票让老师赔,大家怎么看待这种事?', '108'),
 ('0',
  "'你知道多伦多附近有什么吗?哈哈有破布耶...真的书上写的你听哦...你家那块破布是世界上最大的破布,哈哈,骗你的啦它是说尼加拉瓜瀑布是世界上最大的瀑布啦...哈哈哈''爸爸,她的头发耶!我们大扫除椅子都要翻上来我看到木头缝里有头发...一定是xx以前夹到的,你说是不是?[生病]",
  'sadness'),
 ('0', '一月份跟二月份肯定有一个月份有.', '肯定有一个月份有', '0'))

In [3]:
def precess_label(train_data):
    labels = set([d[-1] for d in train_data])
    label2id = {k:v for v, k in enumerate(labels)}
    id2label = {v:k for k, v in label2id.items()}
    return labels, label2id, id2label

tnews_labels, tnews_label2id, tnews_id2label = precess_label(tnews_train)
ocnli_labels, ocnli_label2id, ocnli_id2label = precess_label(ocnli_train)
ocemotion_labels, ocemotion_label2id, ocemotion_id2label = precess_label(ocemotion_train)

In [4]:
tnews_labels, ocnli_labels, ocemotion_labels

({'100',
  '101',
  '102',
  '103',
  '104',
  '106',
  '107',
  '108',
  '109',
  '110',
  '112',
  '113',
  '114',
  '115',
  '116'},
 {'0', '1', '2'},
 {'anger', 'disgust', 'fear', 'happiness', 'like', 'sadness', 'surprise'})

In [5]:
tnews_train = [d[:-1] + (tnews_label2id[d[-1]],) for d in tnews_train]
ocnli_train = [d[:-1] + (ocnli_label2id[d[-1]],) for d in ocnli_train]
ocemotion_train = [d[:-1] + (ocemotion_label2id[d[-1]],) for d in ocemotion_train]

tnews_train[0], ocnli_train[0], ocemotion_train[0]

(('0', '上课时学生手机响个不停,老师一怒之下把手机摔了,家长拿发票让老师赔,大家怎么看待这种事?', 6),
 ('0', '一月份跟二月份肯定有一个月份有.', '肯定有一个月份有', 1),
 ('0',
  "'你知道多伦多附近有什么吗?哈哈有破布耶...真的书上写的你听哦...你家那块破布是世界上最大的破布,哈哈,骗你的啦它是说尼加拉瓜瀑布是世界上最大的瀑布啦...哈哈哈''爸爸,她的头发耶!我们大扫除椅子都要翻上来我看到木头缝里有头发...一定是xx以前夹到的,你说是不是?[生病]",
  2))

In [6]:
from toolkit4nlp.models import *
from toolkit4nlp.layers import *
from toolkit4nlp.utils import *
from toolkit4nlp.optimizers import *
from toolkit4nlp.tokenizers import *
from toolkit4nlp.backend import *

import tensorflow as tf

Using TensorFlow backend.


In [7]:
#bert config

config_path = '/home/mingming.xu/pretrain/NLP/nezha_base_wwm/bert_config.json'
# checkpoint_path = '/home/mingming.xu/pretrain/NLP/nezha_base_wwm/model.ckpt'
dict_path = '/home/mingming.xu/pretrain/NLP/nezha_base_wwm/vocab.txt'
checkpoint_path = './post_training/nezha_base_wwm-13'
# tokenizer
tokenizer = Tokenizer(dict_path, do_lower_case=True)

batch_size=16
maxlen=256
epochs = 5





In [8]:
class batch_data_generator(DataGenerator):
    def __init__(self, label_mask, **kwargs):
        super(batch_data_generator, self).__init__(**kwargs)
        self.label_mask = label_mask
        
    def __iter__(self, shuffle=False):
        batch_token_ids, batch_segment_ids, batch_type_ids, batch_labels, batch_label_mask = [], [], [], [], []
        for is_end, item in self.get_sample(shuffle):
            if len(item) == 4:
                _, q, r, l = item
                token_ids, segment_ids = tokenizer.encode(q,r, maxlen=maxlen)
            else:
                _, q, l = item
                token_ids, segment_ids = tokenizer.encode(q, maxlen=maxlen)
            
            batch_token_ids.append(token_ids)
            batch_segment_ids.append(segment_ids)
            batch_labels.append([l])
            batch_label_mask.append(self.label_mask)
            # task type
            task_type = np.argmax(self.label_mask)
            batch_type_ids.append([task_type])
            
            if is_end or self.batch_size == len(batch_token_ids):
                batch_token_ids = pad_sequences(batch_token_ids)
                batch_segment_ids = pad_sequences(batch_segment_ids)
                batch_label_mask = pad_sequences(batch_label_mask)
                batch_labels = pad_sequences(batch_labels)
                batch_type_ids = pad_sequences(batch_type_ids)
                yield [batch_token_ids, batch_segment_ids, batch_type_ids, batch_labels, batch_label_mask], None
                batch_token_ids, batch_segment_ids, batch_type_ids, batch_labels, batch_label_mask = [], [], [], [], []
    



In [9]:
import numpy as np
np.argmax([0,1])

1

In [10]:
# task type
task_type = {'tnews': 0, 
            'ocnli':1,
            'ocemotion':2}

split = 0.8
tnews_mask = [1,0,0]
ocnli_mask = [0,1,0]
ocemotion_mask = [0,0,1]

def split_train_valid(data, split):
    n = int(len(data)*split)
    train_data = data[:n]
    valid_data = data[n:]
    return train_data, valid_data


tnews_train_data, tnews_valid_data = split_train_valid(tnews_train, split)
ocnli_train_data, ocnli_valid_data = split_train_valid(ocnli_train, split)
ocemotion_train_data, ocemotion_valid_data = split_train_valid(ocemotion_train, split)

tnews_train_generator = batch_data_generator(data=tnews_train, batch_size=batch_size, label_mask=tnews_mask)
tnews_valid_generator = batch_data_generator(data=tnews_valid_data, batch_size=batch_size, label_mask=tnews_mask)
tnews_test_generator = batch_data_generator(data=tnews_test, batch_size=batch_size, label_mask=tnews_mask)

ocnli_train_generator = batch_data_generator(data=ocnli_train_data, batch_size=batch_size, label_mask=ocnli_mask)
ocnli_valid_generator = batch_data_generator(data=ocnli_valid_data, batch_size=batch_size, label_mask=ocnli_mask)
ocnli_test_generator = batch_data_generator(data=ocnli_test, batch_size=batch_size, label_mask=ocnli_mask)


ocemotion_train_generator = batch_data_generator(data=ocemotion_train_data, batch_size=batch_size, label_mask=ocemotion_mask)
ocemotion_valid_generator = batch_data_generator(data=ocemotion_valid_data, batch_size=batch_size, label_mask=ocemotion_mask)
ocemotion_test_generator = batch_data_generator(data=ocemotion_test, batch_size=batch_size, label_mask=ocemotion_mask)

tnews_train_generator.take()

([array([[ 101,  677, 6440, 3198, 2110, 4495, 2797, 3322, 1510,  702,  679,
           977,  117, 5439, 2360,  671, 2584,  722,  678, 2828, 2797, 3322,
          3035,  749,  117, 2157, 7270, 2897, 1355, 4873, 6375, 5439, 2360,
          6608,  117, 1920, 2157, 2582,  720, 4692, 2521, 6821, 4905,  752,
           136,  102],
         [ 101, 1555, 6617, 4384, 4413, 5500,  819, 3300, 7361, 1062, 1385,
          1068,  754, 2454, 3309, 1726, 1908,  677, 3862, 6395, 1171,  769,
          3211, 2792, 2190, 1062, 1385, 8109, 2399, 2399, 2428, 2845, 1440,
          4638,  752, 1400, 2144, 3417, 7309, 6418, 1141, 4638, 1062, 1440,
           102,    0],
         [ 101, 6858, 6814,  704,  792, 1062, 1385,  743,  749,  753, 2797,
          2791,  117, 7674,  802, 6963,  802,  749,  117, 4385, 1762, 1297,
          2157,  679, 2682, 1297,  749,  511, 2582,  720, 1905, 4415,  136,
           102,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0],
         [ 101, 827

In [11]:
ocnli_train_generator.take()

([array([[ 101,  671, 3299, ...,    0,    0,    0],
         [ 101,  671, 3299, ...,    0,    0,    0],
         [ 101,  671, 3299, ...,    0,    0,    0],
         ...,
         [ 101,  752, 2141, ...,    0,    0,    0],
         [ 101,  800,  809, ...,    0,    0,    0],
         [ 101,  800,  809, ...,    0,    0,    0]]),
  array([[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]]),
  array([[1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1]]),
  array([[1],
         [2],
         [0],
         [1],
         [1],
         [2],
         [0],
         [1],
         [2],
         [0],
         [2],
         [0],
         [1],
         [2],
         [1],
         [2]]),


In [12]:
class data_generator(DataGenerator):
    def __iter__(self, shuffle=False):
        for is_end, item in self.get_sample(shuffle):
            yield item

train_batch_data = list(tnews_train_generator.__iter__(shuffle=True)) + list(ocnli_train_generator.__iter__(shuffle=True))
train_batch_data += list(ocemotion_train_generator.__iter__(shuffle=True))
train_generator = data_generator(data=train_batch_data, batch_size=1)

In [13]:
class SwitchLoss(Loss):
    """计算三种cls 的loss，然后通过 loss mask 过滤掉非当前任务的loss
    这里也可以利用loss mask对不同task 的loss 加权
    """
    def compute_loss(self, inputs, mask=None):
        tnew_pred, ocnli_pred, ocemotion_pred, y_true, type_input = inputs
        
        train_loss = tf.case([(tf.equal(tf.argmax(type_input[0]), 0), lambda: K.sparse_categorical_crossentropy(y_true, tnews_cls)),
                       (tf.equal(tf.argmax(type_input[0]), 1), lambda: K.sparse_categorical_crossentropy(y_true,ocnli_cls)),
                        (tf.equal(tf.argmax(type_input[0]), 2), lambda: K.sparse_categorical_crossentropy(y_true, ocemotion_cls))
                       ], exclusive=True)
        return K.mean(train_loss)
        

In [14]:
type_in = Input(shape=(1,))
type_emb = Embedding(3, 128)(type_in)
type_emb = Reshape((128,))(type_emb)

bert = build_transformer_model(checkpoint_path=checkpoint_path, config_path=config_path, model='nezha', additional_input_layers=type_in, layer_norm_cond=type_emb)
output = Lambda(lambda x: x[:,0])(bert.output)
output = Dropout(0.1)(output)

tnews_cls = Dense(units=len(tnews_labels), activation='softmax')(output)
ocnli_cls = Dense(units=len(ocnli_labels), activation='softmax')(output)
ocemotion_cls = Dense(units=len(ocemotion_labels), activation='softmax')(output)

y_input = Input(shape=(None, ))
type_input = Input(shape=(None,))

train_output = SwitchLoss(0)([tnews_cls, ocnli_cls, ocemotion_cls, y_input, type_input])

train_model = Model(bert.inputs + [y_input, type_input], train_output)

tnews_model = Model(bert.inputs, tnews_cls)
ocnli_model = Model(bert.inputs, ocnli_cls)
ocemotion_model = Model(bert.inputs, ocemotion_cls)

In [15]:
train_model.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
Input-Token (InputLayer)        (None, None)         0                                            
__________________________________________________________________________________________________
Input-Segment (InputLayer)      (None, None)         0                                            
__________________________________________________________________________________________________
input_1 (InputLayer)            (None, 1)            0                                            
__________________________________________________________________________________________________
Embedding-Token (Embedding)     (None, None, 768)    16226304    Input-Token[0][0]                
____________________________________________________________________________________________

In [16]:
tnews_model.inputs

[<tf.Tensor 'Input-Token:0' shape=(None, None) dtype=float32>,
 <tf.Tensor 'Input-Segment:0' shape=(None, None) dtype=float32>,
 <tf.Tensor 'input_1:0' shape=(None, 1) dtype=float32>]

In [17]:
grad_accum_steps = 3
Opt = extend_with_weight_decay(Adam)
Opt = extend_with_gradient_accumulation(Opt)
exclude_from_weight_decay = ['Norm', 'bias']
Opt = extend_with_piecewise_linear_lr(Opt)
para = {
    'learning_rate': 2e-5,
    'weight_decay_rate': 0.01,
    'exclude_from_weight_decay': exclude_from_weight_decay,
    'grad_accum_steps': grad_accum_steps,
    'lr_schedule': {int(len(train_generator) * 0.1 * epochs / grad_accum_steps): 1, int(len(train_generator) * epochs / grad_accum_steps): 0},
}

opt = Opt(**para)

train_model.compile(opt)



In [18]:
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, classification_report, f1_score


def get_f1(l_t, l_p):
    marco_f1_score = f1_score(l_t, l_p, average='macro')
    return marco_f1_score

def print_result(l_t, l_p):
    marco_f1_score = f1_score(l_t, l_p, average='macro')
    print(marco_f1_score)
    print(f"{'confusion_matrix':*^80}")
    print(confusion_matrix(l_t, l_p, ))
    print(f"{'classification_report':*^80}")
    print(classification_report(l_t, l_p, ))

In [19]:

d, _ = tnews_valid_generator.take()
d[2]

array([[0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0]])

In [20]:
def get_predict(model, data):
    preds, trues = [],[]
    for (t, s,task, y, _),_ in tqdm(data):
        pred = model.predict([t,s, task]).argmax(-1)
        preds.extend(pred.tolist())
        trues.extend(y.tolist())
    return trues, preds

    
def evaluate():
    tnews_trues, tnews_preds = get_predict(tnews_model, tnews_valid_generator)
    ocnli_trues, ocnli_preds = get_predict(ocnli_model, ocnli_valid_generator)
    ocemotion_trues, ocemotion_preds = get_predict(ocemotion_model, ocemotion_valid_generator)
    
    tnews_f1 = get_f1(tnews_trues, tnews_preds)
    ocnli_f1 = get_f1(ocnli_trues, ocnli_preds)
    ocemotion_f1 = get_f1(ocemotion_trues, ocemotion_preds)
    
    print_result(tnews_trues, tnews_preds)
    print_result(ocnli_trues, ocnli_preds)
    print_result(ocemotion_trues, ocemotion_preds)

    score = (tnews_f1 + ocnli_f1 + ocemotion_f1) / 3
    return score

In [21]:
class Evaluator(keras.callbacks.Callback):
    def __init__(self, save_path):
        self.save_path = save_path
        self.best_f1 = 0.
    
    def on_epoch_end(self, epoch, logs=None):
        avg_f1 = evaluate()
        if self.best_f1 < avg_f1:
            self.best_f1 = avg_f1
            self.model.save_weights(self.save_path)
        
        print('epoch: {} f1 is:{},  best f1 is:{}'.format(epoch +1, avg_f1, self.best_f1))

In [22]:
evaluate()

100%|██████████| 792/792 [00:17<00:00, 45.62it/s]
100%|██████████| 668/668 [00:19<00:00, 33.53it/s]
100%|██████████| 447/447 [00:29<00:00, 15.23it/s]
  _warn_prf(average, modifier, msg_start, len(result))


0.007639658843312296
********************************confusion_matrix********************************
[[   0    0    0    0  120    0    0   15    3    2    0    0    0    0
   701]
 [   0    0    2    0   96    0    0   29    6    2    3    1    0    0
   806]
 [   0    0    0    0   26    0    0    2    0    0    1    0    0    0
   239]
 [   0    0    1    0  340    0    0   39    2   10    2    0    0    0
   772]
 [   0    0    0    0  111    0    0    3    3    7    4    0    0    0
   708]
 [   0    0    0    0   90    0    1   11    0    4    0    0    0    0
   384]
 [   0    0    0    0  143    0    0   13    2    1    0    0    0    0
   658]
 [   0    0    2    0  129    0    0   12    5    6    1    0    0    0
   819]
 [   0    0    0    0  145    0    0   21    2   10    0    0    0    0
   807]
 [   0    0    2    0  184    0    0   21   10    1    1    0    0    0
   952]
 [   0    0    1    0   59    0    0    8    1    3    3    0    0    0
   547]
 [   0    0    0  

0.08836648988495747

In [23]:
model_save_path = 'best_adapt_model.weights'
evaluator = Evaluator(model_save_path)

train_model.fit_generator(train_generator.generator(),
                         steps_per_epoch=len(train_generator),
                         epochs=epochs,
                          callbacks=[evaluator]
                         )



Epoch 1/5


100%|██████████| 792/792 [00:16<00:00, 48.28it/s]
100%|██████████| 668/668 [00:20<00:00, 33.05it/s]
100%|██████████| 447/447 [00:30<00:00, 14.73it/s]


0.6247624202901726
********************************confusion_matrix********************************
[[614  21   2   9   6   1  13  16   6 107   2   2  11  31   0]
 [  0 575  18  29 107   3  76   8   1  73  17   9  17  12   0]
 [  1  22 181  10   2   4   8   2   1  18  11   1   4   3   0]
 [ 10  16  10 810  26   8  27  11  12  67  12  51  77  29   0]
 [ 15  41   6 104 482  27  21  10  27  32  39   8   8  16   0]
 [  0   7   2   5  23 373   9   2   3   6  17  23   1  19   0]
 [ 11  34  11  12  13  13 609   7   2  38  14   9  11  33   0]
 [ 50  12   3  13   9   4  16 717  10 113   5   7   7   8   0]
 [  6  11   5  19  37  11  13  11 679  38  14  25  14 102   0]
 [ 27  66  42  32  15   5  34  22   4 886   4   8   6  20   0]
 [  1   9  19  14  50   9  13   5   4   8 433  33   3  21   0]
 [  5  10   1  39  21  52  23   4  10  21  58 654   4 298   0]
 [ 43  20   4 237   7   1  18   9   6  40   3  11 489  19   0]
 [ 77   9   2  22  10  18  35   6  35  40  13 200  25 902   0]
 [  0   0   0   0 

100%|██████████| 792/792 [00:16<00:00, 48.07it/s]
100%|██████████| 668/668 [00:20<00:00, 32.98it/s]
100%|██████████| 447/447 [00:30<00:00, 14.71it/s]


0.773027970966573
********************************confusion_matrix********************************
[[ 657   16    0   14    2    1    7   14    3   65    0    1   30   31
     0]
 [   0  826    4   29   10    0   33    3    0   17    9    8    2    3
     1]
 [   0   27  187   16    0    5    5    1    3   17    3    1    2    1
     0]
 [   1    4    2  980   16    3    7    4    2   19    3   14   97   14
     0]
 [   2  110    2   91  516   10    8    8   22   18   29    7    6    7
     0]
 [   0   11    0    4   18  408    3    0    7    5    2   23    2    7
     0]
 [   2   41    5   18    9    6  649    8    6   23    4   15   14   17
     0]
 [   9    8    0   17    5    1    3  864   11   34    4    7    7    4
     0]
 [   1   16    3   35   32    5    3    5  784   16    2   22   11   50
     0]
 [  18   68   17   72    8    4   16   26    6  897    1    4   12   22
     0]
 [   0   22    9   21   17    9    5    0    5    6  490   29    4    5
     0]
 [   1    7    0   66

100%|██████████| 792/792 [00:16<00:00, 47.65it/s]
100%|██████████| 668/668 [00:20<00:00, 32.82it/s]
100%|██████████| 447/447 [00:30<00:00, 14.64it/s]


0.866887784449254
********************************confusion_matrix********************************
[[ 725    3    0    2    1    1    6    8    1   30    3    5    9   47
     0]
 [   0  834    1    6   16    1   13    9    2   34    7   17    2    2
     1]
 [   1    8  224    1    0    0    4    0    0   23    5    2    0    0
     0]
 [   4    0    2  959    7    1    1    3    3   17    8   39  108   14
     0]
 [   2   25    1   54  663   17    1    2   16   14   10   21    2    8
     0]
 [   0    3    0    2    2  429    0    0    2    4    4   42    0    2
     0]
 [   2    8    4    8    1    4  729    7    1   13    5   16    8   11
     0]
 [   5    2    0    2    3    2    5  917    5   15    0   10    3    5
     0]
 [   1    7    1    9   11    6    1    5  888    9    0   25    3   19
     0]
 [   6   15    7    8    3    1    5   18    1 1079    0    7    1   20
     0]
 [   1    1    1    4   11    1    2    0    1    0  565   34    0    1
     0]
 [   0    1    0    6

100%|██████████| 792/792 [00:16<00:00, 47.61it/s]
100%|██████████| 668/668 [00:20<00:00, 32.69it/s]
100%|██████████| 447/447 [00:30<00:00, 14.67it/s]


0.9322898989632218
********************************confusion_matrix********************************
[[ 765    4    1    2    4    0    3   11    2   13    2    4   14   16
     0]
 [   0  915    3    1    6    0    7    1    1    6    1    4    0    0
     0]
 [   0    3  261    1    0    0    0    0    0    3    0    0    0    0
     0]
 [   0    3    1 1085    3    1    2    0    5    7    0    1   56    2
     0]
 [   0   20    1   12  774    2    2    0    4    0   10    6    2    3
     0]
 [   0    2    2    1    7  461    0    0    1    0    0   12    1    3
     0]
 [   1    3    3    0    0    2  792    2    2    3    4    2    1    2
     0]
 [   1    4    1    1    4    2    1  931    4   12    2    2    5    3
     1]
 [   1    5    2    2   19    1    1    2  919    2    1    4    5   21
     0]
 [   4   18   15    3    2    1    4    5    0 1107    0    2    4    6
     0]
 [   0    4    1    2    2    3    0    0    1    0  605    4    0    0
     0]
 [   1    1    1   1

100%|██████████| 792/792 [00:16<00:00, 47.97it/s]
100%|██████████| 668/668 [00:20<00:00, 32.99it/s]
100%|██████████| 447/447 [00:30<00:00, 14.68it/s]


0.9535715351183268
********************************confusion_matrix********************************
[[ 791    2    0    2    2    0    3   13    4    5    2    1    3   13
     0]
 [   1  918    0    0    7    0    2    0    0    3    8    2    2    1
     1]
 [   0    1  267    0    0    0    0    0    0    0    0    0    0    0
     0]
 [   1    1    1 1131    3    1    2    1    2    2    1    3   14    3
     0]
 [   0    5    1    8  798    3    1    0    3    1   12    1    0    3
     0]
 [   0    0    1    0    1  477    0    1    1    0    1    5    0    3
     0]
 [   3    1    1    0    1    1  799    0    0    1    6    2    0    2
     0]
 [   1    5    1    2    4    0    0  950    0    6    1    2    1    0
     1]
 [   0    5    0    2    3    3    0    4  953    2    0    2    1   10
     0]
 [   3   17    4   10    1    1    2    1    3 1120    0    1    2    6
     0]
 [   0    1    1    0    3    1    0    0    0    0  615    1    0    0
     0]
 [   0    0    0    

<keras.callbacks.callbacks.History at 0x7f9275fd2d90>

In [27]:
train_model.load_weights(model_save_path)
evaluate()

100%|██████████| 792/792 [00:15<00:00, 50.88it/s]
100%|██████████| 668/668 [00:17<00:00, 37.18it/s]
100%|██████████| 447/447 [00:27<00:00, 16.51it/s]


0.9322898989632218
********************************confusion_matrix********************************
[[ 765    4    1    2    4    0    3   11    2   13    2    4   14   16
     0]
 [   0  915    3    1    6    0    7    1    1    6    1    4    0    0
     0]
 [   0    3  261    1    0    0    0    0    0    3    0    0    0    0
     0]
 [   0    3    1 1085    3    1    2    0    5    7    0    1   56    2
     0]
 [   0   20    1   12  774    2    2    0    4    0   10    6    2    3
     0]
 [   0    2    2    1    7  461    0    0    1    0    0   12    1    3
     0]
 [   1    3    3    0    0    2  792    2    2    3    4    2    1    2
     0]
 [   1    4    1    1    4    2    1  931    4   12    2    2    5    3
     1]
 [   1    5    2    2   19    1    1    2  919    2    1    4    5   21
     0]
 [   4   18   15    3    2    1    4    5    0 1107    0    2    4    6
     0]
 [   0    4    1    2    2    3    0    0    1    0  605    4    0    0
     0]
 [   1    1    1   1

0.721807193339421

In [28]:
def predict_to_file(result_path):
    _, tnews_preds = get_predict(tnews_model, tnews_test_generator)
    _, ocnli_preds = get_predict(ocnli_model, ocnli_test_generator)
    _, ocemotion_preds = get_predict(ocemotion_model, ocemotion_test_generator)
    
    tnews_result, ocnli_result, ocemotion_result = [], [], []
    
    for (d, p) in zip(tnews_test, tnews_preds):
        tnews_result.append({'id': d[0], 'label': tnews_id2label[p]})
    
    for (d, p) in zip(ocnli_test, ocnli_preds):
        ocnli_result.append({'id': d[0], 'label': ocnli_id2label[p]})
    
    for (d, p) in zip(ocemotion_test, ocemotion_preds):
        ocemotion_result.append({'id': d[0], 'label': ocemotion_id2label[p]})
        
    with open(os.path.join(result_path, 'tnews_predict.json'), 'w') as f:
        for d in tnews_result:
            f.write(json.dumps(d) + '\n')
    
    with open(os.path.join(result_path, 'ocnli_predict.json'), 'w') as f:
        for d in ocnli_result:
            f.write(json.dumps(d) + '\n')
    
    with open(os.path.join(result_path, 'ocemotion_predict.json'), 'w') as f:
        for d in ocemotion_result:
            f.write(json.dumps(d) + '\n')

In [29]:
predict_to_file('./result')

100%|██████████| 94/94 [00:01<00:00, 48.42it/s]
100%|██████████| 94/94 [00:02<00:00, 33.94it/s]
100%|██████████| 94/94 [00:05<00:00, 17.20it/s]
