In [43]:
import os
import json
from tqdm import tqdm


def load_data(filename):
    D = []
    with open(filename) as f:
        for i, l in enumerate(f):
            label = None
            items = l.strip().split('\t')
            if len(items) == 3:
                idx, text, label = items
#                 label = int(label)
            else:
                idx, text = items
            D.append((idx, text, label))
    return D

def load_ocnli_data(filename):
    D = []
    with  open(filename) as f:
        for l in f:
            label = None
            items = l.strip().split('\t')
            if len(items) == 4:
                idx, s1, s2, label = items
            else:
                idx, s1, s2 = items
            D.append((idx, s1, s2, label))
    return D


tnews_train = load_data('/home/mingming.xu/datasets/NLP/ptms_data/TNEWS_train1128.csv')
tnews_test = load_data('/home/mingming.xu/datasets/NLP/ptms_data/TNEWS_a.csv')

ocemotion_train = load_data('/home/mingming.xu/datasets/NLP/ptms_data/OCEMOTION_train1128.csv')
ocemotion_test = load_data('/home/mingming.xu/datasets/NLP/ptms_data/OCEMOTION_a.csv')

ocnli_train = load_ocnli_data('/home/mingming.xu/datasets/NLP/ptms_data/OCNLI_train1128.csv')
ocnli_test = load_ocnli_data('/home/mingming.xu/datasets/NLP/ptms_data/OCNLI_a.csv')

len(tnews_train), len(tnews_test), len(ocemotion_train), len(ocemotion_test), len(ocnli_train), len(ocnli_test)

(63360, 1500, 35694, 1500, 53387, 1500)

In [44]:
tnews_train[0], ocemotion_train[0], ocnli_train[0]

(('0', '上课时学生手机响个不停,老师一怒之下把手机摔了,家长拿发票让老师赔,大家怎么看待这种事?', '108'),
 ('0',
  "'你知道多伦多附近有什么吗?哈哈有破布耶...真的书上写的你听哦...你家那块破布是世界上最大的破布,哈哈,骗你的啦它是说尼加拉瓜瀑布是世界上最大的瀑布啦...哈哈哈''爸爸,她的头发耶!我们大扫除椅子都要翻上来我看到木头缝里有头发...一定是xx以前夹到的,你说是不是?[生病]",
  'sadness'),
 ('0', '一月份跟二月份肯定有一个月份有.', '肯定有一个月份有', '0'))

In [45]:
def precess_label(train_data):
    labels = set([d[-1] for d in train_data])
    label2id = {k:v for v, k in enumerate(labels)}
    id2label = {v:k for k, v in label2id.items()}
    return labels, label2id, id2label

tnews_labels, tnews_label2id, tnews_id2label = precess_label(tnews_train)
ocnli_labels, ocnli_label2id, ocnli_id2label = precess_label(ocnli_train)
ocemotion_labels, ocemotion_label2id, ocemotion_id2label = precess_label(ocemotion_train)

In [46]:
tnews_labels, ocnli_labels, ocemotion_labels

({'100',
  '101',
  '102',
  '103',
  '104',
  '106',
  '107',
  '108',
  '109',
  '110',
  '112',
  '113',
  '114',
  '115',
  '116'},
 {'0', '1', '2'},
 {'anger', 'disgust', 'fear', 'happiness', 'like', 'sadness', 'surprise'})

In [47]:
tnews_train = [d[:-1] + (tnews_label2id[d[-1]],) for d in tnews_train]
ocnli_train = [d[:-1] + (ocnli_label2id[d[-1]],) for d in ocnli_train]
ocemotion_train = [d[:-1] + (ocemotion_label2id[d[-1]],) for d in ocemotion_train]

tnews_train[0], ocnli_train[0], ocemotion_train[0]

(('0', '上课时学生手机响个不停,老师一怒之下把手机摔了,家长拿发票让老师赔,大家怎么看待这种事?', 10),
 ('0', '一月份跟二月份肯定有一个月份有.', '肯定有一个月份有', 2),
 ('0',
  "'你知道多伦多附近有什么吗?哈哈有破布耶...真的书上写的你听哦...你家那块破布是世界上最大的破布,哈哈,骗你的啦它是说尼加拉瓜瀑布是世界上最大的瀑布啦...哈哈哈''爸爸,她的头发耶!我们大扫除椅子都要翻上来我看到木头缝里有头发...一定是xx以前夹到的,你说是不是?[生病]",
  6))

