In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, Trainer, TrainingArguments, DefaultDataCollator

In [15]:
datasets = load_dataset('roberthsu2003/for_MRC_QA', cache_dir='data')
datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'context', 'question', 'answers'],
        num_rows: 26936
    })
    validation: Dataset({
        features: ['id', 'context', 'question', 'answers'],
        num_rows: 3524
    })
    test: Dataset({
        features: ['id', 'context', 'question', 'answers'],
        num_rows: 3493
    })
})

In [16]:
datasets['train'][0]

{'id': '1001-10-1',
 'context': '2010年引進的廣州快速公交運輸系統，屬世界第二大快速公交系統，日常載客量可達100萬人次，高峰時期每小時單向客流高達26900人次，僅次於波哥大的快速交通系統，平均每10秒鐘就有一輛巴士，每輛巴士單向行駛350小時。包括橋樑在內的站台是世界最長的州快速公交運輸系統站台，長達260米。目前廣州市區的計程車和公共汽車主要使用液化石油氣作燃料，部分公共汽車更使用油電、氣電混合動力技術。2012年底開始投放液化天然氣燃料的公共汽車，2014年6月開始投放液化天然氣插電式混合動力公共汽車，以取代液化石油氣公共汽車。2007年1月16日，廣州市政府全面禁止在市區內駕駛摩托車。違反禁令的機動車將會予以沒收。廣州市交通局聲稱禁令的施行，使得交通擁擠問題和車禍大幅減少。廣州白雲國際機場位於白雲區與花都區交界，2004年8月5日正式投入運營，屬中國交通情況第二繁忙的機場。該機場取代了原先位於市中心的無法滿足日益增長航空需求的舊機場。目前機場有三條飛機跑道，成為國內第三個擁有三跑道的民航機場。比鄰近的香港國際機場第三跑道預計的2023年落成早8年。',
 'question': '廣州的快速公交運輸系統每多久就會有一輛巴士？',
 'answers': {'answer_start': [84], 'text': ['10秒鐘']}}

In [10]:
tokenizer = AutoTokenizer.from_pretrained('google-bert/bert-base-chinese')
tokenizer

