In [1]:
import torch
from tqdm import tqdm
from transformers import (
    AutoConfig,
    AutoModelForMultipleChoice,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    default_data_collator,
    set_seed,
)
import numpy as np
# np.argmax([1,2,3,4])

In [2]:
def softmax(x, axis=None):
    x = np.array(x)
    x = x - x.max(axis=axis, keepdims=True)
    y = np.exp(x)
    rst = y / y.sum(axis=axis, keepdims=True)
    return rst.tolist()
softmax([1,2])

[0.2689414213699951, 0.7310585786300049]

In [3]:
tokenizer = AutoTokenizer.from_pretrained('ethanyt/guwenbert-base')

In [4]:
import json
with open('data/valid.jsonl','r') as f:
    valid_json = f.readlines()
    valid_json = [json.loads(e) for e in valid_json]

In [47]:
translations = [e['translation'] for e in valid_json][:10][-5]

In [48]:
choices = [e['choices'][0] for e in valid_json][:10][-5]

In [49]:
choices

'数声好鸟不知处，千丈藤萝古木昏。'

In [52]:
def split_string(c,l = 5,remove_dot = True):
            if remove_dot:
                if l == 5:
                    return [c[:2],c[2:5]]
                elif l == 7:
                    return [c[:2],c[2:4],c[4:]]
                elif l == 12:
                    return [c[:2],c[2:5],c[6:8],c[8:11]]
                elif l == 16:
                    return [c[:2],c[2:4],c[4:7],c[8:10],c[10:12],c[12:15]]
            else:
                if l == 5:
                    return [c[:2],c[2:5]]
                elif l == 7:
                    return [c[:2],c[2:4],c[4:]]
                elif l == 12:
                    return [c[:2],c[2:5],c[5:6],c[6:8],c[8:11],c[11:12]]
                elif l == 16:
                    return [c[:2],c[2:4],c[4:7],c[7:8],c[8:10],c[10:12],c[12:15],c[15:16]]
def generate_ner_label(true_spans, split_choice, max_seq_length,len_translation):
            ner_label = [0] * max_seq_length
            prefix_len = len_translation + 2 # [CLS] 1st [SEP] 2nd [SEP] [PAD] [PAD]...
            this_start = prefix_len
            this_len = 0
            for each_span in split_choice:
                this_len = len(each_span)
                if each_span in true_spans:
                    ner_label[this_start:this_start+this_len] = [1] * this_len
                this_start += this_len
            return ner_label

In [60]:
true_spans = [split_string(choices,len(choices))[1],split_string(choices,len(choices))[-1]]
true_spans

['好鸟', '古木昏']

In [61]:
split_choice = split_string(choices,len(choices),False)
split_choice

['数声', '好鸟', '不知处', '，', '千丈', '藤萝', '古木昏', '。']

In [63]:
len_translation = len(translations)
len_translation

48

In [64]:
max_seq_length = 100

In [65]:
tokenized_examples = tokenizer(
            translations,
            choices,
            truncation=True,
            max_length=max_seq_length,
            padding="max_length",
        )

In [68]:
ner_labels = generate_ner_label(true_spans,split_choice,max_seq_length,len_translation)

In [69]:
for i in range(len(tokenized_examples['input_ids'])):
    print(tokenizer.decode(tokenized_examples['input_ids'][i]),'\t',tokenized_examples['token_type_ids'][i],'\t',ner_labels[i])

[CLS] 	 0 	 0
这 	 0 	 0
时 	 0 	 0
从 	 0 	 0
竹 	 0 	 0
林 	 0 	 0
中 	 0 	 0
传 	 0 	 0
出 	 0 	 0
阵 	 0 	 0
阵 	 0 	 0
鸟 	 0 	 0
叫 	 0 	 0
， 	 0 	 0
可 	 0 	 0
是 	 0 	 0
偌 	 0 	 0
大 	 0 	 0
的 	 0 	 0
竹 	 0 	 0
林 	 0 	 0
中 	 0 	 0
却 	 0 	 0
发 	 0 	 0
现 	 0 	 0
不 	 0 	 0
了 	 0 	 0
鸟 	 0 	 0
儿 	 0 	 0
的 	 0 	 0
位 	 0 	 0
置 	 0 	 0
， 	 0 	 0
只 	 0 	 0
见 	 0 	 0
到 	 0 	 0
长 	 0 	 0
长 	 0 	 0
的 	 0 	 0
藤 	 0 	 0
萝 	 0 	 0
和 	 0 	 0
黄 	 0 	 0
昏 	 0 	 0
中 	 0 	 0
的 	 0 	 0
古 	 0 	 0
木 	 0 	 0
。 	 0 	 0
[SEP] 	 0 	 0
数 	 1 	 0
声 	 1 	 0
好 	 1 	 1
鸟 	 1 	 1
不 	 1 	 0
知 	 1 	 0
处 	 1 	 0
， 	 1 	 0
千 	 1 	 0
丈 	 1 	 0
藤 	 1 	 0
萝 	 1 	 0
古 	 1 	 1
木 	 1 	 1
昏 	 1 	 1
。 	 1 	 0
[SEP] 	 1 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 

In [26]:
tokenized_examples.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

In [17]:
choices[0]