In [48]:
from toolkit4nlp.models import *
from toolkit4nlp.layers import *
from toolkit4nlp.utils import *
from toolkit4nlp.optimizers import *
from toolkit4nlp.tokenizers import *
from toolkit4nlp.backend import *

import tensorflow as tf

In [49]:
#bert config

config_path = '/home/mingming.xu/pretrain/NLP/nezha_base_wwm/bert_config.json'
checkpoint_path = '/home/mingming.xu/pretrain/NLP/nezha_base_wwm/model.ckpt'
dict_path = '/home/mingming.xu/pretrain/NLP/nezha_base_wwm/vocab.txt'

# tokenizer
tokenizer = Tokenizer(dict_path, do_lower_case=True)

batch_size=16
maxlen=256
epochs = 5

In [50]:
class batch_data_generator(DataGenerator):
    def __init__(self, label_mask, **kwargs):
        super(batch_data_generator, self).__init__(**kwargs)
        self.label_mask = label_mask
        
    def __iter__(self, shuffle=False):
        batch_token_ids, batch_segment_ids, batch_labels, batch_label_mask = [], [], [], []
        for is_end, item in self.get_sample(shuffle):
            if len(item) == 4:
                _, q, r, l = item
                token_ids, segment_ids = tokenizer.encode(q,r, maxlen=maxlen)
            else:
                _, q, l = item
                token_ids, segment_ids = tokenizer.encode(q, maxlen=maxlen)
            
            batch_token_ids.append(token_ids)
            batch_segment_ids.append(segment_ids)
            batch_labels.append([l])
            batch_label_mask.append(self.label_mask)
            
            if is_end or self.batch_size == len(batch_token_ids):
                batch_token_ids = pad_sequences(batch_token_ids)
                batch_segment_ids = pad_sequences(batch_segment_ids)
                batch_label_mask = pad_sequences(batch_label_mask)
                batch_labels = pad_sequences(batch_labels)
                
                yield [batch_token_ids, batch_segment_ids, batch_labels, batch_label_mask], None
                batch_token_ids, batch_segment_ids, batch_labels, batch_label_mask = [], [], [], []
    



In [51]:
split = 0.8
tnews_mask = [1,0,0]
ocnli_mask = [0,1,0]
ocemotion_mask = [0,0,1]

def split_train_valid(data, split):
    n = int(len(data)*split)
    train_data = data[:n]
    valid_data = data[n:]
    return train_data, valid_data


tnews_train_data, tnews_valid_data = split_train_valid(tnews_train, split)
ocnli_train_data, ocnli_valid_data = split_train_valid(ocnli_train, split)
ocemotion_train_data, ocemotion_valid_data = split_train_valid(ocemotion_train, split)

tnews_train_generator = batch_data_generator(data=tnews_train, batch_size=batch_size, label_mask=tnews_mask)
tnews_valid_generator = batch_data_generator(data=tnews_valid_data, batch_size=batch_size, label_mask=tnews_mask)
tnews_test_generator = batch_data_generator(data=tnews_test, batch_size=batch_size, label_mask=tnews_mask)

ocnli_train_generator = batch_data_generator(data=ocnli_train_data, batch_size=batch_size, label_mask=ocnli_mask)
ocnli_valid_generator = batch_data_generator(data=ocnli_valid_data, batch_size=batch_size, label_mask=ocnli_mask)
ocnli_test_generator = batch_data_generator(data=ocnli_test, batch_size=batch_size, label_mask=ocnli_mask)


ocemotion_train_generator = batch_data_generator(data=ocemotion_train_data, batch_size=batch_size, label_mask=ocemotion_mask)
ocemotion_valid_generator = batch_data_generator(data=ocemotion_valid_data, batch_size=batch_size, label_mask=ocemotion_mask)
ocemotion_test_generator = batch_data_generator(data=ocemotion_test, batch_size=batch_size, label_mask=ocemotion_mask)

tnews_train_generator.take()