BertTokenizerFast(name_or_path='google-bert/bert-base-chinese', vocab_size=21128, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

In [19]:
sample_datasets = datasets['train'].select(range(10))
sample_datasets[1]

{'id': '1001-10-2',
 'context': '2010年引進的廣州快速公交運輸系統，屬世界第二大快速公交系統，日常載客量可達100萬人次，高峰時期每小時單向客流高達26900人次，僅次於波哥大的快速交通系統，平均每10秒鐘就有一輛巴士，每輛巴士單向行駛350小時。包括橋樑在內的站台是世界最長的州快速公交運輸系統站台，長達260米。目前廣州市區的計程車和公共汽車主要使用液化石油氣作燃料，部分公共汽車更使用油電、氣電混合動力技術。2012年底開始投放液化天然氣燃料的公共汽車，2014年6月開始投放液化天然氣插電式混合動力公共汽車，以取代液化石油氣公共汽車。2007年1月16日，廣州市政府全面禁止在市區內駕駛摩托車。違反禁令的機動車將會予以沒收。廣州市交通局聲稱禁令的施行，使得交通擁擠問題和車禍大幅減少。廣州白雲國際機場位於白雲區與花都區交界，2004年8月5日正式投入運營，屬中國交通情況第二繁忙的機場。該機場取代了原先位於市中心的無法滿足日益增長航空需求的舊機場。目前機場有三條飛機跑道，成為國內第三個擁有三跑道的民航機場。比鄰近的香港國際機場第三跑道預計的2023年落成早8年。',
 'question': '從哪一天開始在廣州市內騎摩托車會被沒收？',
 'answers': {'answer_start': [256], 'text': ['2007年1月16日']}}

In [20]:
#自已定義簡單資料
source_json = [{
    'id':'lesson1',
    'context':'小英的生日是1997年10月9日,女性',
    'question':'小英的生日是?',
    'answers':{'text':['1997年10月9日'],'answer_start':[6]}
    },
    {
    'id':'lesson2',
    'context':'川普日前宣布課徵加拿大和墨西哥25%的關稅',
    'question':'加拿大和墨西哥被課徵的關稅是',
    'answers':{'text':['25%'],'answer_start':[15]}
    }

]

In [81]:
from datasets import DatasetDict,Dataset

train_dataset = Dataset.from_list(source_json)
datasets = DatasetDict({
    'train':train_dataset
})
datasets
sample_dataset = datasets['train']
sample_dataset['answers']

[{'answer_start': [6], 'text': ['1997年10月9日']},
 {'answer_start': [15], 'text': ['25%']}]

In [112]:
from pprint import pprint
tokenized_examples = tokenizer(text=datasets['train']['question'],
                               text_pair=datasets['train']['context'],
                               max_length=512,
                               padding="max_length",
                               truncation='only_second')
pprint(tokenized_examples,compact=True,width=100)

{'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
                     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          

In [113]:
pprint(list(zip(tokenized_examples['input_ids'][0], tokenized_examples['token_type_ids'][0])),compact=True)

[(101, 0), (2207, 0), (5739, 0), (4638, 0), (4495, 0), (3189, 0), (3221, 0),
 (136, 0), (102, 0), (2207, 1), (5739, 1), (4638, 1), (4495, 1), (3189, 1),
 (3221, 1), (8387, 1), (2399, 1), (8108, 1), (3299, 1), (130, 1), (3189, 1),
 (117, 1), (1957, 1), (2595, 1), (102, 1), (0, 0), (0, 0), (0, 0), (0, 0),
 (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0),
 (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0),
 (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0),
 (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0),
 (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0),
 (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0),
 (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0),
 (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0),
 (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)

In [114]:
from pprint import pprint
tokenized_examples = tokenizer(text=datasets['train']['question'],
                               text_pair=datasets['train']['context'],
                               max_length=512,
                               return_offsets_mapping=True,
                               #padding="max_length",
                               truncation='only_second')
print(tokenized_examples.keys())

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping'])


In [115]:
pprint(tokenized_examples['offset_mapping'][0],compact=True)

[(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (0, 0), (0, 1),
 (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 10), (10, 11), (11, 13), (13, 14),
 (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (0, 0)]


In [116]:
offset_mapping = tokenized_examples.pop("offset_mapping")
pprint(offset_mapping[0],compact=True)

[(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (0, 0), (0, 1),
 (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 10), (10, 11), (11, 13), (13, 14),
 (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (0, 0)]


In [117]:
print(tokenized_examples.sequence_ids(0))

[None, 0, 0, 0, 0, 0, 0, 0, None, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, None]


In [122]:
for idx, offset in enumerate(offset_mapping):
    #pprint((idx, offset),compact=True)
    answer = sample_dataset['answers'][idx]
    #print(answer['answer_start'][0])
    start_char = answer['answer_start'][0] #先取得start_char
    #print(len(answer['text'][0]))
    end_char = start_char + len(answer['text'][0])
    
    #定位答案在token中的起始位置和結束位置
    #取得context的起始和結束,然後從左右2側向答案逼近
    #print(answer,start_char,end_char)
    #print(tokenized_examples.sequence_ids(idx))
    context_start = tokenized_examples.sequence_ids(idx).index(1)
    context_end = tokenized_examples.sequence_ids(idx).index(None, context_start) - 1
    #print(answer, start_char, end_char, context_start, context_end)

    #由於使用的是截斷,有可能答案不在裏面
    if offset[context_end][1] < start_char or offset[context_start][0] > end_char:
        start_token_pos = 0
        end_token_pos = 0
    else:
        token_id = context_start
        while token_id <= context_end and offset[token_id][0] < start_char:
            token_id += 1
        start_token_pos = token_id

        token_id = context_end
        while token_id >= context_start and offset[token_id][1] > end_char:
            token_id -= 1
        end_token_pos = token_id

    #print(answer, start_char, end_char, context_start, context_end,start_token_pos, end_token_pos)
    print("token answer decode:", tokenizer.decode(tokenized_examples['input_ids'][idx][start_token_pos:end_token_pos+1]))
    

    

token answer decode: 1997 年 10 月 9 日
token answer decode: 25 %


In [None]:
def precess_func(examples):
    tokenized_examples = tokenizer(
                                text=examples['question'],
                               text_pair=examples['context'],
                               max_length=512,
                               padding="max_length",
                               return_offsets_mapping=True,
                               truncation='only_second')
    
    offset_mapping = tokenized_examples.pop("offset_mapping")
    start_positions = []
    end_positions = []
    for idx, offset in enumerate(offset_mapping):        
        answer = examples['answers'][idx]        
        start_char = answer['answer_start'][0] #先取得start_char        
        end_char = start_char + len(answer['text'][0])       
        
        context_start = tokenized_examples.sequence_ids(idx).index(1)
        context_end = tokenized_examples.sequence_ids(idx).index(None, context_start) - 1
        
        if offset[context_end][1] < start_char or offset[context_start][0] > end_char:
            start_token_pos = 0
            end_token_pos = 0
        else:
            token_id = context_start
            while token_id <= context_end and offset[token_id][0] < start_char:
                token_id += 1
            start_token_pos = token_id

            token_id = context_end
            while token_id >= context_start and offset[token_id][1] > end_char:
                token_id -= 1
            end_token_pos = token_id        
        start_positions.append(start_token_pos)
        end_positions.append(end_token_pos)
    
    tokenized_examples["start_positions"] = start_positions
    tokenized_examples["end_positions"] = end_positions
    return tokenized_examples    

In [138]:
datasets = load_dataset('roberthsu2003/for_MRC_QA', cache_dir='data')
sample_datasets = datasets['train'].select(range(10000,10011))
precess_func(sample_datasets)

In [144]:
tokenied_datasets = datasets.map(precess_func,batched=True,remove_columns=datasets['train'].column_names)
tokenied_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'],
        num_rows: 26936
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'],
        num_rows: 3524
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'],
        num_rows: 3493
    })
})

In [145]:
model = AutoModelForQuestionAnswering.from_pretrained('google-bert/bert-base-chinese')

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at google-bert/bert-base-chinese and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [147]:
args = TrainingArguments(
    output_dir = "models_for_qa",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=50,
    num_train_epochs=3
)

In [148]:
trainer = Trainer(
    model = model,
    args = args,
    train_dataset=tokenied_datasets["train"],
    eval_dataset=tokenied_datasets['validation'],
    data_collator=DefaultDataCollator()
)