In [1]:
import torch
from tqdm import tqdm
from transformers import (
    AutoConfig,
    AutoModelForMultipleChoice,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    default_data_collator,
    set_seed,
)
import numpy as np
# np.argmax([1,2,3,4])

In [2]:
def softmax(x, axis=None):
    x = np.array(x)
    x = x - x.max(axis=axis, keepdims=True)
    y = np.exp(x)
    rst = y / y.sum(axis=axis, keepdims=True)
    return rst.tolist()
softmax([1,2])

[0.2689414213699951, 0.7310585786300049]

In [3]:
tokenizer = AutoTokenizer.from_pretrained('ethanyt/guwenbert-base')

In [4]:
import json
with open('data/valid.jsonl','r') as f:
    valid_json = f.readlines()
    valid_json = [json.loads(e) for e in valid_json]

In [5]:
import pickle
with open('predict/ner/test_result.pkl','rb') as f:
    rst = pickle.load(f)

In [7]:
rst['test_dataset'],rst['test_dataset'][0].keys()

(Dataset({
     features: ['attention_mask', 'input_ids', 'labels', 'origin_idx', 'token_type_ids'],
     num_rows: 10880
 }),
 dict_keys(['attention_mask', 'input_ids', 'labels', 'origin_idx', 'token_type_ids']))

In [8]:
tokenizer.decode(rst['test_dataset'][0]['input_ids'])

'[CLS] 昏 暗 的 灯 熄 灭 了 又 被 重 新 点 亮 。 [SEP] 渔 灯 灭 复 明 [SEP]'

In [28]:
rst['test_result'].label_ids.shape

(10880, 97)

In [57]:
batch_idx = 1
input_tokens = tokenizer.decode(rst['test_dataset'][batch_idx]['input_ids']).split(' ')
predicts = softmax(rst['test_result'].predictions[batch_idx],axis=-1)
labels = rst['test_result'].label_ids[batch_idx]
for input_token in range(len(input_tokens)):
    print(input_tokens[input_token], '\t',predicts[input_token][1],'\t',labels[input_token])

[CLS] 	 0.036833055317401886 	 -100
昏 	 0.9407833814620972 	 -100
暗 	 0.9533089995384216 	 -100
的 	 0.728804349899292 	 -100
灯 	 0.15066255629062653 	 -100
熄 	 0.47069135308265686 	 -100
灭 	 0.47764626145362854 	 -100
了 	 0.7924788594245911 	 -100
又 	 0.9604582786560059 	 -100
被 	 0.8905993700027466 	 -100
重 	 0.9913784265518188 	 -100
新 	 0.9896129965782166 	 -100
点 	 0.9938084483146667 	 -100
亮 	 0.965327262878418 	 -100
。 	 0.0014975041849538684 	 -100
[SEP] 	 0.0004231913771945983 	 -100
残 	 0.00033823924604803324 	 1
灯 	 0.00035732300602830946 	 1
灭 	 0.003181070787832141 	 0
又 	 0.0022221505641937256 	 0
然 	 0.0022334421519190073 	 0
[SEP] 	 0.0018983433255925775 	 -100


[16, 17, 18, 19, 20]

[0.00033823924604803324,
 0.00035732300602830946,
 0.003181070787832141,
 0.0022221505641937256,
 0.0022334421519190073]

In [90]:
all_scores = {}
for batch_idx in tqdm(range(len(rst['test_dataset']))):
    origin_id = rst['test_dataset'][batch_idx]['origin_idx']
    if origin_id not in all_scores:
        all_scores[origin_id] = []
    predicts = softmax(rst['test_result'].predictions[batch_idx],axis=-1)
    labels = rst['test_result'].label_ids[batch_idx]
    all_scores[origin_id].append(np.array(predicts)[np.where(labels!=-100)[0].tolist()][:,1].tolist())

100%|████████████████████████████████████████████████████████████████████████| 10880/10880 [00:09<00:00, 1155.42it/s]


In [62]:
valid_json[0]

{'translation': '昏暗的灯熄灭了又被重新点亮。',
 'choices': ['渔灯灭复明', '残灯灭又然', '残灯暗复明', '残灯灭又明'],
 'answer': 3}

In [111]:
def multiple_list(l):
    rst = 1
    for e in l:
        rst *= e
    return rst
predict_valid_json = []
for i in range(len(valid_json)):
    choices = valid_json[i]['choices']
    this_scores = []
    for choice_id in range(len(choices)):
        choice_score = all_scores[i][choice_id]
        choice_score = multiple_list(choice_score)
#         choice_score = sum(choice_score)
        this_scores.append(choice_score)
    this_dict = {k:valid_json[i][k] for k in valid_json[i]}
    this_dict['scores'] = this_scores
    this_dict['predict'] = np.argmax(this_scores)
    this_dict['true'] = (this_dict['predict'] == this_dict['answer'])
    predict_valid_json.append(this_dict)

In [112]:
predict_valid_json[0]

{'translation': '昏暗的灯熄灭了又被重新点亮。',
 'choices': ['渔灯灭复明', '残灯灭又然', '残灯暗复明', '残灯灭又明'],
 'answer': 3,
 'scores': [2.202495009190283e-21,
  1.9081235868075017e-15,
  8.545825504510408e-15,
  4.655045091572122e-11],
 'predict': 3,
 'true': True}

In [113]:
true_num = np.sum([int(e['true']) for e in predict_valid_json]) / len(predict_valid_json)

In [118]:
[e for e in predict_valid_json if not e['true']][-10]

{'translation': '温泉水中萃集了万物精华。',
 'choices': ['精览万殊入', '万象入冥搜', '万象划然殊', '含灵万象入'],
 'answer': 0,
 'scores': [1.974271941932789e-24,
  1.9784688151084597e-24,
  2.1180516508889757e-24,
  2.456173303756735e-21],
 'predict': 3,
 'true': False}

In [116]:
true_num

0.8264705882352941