([array([[ 101,  677, 6440, 3198, 2110, 4495, 2797, 3322, 1510,  702,  679,
           977,  117, 5439, 2360,  671, 2584,  722,  678, 2828, 2797, 3322,
          3035,  749,  117, 2157, 7270, 2897, 1355, 4873, 6375, 5439, 2360,
          6608,  117, 1920, 2157, 2582,  720, 4692, 2521, 6821, 4905,  752,
           136,  102],
         [ 101, 1555, 6617, 4384, 4413, 5500,  819, 3300, 7361, 1062, 1385,
          1068,  754, 2454, 3309, 1726, 1908,  677, 3862, 6395, 1171,  769,
          3211, 2792, 2190, 1062, 1385, 8109, 2399, 2399, 2428, 2845, 1440,
          4638,  752, 1400, 2144, 3417, 7309, 6418, 1141, 4638, 1062, 1440,
           102,    0],
         [ 101, 6858, 6814,  704,  792, 1062, 1385,  743,  749,  753, 2797,
          2791,  117, 7674,  802, 6963,  802,  749,  117, 4385, 1762, 1297,
          2157,  679, 2682, 1297,  749,  511, 2582,  720, 1905, 4415,  136,
           102,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0],
         [ 101, 827

In [52]:
class SwitchLoss(Loss):
    """计算三种cls 的loss，然后通过 loss mask 过滤掉非当前任务的loss
    这里也可以利用loss mask对不同task 的loss 加权
    """
    def compute_loss(self, inputs, mask=None):
        tnew_pred, ocnli_pred, ocemotion_pred, y_true, type_input = inputs
        
        train_loss = tf.case([(tf.equal(tf.argmax(type_input[0]), 0), lambda: K.sparse_categorical_crossentropy(y_true, tnews_cls)),
                       (tf.equal(tf.argmax(type_input[0]), 1), lambda: K.sparse_categorical_crossentropy(y_true,ocnli_cls)),
                        (tf.equal(tf.argmax(type_input[0]), 2), lambda: K.sparse_categorical_crossentropy(y_true, ocemotion_cls))
                       ], exclusive=True)
        return K.mean(train_loss)
        

In [53]:
bert = build_transformer_model(checkpoint_path=checkpoint_path, config_path=config_path, model='nezha', with_pool=True)
output = Dropout(0.1)(bert.output)

tnews_cls = Dense(units=len(tnews_labels), activation='softmax')(output)
ocnli_cls = Dense(units=len(ocnli_labels), activation='softmax')(output)
ocemotion_cls = Dense(units=len(ocemotion_labels), activation='softmax')(output)

y_input = Input(shape=(None, ))
type_input = Input(shape=(None,))

train_output = SwitchLoss(0)([tnews_cls, ocnli_cls, ocemotion_cls, y_input, type_input])

train_model = Model(bert.inputs + [y_input, type_input], train_output)

tnews_model = Model(bert.inputs, tnews_cls)
ocnli_model = Model(bert.inputs, ocnli_cls)
ocemotion_model = Model(bert.inputs, ocemotion_cls)

In [54]:
train_model.summary()

Model: "model_7"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
Input-Token (InputLayer)        (None, None)         0                                            
__________________________________________________________________________________________________
Input-Segment (InputLayer)      (None, None)         0                                            
__________________________________________________________________________________________________
Embedding-Token (Embedding)     (None, None, 768)    16226304    Input-Token[0][0]                
__________________________________________________________________________________________________
Embedding-Segment (Embedding)   (None, None, 768)    1536        Input-Segment[0][0]              
____________________________________________________________________________________________

In [82]:
grad_accum_steps = 3
Opt = extend_with_weight_decay(Adam)
Opt = extend_with_gradient_accumulation(Opt)
exclude_from_weight_decay = ['Norm', 'bias']
Opt = extend_with_piecewise_linear_lr(Opt)
para = {
    'learning_rate': 2e-5,
    'weight_decay_rate': 0.01,
    'exclude_from_weight_decay': exclude_from_weight_decay,
    'grad_accum_steps': grad_accum_steps,
    'lr_schedule': {int(len(train_generator) * 0.1 * epochs / grad_accum_steps): 1, int(len(train_generator) * epochs / grad_accum_steps): 0},
}

opt = Opt(**para)

train_model.compile(opt)



In [72]:
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, classification_report, f1_score


def get_f1(l_t, l_p):
    marco_f1_score = f1_score(l_t, l_p, average='macro')
    return marco_f1_score

def print_result(l_t, l_p):
    marco_f1_score = f1_score(l_t, l_p, average='macro')
    print(marco_f1_score)
    print(f"{'confusion_matrix':*^80}")
    print(confusion_matrix(l_t, l_p, ))
    print(f"{'classification_report':*^80}")
    print(classification_report(l_t, l_p, ))

In [73]:
def get_predict(model, data):
    preds, trues = [],[]
    for (t, s, y, _),_ in tqdm(data):
        pred = model.predict([t,s]).argmax(-1)
        preds.extend(pred.tolist())
        trues.extend(y.tolist())
    return trues, preds

    
def evaluate():
    tnews_trues, tnews_preds = get_predict(tnews_model, tnews_valid_generator)
    ocnli_trues, ocnli_preds = get_predict(ocnli_model, ocnli_valid_generator)
    ocemotion_trues, ocemotion_preds = get_predict(ocemotion_model, ocemotion_valid_generator)
    
    tnews_f1 = get_f1(tnews_trues, tnews_preds)
    ocnli_f1 = get_f1(ocnli_trues, ocnli_preds)
    ocemotion_f1 = get_f1(ocemotion_trues, ocemotion_preds)
    
    print_result(tnews_trues, tnews_preds)
    print_result(ocnli_trues, ocnli_preds)
    print_result(ocemotion_trues, ocemotion_preds)

    score = (tnews_f1 + ocnli_f1 + ocemotion_f1) / 3
    return score

In [74]:
class Evaluator(keras.callbacks.Callback):
    def __init__(self, save_path):
        self.save_path = save_path
        self.best_f1 = 0.
    
    def on_epoch_end(self, epoch, logs=None):
        avg_f1 = evaluate()
        if self.best_f1 < avg_f1:
            self.best_f1 = avg_f1
            self.model.save_weights(self.save_path)
        
        print('epoch: {} f1 is:{},  best f1 is:{}'.format(epoch +1, avg_f1, self.best_f1))

In [59]:
evaluate()

100%|██████████| 792/792 [00:15<00:00, 50.83it/s]
100%|██████████| 668/668 [00:18<00:00, 35.33it/s]
100%|██████████| 447/447 [00:27<00:00, 15.99it/s]
  _warn_prf(average, modifier, msg_start, len(result))


0.0027631714609753586
********************************confusion_matrix********************************
[[   0    0    0    1    0    0    0    0    0  840    0    0    0    0
     0]
 [   0    0    0    0    0    0    0    0    0   56    0    0    0    0
     0]
 [   0    0    0    1    0    0    0    0    0 1393    0    0    0    0
     0]
 [   0    0    0    0    0    0    0    0    0  490    0    0    0    0
     0]
 [   0    0    0    0    0    0    0    0    0  974    0    0    0    0
     0]
 [   0    0    0    0    0    0    0    0    0 1171    0    0    0    0
     0]
 [   0    0    0    0    0    0    0    0    0  907    0    0    0    0
     0]
 [   0    0    0    5    0    0    0    0    0 1195    0    0    0    0
     0]
 [   0    0    0    0    0    0    0    0    0  985    0    0    0    0
     0]
 [   0    0    0    0    0    0    0    0    0  268    0    0    0    0
     0]
 [   0    0    0    0    0    0    0    0    0  817    0    0    0    0
     0]
 [   0    0    0 

0.06390782321885995

In [75]:
class data_generator(DataGenerator):
    def __iter__(self, shuffle=False):
        for is_end, item in self.get_sample(shuffle):
            yield item

train_batch_data = list(tnews_train_generator.__iter__(shuffle=True)) + list(ocnli_train_generator.__iter__(shuffle=True))
train_batch_data += list(ocemotion_train_generator.__iter__(shuffle=True))
train_generator = data_generator(data=train_batch_data, batch_size=1)

In [83]:
model_save_path = 'best_model.weights'
evaluator = Evaluator(save_path)

train_model.fit_generator(train_generator.generator(),
                         steps_per_epoch=len(train_generator),
                         epochs=epochs,
                          callbacks=[evaluator]
                         )



Epoch 1/5


100%|██████████| 792/792 [00:15<00:00, 50.10it/s]
100%|██████████| 668/668 [00:19<00:00, 34.67it/s]
100%|██████████| 447/447 [00:29<00:00, 15.10it/s]


0.9692006472762918
********************************confusion_matrix********************************
[[ 812    0   11    0    2    4    4    1    1    0    2    1    1    2
     0]
 [   0   56    0    0    0    0    0    0    0    0    0    0    0    0
     0]
 [   1    0 1349    0    0    4    0   29    1    0    0    1    5    2
     2]
 [   0    0    1  471    1    1    0   10    2    0    0    2    0    1
     1]
 [   3    0    0    2  952    6    1    2    0    0    0    0    3    4
     1]
 [   2    0    1    1    3 1130    4    1    0    0    1    0    3   25
     0]
 [   1    0    3    0    2    2  851    0    0    0    1    0   38    8
     1]
 [   0    4   29    0    0    0    1 1154    0    0    1    8    3    0
     0]
 [   0    0   30    2    1    1    2    2  938    1    1    0    2    4
     1]
 [   0    0    0    0    0    4    0    0    0  263    0    0    0    1
     0]
 [   1    0    1    1    1    1    0    1    0    1  800    3    0    6
     1]
 [   0    0    2    

100%|██████████| 792/792 [00:15<00:00, 50.81it/s]
100%|██████████| 668/668 [00:19<00:00, 34.44it/s]
100%|██████████| 447/447 [00:29<00:00, 15.01it/s]


0.9750340960011027
********************************confusion_matrix********************************
[[ 819    0    6    0    5    2    4    1    0    0    2    0    1    1
     0]
 [   0   54    0    0    0    0    0    2    0    0    0    0    0    0
     0]
 [   6    0 1357    1    1    2    0    9    9    0    1    1    7    0
     0]
 [   0    0    2  477    0    1    0    4    2    0    0    2    0    0
     2]
 [   0    1    0    0  954    7    2    2    2    0    0    0    3    3
     0]
 [   1    0    3    0    1 1152    2    0    1    1    0    1    5    4
     0]
 [   0    0    1    0    1    1  864    1    1    0    0    0   35    2
     1]
 [   0    1   17    4    1    1    0 1157    0    0    0    8   11    0
     0]
 [   0    0    6    1    1    1    0    3  966    0    0    0    0    4
     3]
 [   0    0    0    0    0    0    0    0    0  266    0    0    1    1
     0]
 [   2    0    4    1    1    0    0    1    0    2  797    4    1    3
     1]
 [   0    0    0    

100%|██████████| 792/792 [00:15<00:00, 50.07it/s]
100%|██████████| 668/668 [00:19<00:00, 34.59it/s]
100%|██████████| 447/447 [00:29<00:00, 15.02it/s]


0.9787213268832121
********************************confusion_matrix********************************
[[ 830    0    4    0    2    2    1    1    0    0    0    0    1    0
     0]
 [   0   56    0    0    0    0    0    0    0    0    0    0    0    0
     0]
 [   5    0 1362    1    1    4    1   13    1    0    3    1    0    1
     1]
 [   0    1    2  479    0    1    0    2    2    0    0    2    0    0
     1]
 [   5    1    1    0  960    1    0    0    0    0    0    0    1    3
     2]
 [   1    0    0    1    6 1154    2    0    0    2    0    1    0    4
     0]
 [   2    0    2    0    1    1  852    2    0    0    1    0   39    4
     3]
 [   0    5    7    2    1    0    0 1173    1    0    1    8    1    0
     1]
 [   1    0    4    1    0    0    1    0  970    0    0    0    0    2
     6]
 [   0    0    0    0    0    0    0    0    0  266    0    2    0    0
     0]
 [   0    0    1    0    1    0    0    0    0    0  812    2    0    1
     0]
 [   0    0    0    

100%|██████████| 792/792 [00:15<00:00, 49.95it/s]
100%|██████████| 668/668 [00:19<00:00, 34.57it/s]
100%|██████████| 447/447 [00:29<00:00, 15.04it/s]


0.983222848493567
********************************confusion_matrix********************************
[[ 822    0    7    0    5    2    1    2    0    0    1    0    0    1
     0]
 [   0   56    0    0    0    0    0    0    0    0    0    0    0    0
     0]
 [   2    0 1368    1    2    5    2   11    1    0    0    0    0    2
     0]
 [   0    0    0  476    0    0    0    4    4    0    0    1    0    0
     5]
 [   0    1    0    0  965    4    0    0    1    0    0    0    0    3
     0]
 [   3    0    1    1    0 1153    2    0    0    1    0    0    2    8
     0]
 [   3    0    0    0    0    1  881    1    0    0    0    0   18    2
     1]
 [   0    1    7    0    2    0    1 1185    0    0    1    3    0    0
     0]
 [   0    0    6    1    0    0    0    0  972    0    0    0    0    3
     3]
 [   0    0    0    0    0    0    0    0    0  267    0    1    0    0
     0]
 [   2    0    3    0    1    0    1    0    2    0  803    3    0    1
     1]
 [   0    0    0    0

100%|██████████| 792/792 [00:15<00:00, 49.99it/s]
100%|██████████| 668/668 [00:19<00:00, 34.53it/s]
100%|██████████| 447/447 [00:29<00:00, 15.02it/s]


0.9807532205109957
********************************confusion_matrix********************************
[[ 826    0    6    0    3    3    1    1    0    0    1    0    0    0
     0]
 [   0   56    0    0    0    0    0    0    0    0    0    0    0    0
     0]
 [   4    0 1371    0    2    2    0    7    3    0    1    1    1    1
     1]
 [   0    0    1  485    0    0    0    2    1    0    0    0    0    0
     1]
 [   1    1    1    0  966    1    0    0    0    0    0    0    0    3
     1]
 [   1    0    1    0    6 1155    2    0    0    0    2    0    0    4
     0]
 [   1    0    3    0    2    1  865    1    0    0    2    1   27    3
     1]
 [   0    3   16    8    2    0    0 1153    0    0    0   12    3    1
     2]
 [   0    0    4    1    1    0    1    0  970    0    0    0    1    4
     3]
 [   0    0    0    0    0    1    1    0    0  262    2    1    0    1
     0]
 [   0    0    2    0    0    0    0    0    0    0  812    2    0    1
     0]
 [   0    0    0    

<keras.callbacks.callbacks.History at 0x7f1d90cc2e20>

In [79]:
train_model.load_weights(model_save_path)
evaluate()

100%|██████████| 792/792 [00:14<00:00, 55.96it/s]
100%|██████████| 668/668 [00:17<00:00, 39.13it/s]
100%|██████████| 447/447 [00:26<00:00, 16.89it/s]


0.9566871717092731
********************************confusion_matrix********************************
[[ 814    0    6    1    7    2    1    1    4    0    1    0    2    2
     0]
 [   0   55    0    0    0    0    0    1    0    0    0    0    0    0
     0]
 [   9    0 1321    2    1    4    2   37    6    0    2    1    4    2
     3]
 [   0    0    1  475    1    0    0    8    1    0    0    2    0    0
     2]
 [   0    1    4    1  949    6    1    1    2    0    2    1    1    4
     1]
 [   3    0    7    1    3 1134    1    1    0    7    2    0    4    7
     1]
 [   8    0    3    0    2    1  800    2    0    0    0    1   83    4
     3]
 [   0    2   40    8    1    0    1 1123    2    1    0    8   12    1
     1]
 [   1    0   20    3    2    4    3    3  923    2    0    0    3    5
    16]
 [   0    0    0    0    0    2    0    0    0  262    1    2    0    1
     0]
 [   0    0    2    2    2    1    2    4    0    2  795    2    3    1
     1]
 [   0    0    2    

0.7277198571870023

In [80]:
def predict_to_file(result_path):
    _, tnews_preds = get_predict(tnews_model, tnews_test_generator)
    _, ocnli_preds = get_predict(ocnli_model, ocnli_test_generator)
    _, ocemotion_preds = get_predict(ocemotion_model, ocemotion_test_generator)
    
    tnews_result, ocnli_result, ocemotion_result = [], [], []
    
    for (d, p) in zip(tnews_test, tnews_preds):
        tnews_result.append({'id': d[0], 'label': tnews_id2label[p]})
    
    for (d, p) in zip(ocnli_test, ocnli_preds):
        ocnli_result.append({'id': d[0], 'label': ocnli_id2label[p]})
    
    for (d, p) in zip(ocemotion_test, ocemotion_preds):
        ocemotion_result.append({'id': d[0], 'label': ocemotion_id2label[p]})
        
    with open(os.path.join(result_path, 'tnews_predict.json'), 'w') as f:
        for d in tnews_result:
            f.write(json.dumps(d) + '\n')
    
    with open(os.path.join(result_path, 'ocnli_predict.json'), 'w') as f:
        for d in ocnli_result:
            f.write(json.dumps(d) + '\n')
    
    with open(os.path.join(result_path, 'ocemotion_predict.json'), 'w') as f:
        for d in ocemotion_result:
            f.write(json.dumps(d) + '\n')

In [81]:
predict_to_file('./result')

100%|██████████| 94/94 [00:01<00:00, 53.05it/s]
100%|██████████| 94/94 [00:02<00:00, 35.43it/s]
100%|██████████| 94/94 [00:05<00:00, 17.50it/s]
