# 基于截断策略的机器阅读理解任务实现

* 数据集： cmrc2018

* 预训练模型：hfl/chinese-macbert-base

* 数据集处理方式：对context进行截断

## Step1 导入相关包

In [1]:
from datasets import load_dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForQuestionAnswering,
    TrainingArguments,
    Trainer,
    DefaultDataCollator,  # 这次默认填到512
)

  from .autonotebook import tqdm as notebook_tqdm


## Step2 数据集加载

In [2]:
# 如果可以联网，直接使用load_dataset进行加载
# datasets = load_dataset("cmrc2018", cache_dir="data")
# 如果无法联网，则使用下面的方式加载数据集
datasets = DatasetDict.load_from_disk("mrc_data")
datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'context', 'question', 'answers'],
        num_rows: 10142
    })
    validation: Dataset({
        features: ['id', 'context', 'question', 'answers'],
        num_rows: 3219
    })
    test: Dataset({
        features: ['id', 'context', 'question', 'answers'],
        num_rows: 1002
    })
})

In [3]:
datasets["train"][0]

{'id': 'TRAIN_186_QUERY_0',
 'context': '范廷颂枢机（，），圣名保禄·若瑟（），是越南罗马天主教枢机。1963年被任为主教；1990年被擢升为天主教河内总教区宗座署理；1994年被擢升为总主教，同年年底被擢升为枢机；2009年2月离世。范廷颂于1919年6月15日在越南宁平省天主教发艳教区出生；童年时接受良好教育后，被一位越南神父带到河内继续其学业。范廷颂于1940年在河内大修道院完成神学学业。范廷颂于1949年6月6日在河内的主教座堂晋铎；及后被派到圣女小德兰孤儿院服务。1950年代，范廷颂在河内堂区创建移民接待中心以收容到河内避战的难民。1954年，法越战争结束，越南民主共和国建都河内，当时很多天主教神职人员逃至越南的南方，但范廷颂仍然留在河内。翌年管理圣若望小修院；惟在1960年因捍卫修院的自由、自治及拒绝政府在修院设政治课的要求而被捕。1963年4月5日，教宗任命范廷颂为天主教北宁教区主教，同年8月15日就任；其牧铭为「我信天主的爱」。由于范廷颂被越南政府软禁差不多30年，因此他无法到所属堂区进行牧灵工作而专注研读等工作。范廷颂除了面对战争、贫困、被当局迫害天主教会等问题外，也秘密恢复修院、创建女修会团体等。1990年，教宗若望保禄二世在同年6月18日擢升范廷颂为天主教河内总教区宗座署理以填补该教区总主教的空缺。1994年3月23日，范廷颂被教宗若望保禄二世擢升为天主教河内总教区总主教并兼天主教谅山教区宗座署理；同年11月26日，若望保禄二世擢升范廷颂为枢机。范廷颂在1995年至2001年期间出任天主教越南主教团主席。2003年4月26日，教宗若望保禄二世任命天主教谅山教区兼天主教高平教区吴光杰主教为天主教河内总教区署理主教；及至2005年2月19日，范廷颂因获批辞去总主教职务而荣休；吴光杰同日真除天主教河内总教区总主教职务。范廷颂于2009年2月22日清晨在河内离世，享年89岁；其葬礼于同月26日上午在天主教河内总教区总主教座堂举行。',
 'question': '范廷颂是什么时候被任为主教的？',
 'answers': {'text': ['1963年'], 'answer_start': [30]}}

## Step3 数据预处理

In [4]:
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-macbert-base")
tokenizer