'渔灯灭复明'

In [32]:
len(translations)

14

In [40]:
def generate_ner_label(true_spans, split_choice, max_seq_length,len_translation):
            ner_label = [0] * max_seq_length
            prefix_len = len_translation + 2 # [CLS] 1st [SEP] 2nd [SEP] [PAD] [PAD]...
            this_start = prefix_len
            this_len = 0
            for each_span in split_choice:
                this_len = len(each_span)
                if each_span in true_spans:
                    ner_label[this_start:this_start+this_len] = [1] * this_len
                this_start += this_len
            return ner_label

In [None]:
true_spans = ['']

In [38]:
max_length = 30
first_ner_label = [0] * max_length
prefix_len = len(translations)+2
this_start = 2
this_len = 3
first_ner_label[prefix_len+this_start:prefix_len+this_start+this_len] = [1] * this_len
print(first_ner_label)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [39]:
for i in range(len(tokenized_examples['input_ids'])):
    print(tokenizer.decode(tokenized_examples['input_ids'][i]),'\t',tokenized_examples['token_type_ids'][i],'\t',first_ner_label[i])

[CLS] 	 0 	 0
昏 	 0 	 0
暗 	 0 	 0
的 	 0 	 0
灯 	 0 	 0
熄 	 0 	 0
灭 	 0 	 0
了 	 0 	 0
又 	 0 	 0
被 	 0 	 0
重 	 0 	 0
新 	 0 	 0
点 	 0 	 0
亮 	 0 	 0
。 	 0 	 0
[SEP] 	 0 	 0
渔 	 1 	 0
灯 	 1 	 0
灭 	 1 	 1
复 	 1 	 1
明 	 1 	 1
[SEP] 	 1 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0
[PAD] 	 0 	 0


In [1]:
from dataload.binary_ner_dataset import *

In [3]:
import torch
from tqdm import tqdm
from transformers import (
    AutoConfig,
    AutoModelForMultipleChoice,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    default_data_collator,
    set_seed,
)
import numpy as np
# np.argmax([1,2,3,4])
tokenizer = AutoTokenizer.from_pretrained('ethanyt/guwenbert-base')

In [2]:
dataloader = unit_test()

Using custom data configuration default-6cb4e0ff663b7510
Reusing dataset json (/home/zhangkechi/.cache/huggingface/datasets/json/default-6cb4e0ff663b7510/0.0.0/c2d554c3377ea79c7664b93dc65d0803b45e3279000f993c7bfd18937fd7f426)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?ba/s]

In [4]:
for batch in dataloader:
    print(batch)
    break

{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'input_ids': tensor([[  0, 153,  11,  ...,   1,   1,   1],
        [  0, 153,  11,  ...,   1,   1,   1],
        [  0, 153,  11,  ...,   1,   1,   1],
        [  0, 153,  11,  ...,   1,   1,   1]]), 'origin_idx': tensor([0, 0, 0, 0]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'labels': tensor([[-100, -100, -100,  ..., -100, -100, -100],
        [-100, -100, -100,  ..., -100, -100, -100],
        [-100, -100, -100,  ..., -100, -100, -100],
        [-100, -100, -100,  ..., -100, -100, -100]])}


In [10]:
batch['labels'].ne(-100).long().shape

torch.Size([4, 1024])

In [11]:
batch['token_type_ids'].dtype

torch.int64

In [12]:
batch_idx = 3
for i in range(batch['input_ids'].shape[1]):
    print(batch['token_type_ids'][batch_idx][i].item(),'\t',batch['labels'].ne(-100).long()[batch_idx][i].item())

0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
1 	 1
1 	 1
1 	 1
1 	 1
1 	 1
1 	 1
1 	 1
1 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 0
0 	 

In [5]:
batch_idx = 3
for i in range(batch['input_ids'].shape[1]):
    print(tokenizer.decode(batch['input_ids'][batch_idx][i].item()),'\t',batch['labels'][batch_idx][i].item(),'\t',batch['attention_mask'][batch_idx][i].item(),'\t',batch['token_type_ids'][batch_idx][i].item())

[CLS] 	 -100 	 1 	 0
诗 	 -100 	 1 	 0
人 	 -100 	 1 	 0
啊 	 -100 	 1 	 0
， 	 -100 	 1 	 0
你 	 -100 	 1 	 0
竟 	 -100 	 1 	 0
像 	 -100 	 1 	 0
在 	 -100 	 1 	 0
遥 	 -100 	 1 	 0
远 	 -100 	 1 	 0
的 	 -100 	 1 	 0
地 	 -100 	 1 	 0
方 	 -100 	 1 	 0
站 	 -100 	 1 	 0
立 	 -100 	 1 	 0
船 	 -100 	 1 	 0
头 	 -100 	 1 	 0
。 	 -100 	 1 	 0
[SEP] 	 -100 	 1 	 0
行 	 0 	 1 	 1
人 	 0 	 1 	 1
迢 	 0 	 1 	 1
递 	 0 	 1 	 1
木 	 1 	 1 	 1
兰 	 1 	 1 	 1
舟 	 1 	 1 	 1
[SEP] 	 -100 	 1 	 1
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 

[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 

[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 	 0 	 0
[PAD] 	 -100 