BertTokenizerFast(name_or_path='hfl/chinese-macbert-base', vocab_size=21128, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [5]:
sample_dataset = datasets["train"].select(range(10))

In [6]:
tokenized_examples = tokenizer(
    text=sample_dataset["question"],
    text_pair=sample_dataset["context"],
    return_offsets_mapping=True,
    max_length=512,
    truncation="only_second",
    padding="max_length",
)
tokenized_examples.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping'])

In [7]:
print(
    tokenized_examples["offset_mapping"][0],
    len(tokenized_examples["offset_mapping"][0]),
)

[(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (19, 20), (20, 21), (21, 22), (22, 23), (23, 24), (24, 25), (25, 26), (26, 27), (27, 28), (28, 29), (29, 30), (30, 34), (34, 35), (35, 36), (36, 37), (37, 38), (38, 39), (39, 40), (40, 41), (41, 45), (45, 46), (46, 47), (47, 48), (48, 49), (49, 50), (50, 51), (51, 52), (52, 53), (53, 54), (54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 60), (60, 61), (61, 62), (62, 63), (63, 67), (67, 68), (68, 69), (69, 70), (70, 71), (71, 72), (72, 73), (73, 74), (74, 75), (75, 76), (76, 77), (77, 78), (78, 79), (79, 80), (80, 81), (81, 82), (82, 83), (83, 84), (84, 85), (85, 86), (86, 87), (87, 91), (91, 92), (92, 93), (93, 94), (94, 95), (95, 96), (96, 97), (97, 98), (98, 99), (

In [8]:
offset_mapping = tokenized_examples.pop("offset_mapping")

In [9]:
tokenized_examples["token_type_ids"][0]

[0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,


In [10]:
tokenized_examples.sequence_ids(0)  # 特殊token的id为None

[None,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 None,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1

In [11]:
cs = tokenized_examples.sequence_ids(0).index(1)
cs

17

In [12]:
ce = tokenized_examples.sequence_ids(0).index(None, cs) - 1
ce

510

In [13]:
offset_mapping[0]

[(0, 0),
 (0, 1),
 (1, 2),
 (2, 3),
 (3, 4),
 (4, 5),
 (5, 6),
 (6, 7),
 (7, 8),
 (8, 9),
 (9, 10),
 (10, 11),
 (11, 12),
 (12, 13),
 (13, 14),
 (14, 15),
 (0, 0),
 (0, 1),
 (1, 2),
 (2, 3),
 (3, 4),
 (4, 5),
 (5, 6),
 (6, 7),
 (7, 8),
 (8, 9),
 (9, 10),
 (10, 11),
 (11, 12),
 (12, 13),
 (13, 14),
 (14, 15),
 (15, 16),
 (16, 17),
 (17, 18),
 (18, 19),
 (19, 20),
 (20, 21),
 (21, 22),
 (22, 23),
 (23, 24),
 (24, 25),
 (25, 26),
 (26, 27),
 (27, 28),
 (28, 29),
 (29, 30),
 (30, 34),
 (34, 35),
 (35, 36),
 (36, 37),
 (37, 38),
 (38, 39),
 (39, 40),
 (40, 41),
 (41, 45),
 (45, 46),
 (46, 47),
 (47, 48),
 (48, 49),
 (49, 50),
 (50, 51),
 (51, 52),
 (52, 53),
 (53, 54),
 (54, 55),
 (55, 56),
 (56, 57),
 (57, 58),
 (58, 59),
 (59, 60),
 (60, 61),
 (61, 62),
 (62, 63),
 (63, 67),
 (67, 68),
 (68, 69),
 (69, 70),
 (70, 71),
 (71, 72),
 (72, 73),
 (73, 74),
 (74, 75),
 (75, 76),
 (76, 77),
 (77, 78),
 (78, 79),
 (79, 80),
 (80, 81),
 (81, 82),
 (82, 83),
 (83, 84),
 (84, 85),
 (85, 86),
 (86, 87

In [14]:
offset_mapping[0][cs]

(0, 1)

In [15]:
offset_mapping[0][ce]

(533, 534)

In [16]:
for idx, offset in enumerate(offset_mapping):
    answer = sample_dataset[idx]["answers"]
    start_char = answer["answer_start"][0]
    end_char = start_char + len(answer["text"][0])

    # 定位答案在token中的起始位置和结束位置
    # 一种策略，我们要拿到context的起始和结束，然后从左右两侧向答案逼近
    context_start = tokenized_examples.sequence_ids(idx).index(1)  # context的起始位置
    context_end = (
        tokenized_examples.sequence_ids(idx).index(None, context_start) - 1
    )  # context的结束位置

    # 判断答案是否在context中
    # offset是一个tuple，第一个元素是token在context中的起始位置，第二个元素是token在context中的结束位置
    if offset[context_end][1] < start_char or offset[context_start][0] > end_char:
        start_token_pos = 0
        end_token_pos = 0
    else:
        token_id = context_start
        while token_id <= context_end and offset[token_id][0] < start_char:
            token_id += 1
        start_token_pos = token_id
        token_id = context_end
        while token_id >= context_start and offset[token_id][1] > end_char:
            token_id -= 1
        end_token_pos = token_id

    print(
        answer,
        start_char,
        end_char,
        context_start,
        context_end,
        start_token_pos,
        end_token_pos,
    )
    print(
        "token answer decode:",
        tokenizer.decode(
            tokenized_examples["input_ids"][idx][start_token_pos : end_token_pos + 1]
        ),
    )

{'text': ['1963年'], 'answer_start': [30]} 30 35 17 510 47 48
token answer decode: 1963 年
{'text': ['1990年被擢升为天主教河内总教区宗座署理'], 'answer_start': [41]} 41 62 15 510 53 70
token answer decode: 1990 年 被 擢 升 为 天 主 教 河 内 总 教 区 宗 座 署 理
{'text': ['范廷颂于1919年6月15日在越南宁平省天主教发艳教区出生'], 'answer_start': [97]} 97 126 15 510 100 124
token answer decode: 范 廷 颂 于 1919 年 6 月 15 日 在 越 南 宁 平 省 天 主 教 发 艳 教 区 出 生
{'text': ['1994年3月23日，范廷颂被教宗若望保禄二世擢升为天主教河内总教区总主教并兼天主教谅山教区宗座署理'], 'answer_start': [548]} 548 598 17 510 0 0
token answer decode: [CLS]
{'text': ['范廷颂于2009年2月22日清晨在河内离世'], 'answer_start': [759]} 759 780 12 510 0 0
token answer decode: [CLS]
{'text': ['《全美超级模特儿新秀大赛》第十季'], 'answer_start': [26]} 26 42 21 510 47 62
token answer decode: 《 全 美 超 级 模 特 儿 新 秀 大 赛 》 第 十 季
{'text': ['有前途的新面孔'], 'answer_start': [247]} 247 254 20 510 232 238
token answer decode: 有 前 途 的 新 面 孔
{'text': ['《Jet》、《东方日报》、《Elle》等'], 'answer_start': [706]} 706 726 20 510 0 0
token answer decode: [CLS]
{'text': ['售货员'], 'answer_start': [202]}

In [17]:
def process_func(examples):
    tokenized_examples = tokenizer(
        text=examples["question"],
        text_pair=examples["context"],
        return_offsets_mapping=True,
        max_length=384,
        truncation="only_second",
        padding="max_length",
    )
    offset_mapping = tokenized_examples.pop("offset_mapping")
    start_positions = []
    end_positions = []
    for idx, offset in enumerate(offset_mapping):
        answer = examples["answers"][idx]
        start_char = answer["answer_start"][0]
        end_char = start_char + len(answer["text"][0])
        # 定位答案在token中的起始位置和结束位置
        # 一种策略，我们要拿到context的起始和结束，然后从左右两侧向答案逼近
        context_start = tokenized_examples.sequence_ids(idx).index(1)
        context_end = (
            tokenized_examples.sequence_ids(idx).index(None, context_start) - 1
        )
        # 判断答案是否在context中
        if offset[context_end][1] < start_char or offset[context_start][0] > end_char:
            start_token_pos = 0
            end_token_pos = 0
        else:
            token_id = context_start
            while token_id <= context_end and offset[token_id][0] < start_char:
                token_id += 1
            start_token_pos = token_id
            token_id = context_end
            while token_id >= context_start and offset[token_id][1] > end_char:
                token_id -= 1
            end_token_pos = token_id
        start_positions.append(start_token_pos)
        end_positions.append(end_token_pos)

    tokenized_examples["start_positions"] = start_positions
    tokenized_examples["end_positions"] = end_positions
    return tokenized_examples

In [18]:
tokenied_datasets = datasets.map(
    process_func, batched=True, remove_columns=datasets["train"].column_names
)
tokenied_datasets

Map: 100%|██████████| 10142/10142 [00:02<00:00, 4639.43 examples/s]
Map: 100%|██████████| 3219/3219 [00:00<00:00, 4966.25 examples/s]
Map: 100%|██████████| 1002/1002 [00:00<00:00, 4952.19 examples/s]


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'],
        num_rows: 10142
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'],
        num_rows: 3219
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'],
        num_rows: 1002
    })
})

## Step4 加载模型

In [19]:
model = AutoModelForQuestionAnswering.from_pretrained("hfl/chinese-macbert-base")

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at hfl/chinese-macbert-base and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Step5 配置TrainingArguments

In [20]:
args = TrainingArguments(
    output_dir="models_for_qa",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=50,
    num_train_epochs=3,
)

## Step6 配置Trainer

In [21]:
trainer = Trainer(
    model=model,
    args=args,
    tokenizer=tokenizer,
    train_dataset=tokenied_datasets["train"],
    eval_dataset=tokenied_datasets["validation"],
    data_collator=DefaultDataCollator(),
)

## Step7 模型训练

In [22]:
trainer.train()

  5%|▌         | 50/951 [00:44<12:58,  1.16it/s]

{'loss': 3.258, 'grad_norm': 14.40850830078125, 'learning_rate': 4.737118822292324e-05, 'epoch': 0.16}


 11%|█         | 100/951 [01:26<12:17,  1.15it/s]

{'loss': 1.7846, 'grad_norm': 9.209692001342773, 'learning_rate': 4.474237644584648e-05, 'epoch': 0.32}


 16%|█▌        | 150/951 [02:09<11:45,  1.14it/s]

{'loss': 1.6263, 'grad_norm': 12.401537895202637, 'learning_rate': 4.211356466876972e-05, 'epoch': 0.47}


 21%|██        | 200/951 [02:53<10:54,  1.15it/s]

{'loss': 1.46, 'grad_norm': 11.141395568847656, 'learning_rate': 3.9484752891692956e-05, 'epoch': 0.63}


 26%|██▋       | 250/951 [03:36<10:03,  1.16it/s]

{'loss': 1.3871, 'grad_norm': 8.307114601135254, 'learning_rate': 3.6855941114616195e-05, 'epoch': 0.79}


 32%|███▏      | 300/951 [04:20<10:03,  1.08it/s]

{'loss': 1.3883, 'grad_norm': 10.909239768981934, 'learning_rate': 3.4227129337539433e-05, 'epoch': 0.95}


                                                 
 33%|███▎      | 317/951 [05:06<09:53,  1.07it/s]

{'eval_loss': 1.061065673828125, 'eval_runtime': 29.6714, 'eval_samples_per_second': 108.488, 'eval_steps_per_second': 3.404, 'epoch': 1.0}


 37%|███▋      | 350/951 [05:38<09:34,  1.05it/s]  

{'loss': 1.1174, 'grad_norm': 9.251436233520508, 'learning_rate': 3.159831756046267e-05, 'epoch': 1.1}


 42%|████▏     | 400/951 [06:26<08:34,  1.07it/s]

{'loss': 1.0215, 'grad_norm': 7.94811487197876, 'learning_rate': 2.8969505783385907e-05, 'epoch': 1.26}


 47%|████▋     | 450/951 [07:13<07:47,  1.07it/s]

{'loss': 0.9878, 'grad_norm': 11.15049934387207, 'learning_rate': 2.6340694006309152e-05, 'epoch': 1.42}


 53%|█████▎    | 500/951 [08:00<07:11,  1.05it/s]

{'loss': 0.9804, 'grad_norm': 8.740620613098145, 'learning_rate': 2.3711882229232387e-05, 'epoch': 1.58}


 58%|█████▊    | 550/951 [08:48<06:22,  1.05it/s]

{'loss': 0.9498, 'grad_norm': 16.194324493408203, 'learning_rate': 2.1083070452155626e-05, 'epoch': 1.74}


 63%|██████▎   | 600/951 [09:36<05:36,  1.04it/s]

{'loss': 0.9515, 'grad_norm': 13.047072410583496, 'learning_rate': 1.8454258675078864e-05, 'epoch': 1.89}


                                                 
 67%|██████▋   | 634/951 [10:36<04:55,  1.07it/s]

{'eval_loss': 1.1030303239822388, 'eval_runtime': 28.4941, 'eval_samples_per_second': 112.971, 'eval_steps_per_second': 3.545, 'epoch': 2.0}


 68%|██████▊   | 650/951 [10:52<04:48,  1.04it/s]

{'loss': 0.871, 'grad_norm': 7.3335700035095215, 'learning_rate': 1.5825446898002103e-05, 'epoch': 2.05}


 74%|███████▎  | 700/951 [11:38<03:55,  1.06it/s]

{'loss': 0.6473, 'grad_norm': 8.387836456298828, 'learning_rate': 1.3196635120925343e-05, 'epoch': 2.21}


 79%|███████▉  | 750/951 [12:25<03:02,  1.10it/s]

{'loss': 0.6918, 'grad_norm': 7.074014663696289, 'learning_rate': 1.056782334384858e-05, 'epoch': 2.37}


 84%|████████▍ | 800/951 [13:10<02:20,  1.07it/s]

{'loss': 0.6354, 'grad_norm': 9.357627868652344, 'learning_rate': 7.93901156677182e-06, 'epoch': 2.52}


 89%|████████▉ | 850/951 [13:56<01:30,  1.12it/s]

{'loss': 0.6249, 'grad_norm': 10.509344100952148, 'learning_rate': 5.310199789695059e-06, 'epoch': 2.68}


 95%|█████████▍| 900/951 [14:42<00:46,  1.09it/s]

{'loss': 0.6547, 'grad_norm': 5.342777729034424, 'learning_rate': 2.6813880126182968e-06, 'epoch': 2.84}


100%|█████████▉| 950/951 [15:28<00:00,  1.10it/s]

{'loss': 0.6423, 'grad_norm': 10.86176586151123, 'learning_rate': 5.257623554153523e-08, 'epoch': 3.0}


                                                 
100%|██████████| 951/951 [15:58<00:00,  1.12it/s]

{'eval_loss': 1.1907379627227783, 'eval_runtime': 28.5134, 'eval_samples_per_second': 112.894, 'eval_steps_per_second': 3.542, 'epoch': 3.0}


100%|██████████| 951/951 [15:59<00:00,  1.01s/it]

{'train_runtime': 959.953, 'train_samples_per_second': 31.695, 'train_steps_per_second': 0.991, 'train_loss': 1.1405312878977738, 'epoch': 3.0}





TrainOutput(global_step=951, training_loss=1.1405312878977738, metrics={'train_runtime': 959.953, 'train_samples_per_second': 31.695, 'train_steps_per_second': 0.991, 'total_flos': 5962661340337152.0, 'train_loss': 1.1405312878977738, 'epoch': 3.0})

## Step8 模型预测

In [23]:
from transformers import pipeline

pipe = pipeline("question-answering", model=model, tokenizer=tokenizer, device=0)

In [24]:
pipe(question="小明在哪里上班？", context="小明在北京上班。")

{'score': 0.7531976699829102, 'start': 3, 'end': 5, 'answer': '北京'}

: 