In [1]:
import numpy as np
import evaluate
import random
import pandas as pd
import datasets
import subprocess
import collections
from tqdm.auto import tqdm
from datasets import ClassLabel, Sequence,load_dataset,load_metric
from IPython.display import display,HTML
from transformers import AutoTokenizer,AutoModelForQuestionAnswering,TrainingArguments,Trainer,default_data_collator, Trainer

2024-03-01 15:20:08.106003: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-01 15:20:08.210013: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-03-01 15:20:08.802601: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### 任务二：QA任务

1. 加载数据

In [2]:
# 根据使用的模型和GPU资源情况，调整以下关键参数
squad_v2 = False

In [3]:
# 加载数据集
datasets = load_dataset("squad_v2" if squad_v2 else "squad")

In [4]:
datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

训练集train：8.8万样本；验证机validation：1.1万样本集。

In [5]:
# 可视化数据
def show_random_elements(dataset, num_examples=2):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):
            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])
    display(HTML(df.to_html()))

In [6]:
show_random_elements(datasets["train"])

Unnamed: 0,id,title,context,question,answers
0,56d130df17492d1400aabbc0,IPod,The iPod has also been credited with accelerating shifts within the music industry. The iPod's popularization of digital music storage allows users to abandon listening to entire albums and instead be able to choose specific singles which hastened the end of the Album Era in popular music.,"The ease of collecting singles with the iPod and iTunes is credited with ending what ""era"" in pop music?","{'text': ['the Album Era'], 'answer_start': [259]}"
1,57342937d058e614000b6a66,Portugal,"Portuguese cuisine is diverse. The Portuguese consume a lot of dry cod (bacalhau in Portuguese), for which there are hundreds of recipes. There are more than enough bacalhau dishes for each day of the year. Two other popular fish recipes are grilled sardines and caldeirada, a potato-based stew that can be made from several types of fish. Typical Portuguese meat recipes, that may be made out of beef, pork, lamb, or chicken, include cozido à portuguesa, feijoada, frango de churrasco, leitão (roast suckling pig) and carne de porco à alentejana. A very popular northern dish is the arroz de sarrabulho (rice stewed in pigs blood) or the arroz de cabidela (rice and chickens meat stewed in chickens blood).",What is caldeirada?,"{'text': ['a potato-based stew that can be made from several types of fish'], 'answer_start': [275]}"


2.数据预处理

In [7]:
# 基础模型
model_checkpoint = "distilbert-base-uncased"

In [8]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

max_length = 384 
doc_stride = 128 
pad_on_right = tokenizer.padding_side == "right"

In [9]:
# 数据预处理
def prepare_train_features(examples):
    # 一些问题的左侧可能有很多空白字符，这对我们没有用，而且会导致上下文的截断失败
    # （标记化的问题将占用大量空间）。因此，我们删除左侧的空白字符。
    examples["question"] = [q.lstrip() for q in examples["question"]]

    # 使用截断和填充对我们的示例进行标记化，但保留溢出部分，使用步幅（stride）。
    # 当上下文很长时，这会导致一个示例可能提供多个特征，其中每个特征的上下文都与前一个特征的上下文有一些重叠。
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # 由于一个示例可能给我们提供多个特征（如果它具有很长的上下文），我们需要一个从特征到其对应示例的映射。这个键就提供了这个映射关系。
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # 偏移映射将为我们提供从令牌到原始上下文中的字符位置的映射。这将帮助我们计算开始位置和结束位置。
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # 让我们为这些示例进行标记！
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # 我们将使用 CLS 特殊 token 的索引来标记不可能的答案。
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # 获取与该示例对应的序列（以了解上下文和问题是什么）。
        sequence_ids = tokenized_examples.sequence_ids(i)

        # 一个示例可以提供多个跨度，这是包含此文本跨度的示例的索引。
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # 如果没有给出答案，则将cls_index设置为答案。
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # 答案在文本中的开始和结束字符索引。
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # 当前跨度在文本中的开始令牌索引。
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # 当前跨度在文本中的结束令牌索引。
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # 检测答案是否超出跨度（在这种情况下，该特征的标签将使用CLS索引）。
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # 否则，将token_start_index和token_end_index移到答案的两端。
                # 注意：如果答案是最后一个单词（边缘情况），我们可以在最后一个偏移之后继续。
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

In [10]:
# 对所有的数据集应用函数
tokenized_datasets = datasets.map(prepare_train_features,
                                  batched=True,
                                  remove_columns=datasets["train"].column_names)

In [11]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'start_positions', 'end_positions'],
        num_rows: 88524
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'start_positions', 'end_positions'],
        num_rows: 10784
    })
})

新增答案在上下文中起始和结尾位置坐标。

In [12]:
show_random_elements(tokenized_datasets["train"])

Unnamed: 0,input_ids,attention_mask,start_positions,end_positions
0,"[101, 2129, 2079, 11572, 2817, 1037, 5776, 1005, 1055, 11311, 2291, 11595, 1999, 5057, 8146, 1029, 102, 3188, 2000, 1996, 8259, 1997, 15403, 2013, 1996, 3802, 24335, 10091, 7117, 10047, 23041, 2483, 1010, 2029, 2003, 3763, 2005, 1000, 11819, 1000, 1025, 2220, 11572, 7356, 11595, 2008, 2052, 2101, 2022, 10003, 2004, 6827, 6177, 1997, 1996, 11311, 2291, 1012, 1996, 2590, 1048, 24335, 8458, 9314, 11595, 1997, 1996, 11311, 2291, 2024, 1996, 15177, 7606, 1998, 5923, 24960, 1010, 1998, 2708, 1048, 24335, 21890, 4588, 14095, 2107, 2004, 11867, 24129, 1010, 6197, 12146, 1010, 1048, 24335, 8458, 6470, 1010, 1048, 24335, 8458, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]",146,157
1,"[101, 2054, 2003, 5298, 1005, 1055, 2087, 20151, 2103, 1029, 102, 5298, 1006, 1045, 1013, 100, 1013, 1007, 1006, 13796, 1024, 100, 1010, 9092, 21369, 1007, 2003, 1037, 2110, 2284, 1999, 1996, 8252, 2142, 2163, 1012, 5298, 2003, 1996, 21460, 2922, 1998, 1996, 5550, 2087, 20151, 1997, 1996, 2753, 2142, 2163, 1012, 5298, 2003, 11356, 2011, 5612, 1998, 3448, 2000, 1996, 2167, 1010, 2167, 3792, 2000, 1996, 2264, 1010, 4108, 1010, 6041, 1010, 1998, 5900, 2000, 1996, 2148, 1010, 1998, 6751, 1998, 5284, 2000, 1996, 2225, 1012, 1996, 19682, 4020, 16083, 1996, 2789, 2112, 1997, 1996, 2110, 1010, 1998, 1996, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]",131,131


3.模型训练

In [13]:
# 模型微调
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

# 训练超参数
batch_size=64
model_dir = f"models/{model_checkpoint}-finetuned-squad"

args = TrainingArguments(
    output_dir=model_dir,
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
)

# 数据整理器：将训练数据整理为批次数据，用于模型训练时的批次处理
data_collator = default_data_collator

Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
# 实例化训练器
trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
# 开始训练
trainer.train()

4. 模型评估

对验证集进行预处理，获取答案在上下文中的位置。

In [15]:
def prepare_validation_features(examples):
    # 去掉左侧空白符
    examples["question"] = [q.lstrip() for q in examples["question"]]
    # 截断或者填充
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation = "only_second" if pad_on_right else "only_first",
        max_length = max_length,
        stride = doc_stride,
        return_overflowing_tokens = True,
        return_offsets_mapping = True,
        padding = "max_length",
    )

    #获取当前的起始和结束位置
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

    tokenized_examples["example_id"] = []

    for i in range(len(tokenized_examples["input_ids"])):
        # 获取与该实例对应的序列（以了解哪些是上下文，哪些是问题）
        sequence_ids = tokenized_examples.sequence_ids(i)
        context_index = 1 if pad_on_right else 0

        #一个示例可以产生几个文本段，这里是包含该文本段的示例的索引
        sample_index = sample_mapping[i]
        tokenized_examples["example_id"].append(examples["id"][sample_index])

        #将不属于上下文的偏移映射设置为none，以便容易确定一个令牌位置是否属于上下文
        tokenized_examples["offset_mapping"][i] = [
            (o if sequence_ids[k]==context_index else None)
            for k,o in enumerate(tokenized_examples["offset_mapping"][i])
        ]
    return tokenized_examples

In [16]:
validation_features = datasets["validation"].map(
    prepare_validation_features,
    batched=True,
    remove_columns=datasets["validation"].column_names
)

In [17]:
validation_features

Dataset({
    features: ['input_ids', 'attention_mask', 'offset_mapping', 'example_id'],
    num_rows: 10784
})

In [18]:
show_random_elements(validation_features)

Unnamed: 0,input_ids,attention_mask,offset_mapping,example_id
0,"[101, 2054, 3128, 2708, 7837, 2000, 2413, 3481, 1998, 5561, 16400, 1029, 102, 1999, 1996, 3500, 1997, 23810, 1010, 2703, 16400, 2139, 2474, 15451, 9077, 2001, 2445, 3094, 1997, 1037, 1016, 1010, 2199, 1011, 2158, 2486, 1997, 16017, 2015, 2139, 2474, 3884, 1998, 6505, 1012, 2010, 4449, 2020, 2000, 4047, 1996, 2332, 1005, 1055, 2455, 1999, 1996, 4058, 3028, 2013, 1996, 2329, 1012, 16400, 2628, 1996, 2799, 2008, 8292, 10626, 2239, 2018, 17715, 2041, 2176, 2086, 3041, 1010, 2021, 2073, 8292, 10626, 2239, 2018, 3132, 1996, 2501, 1997, 2413, 4447, 2000, 1996, 8940, 1997, 2599, 7766, 1010, 16400, 3833, 1998, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]","[None, None, None, None, None, None, None, None, None, None, None, None, None, [0, 2], [3, 6], [7, 13], [14, 16], [17, 21], [21, 22], [23, 27], [28, 33], [34, 36], [37, 39], [40, 43], [43, 46], [47, 50], [51, 56], [57, 64], [65, 67], [68, 69], [70, 71], [71, 72], [72, 75], [75, 76], [76, 79], [80, 85], [86, 88], [89, 95], [95, 96], [97, 99], [100, 102], [103, 109], [110, 113], [114, 121], [121, 122], [123, 126], [127, 133], [134, 138], [139, 141], [142, 149], [150, 153], [154, 158], [158, 159], [159, 160], [161, 165], [166, 168], [169, 172], [173, 177], [178, 184], [185, 189], [190, 193], [194, 201], [201, 202], [203, 208], [209, 217], [218, 221], [222, 227], [228, 232], [233, 235], [235, 238], [238, 240], [241, 244], [245, 251], [252, 255], [256, 260], [261, 266], [267, 274], [274, 275], [276, 279], [280, 285], [286, 288], [288, 291], [291, 293], [294, 297], [298, 305], [306, 309], [310, 316], [317, 319], [320, 326], [327, 333], [334, 336], [337, 340], [341, 347], [348, 350], [351, 355], [356, 362], [362, 363], [364, 369], [370, 381], [382, 385], ...]",5733ea04d058e614000b6598
1,"[101, 2429, 2000, 7312, 1010, 2065, 1060, 1998, 1061, 2064, 2022, 13332, 2011, 1996, 2168, 9896, 2059, 1060, 10438, 2054, 3853, 1999, 3276, 2000, 1061, 1029, 102, 2116, 11619, 4280, 2024, 4225, 2478, 1996, 4145, 1997, 1037, 7312, 1012, 1037, 7312, 2003, 1037, 8651, 1997, 2028, 3291, 2046, 2178, 3291, 1012, 2009, 19566, 1996, 11900, 9366, 1997, 1037, 3291, 2108, 2012, 2560, 2004, 3697, 2004, 2178, 3291, 1012, 2005, 6013, 1010, 2065, 1037, 3291, 1060, 2064, 2022, 13332, 2478, 2019, 9896, 2005, 1061, 1010, 1060, 2003, 2053, 2062, 3697, 2084, 1061, 1010, 1998, 2057, 2360, 2008, 1060, 13416, 2000, 1061, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]","[None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, [0, 4], [5, 15], [16, 23], [24, 27], [28, 35], [36, 41], [42, 45], [46, 53], [54, 56], [57, 58], [59, 68], [68, 69], [70, 71], [72, 81], [82, 84], [85, 86], [87, 101], [102, 104], [105, 108], [109, 116], [117, 121], [122, 129], [130, 137], [137, 138], [139, 141], [142, 150], [151, 154], [155, 163], [164, 170], [171, 173], [174, 175], [176, 183], [184, 189], [190, 192], [193, 198], [199, 201], [202, 211], [212, 214], [215, 222], [223, 230], [230, 231], [232, 235], [236, 244], [244, 245], [246, 248], [249, 250], [251, 258], [259, 260], [261, 264], [265, 267], [268, 274], [275, 280], [281, 283], [284, 293], [294, 297], [298, 299], [299, 300], [301, 302], [303, 305], [306, 308], [309, 313], [314, 323], [324, 328], [329, 330], [330, 331], [332, 335], [336, 338], [339, 342], [343, 347], [348, 349], [350, 357], [358, 360], [361, 362], ...]",56e1c9bfe3433e1400423194


In [19]:
# 利用原始模型对验证集预测
raw_predictions = trainer.predict(validation_features)

In [20]:
len(raw_predictions)

3

将模型输出的答案在上下文中的起始和结尾位置的概率分布映射为答案的概率分布，其中概率为起始位置概率+结尾位置概率，取最大的20个输出。

In [21]:
n_best_size = 20

In [22]:
def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size = 20, max_answer_length = 30):
    all_start_logits, all_end_logits = raw_predictions
    # 构建一个从示例到其对应特征的映射。
    example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
    features_per_example = collections.defaultdict(list)
    for i, feature in enumerate(features):
        features_per_example[example_id_to_index[feature["example_id"]]].append(i)

    # 我们需要填充的字典。
    predictions = collections.OrderedDict()

    # 日志记录。
    print(f"正在后处理 {len(examples)} 个示例的预测，这些预测分散在 {len(features)} 个特征中。")

    # 遍历所有示例！
    for example_index, example in enumerate(tqdm(examples)):
        # 这些是与当前示例关联的特征的索引。
        feature_indices = features_per_example[example_index]

        min_null_score = None # 仅在squad_v2为True时使用。
        valid_answers = []
        
        context = example["context"]
        # 遍历与当前示例关联的所有特征。
        for feature_index in feature_indices:
            # 我们获取模型对这个特征的预测。
            start_logits = all_start_logits[feature_index]
            end_logits = all_end_logits[feature_index]
            # 这将允许我们将logits中的某些位置映射到原始上下文中的文本跨度。
            offset_mapping = features[feature_index]["offset_mapping"]

            # 更新最小空预测。
            cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id)
            feature_null_score = start_logits[cls_index] + end_logits[cls_index]
            if min_null_score is None or min_null_score < feature_null_score:
                min_null_score = feature_null_score

            # 浏览所有的最佳开始和结束logits，为 `n_best_size` 个最佳选择。
            start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
            end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # 不考虑超出范围的答案，原因是索引超出范围或对应于输入ID的部分不在上下文中。
                    if (
                        start_index >= len(offset_mapping)
                        or end_index >= len(offset_mapping)
                        or offset_mapping[start_index] is None
                        or offset_mapping[end_index] is None
                    ):
                        continue
                    # 不考虑长度小于0或大于max_answer_length的答案。
                    if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                        continue

                    start_char = offset_mapping[start_index][0]
                    end_char = offset_mapping[end_index][1]
                    valid_answers.append(
                        {
                            "score": start_logits[start_index] + end_logits[end_index],
                            "text": context[start_char: end_char]
                        }
                    )
        
        if len(valid_answers) > 0:
            best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
        else:
            # 在极少数情况下我们没有一个非空预测，我们创建一个假预测以避免失败。
            best_answer = {"text": "", "score": 0.0}
        
        # 选择我们的最终答案：最佳答案或空答案（仅适用于squad_v2）
        if not squad_v2:
            predictions[example["id"]] = best_answer["text"]
        else:
            answer = best_answer["text"] if best_answer["score"] > min_null_score else ""
            predictions[example["id"]] = answer

    return predictions

In [23]:
final_predictions = postprocess_qa_predictions(datasets["validation"], validation_features, raw_predictions.predictions)

正在后处理 10570 个示例的预测，这些预测分散在 10784 个特征中。


  0%|          | 0/10570 [00:00<?, ?it/s]

In [24]:
metric = load_metric("squad_v2" if squad_v2 else "squad")

  metric = load_metric("squad_v2" if squad_v2 else "squad")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [25]:
if squad_v2:
    formatted_predictions = [{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in final_predictions.items()]
else:
    formatted_predictions = [{"id": k, "prediction_text": v} for k, v in final_predictions.items()]
references = [{"id": ex["id"], "answers": ex["answers"]} for ex in datasets["validation"]]
metric.compute(predictions=formatted_predictions, references=references)

{'exact_match': 0.1892147587511826, 'f1': 7.183507430240401}

对原始模型进行评估，得到最终f1得分为7.1835。

导入全量样本集微调的模型进行评估

In [29]:
trained_model = AutoModelForQuestionAnswering.from_pretrained(model_dir)

In [30]:
trained_trainer = Trainer(
    trained_model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [31]:
# 利用微调模型对验证集预测
trained_predictions = trained_trainer.predict(validation_features)

In [32]:
final_trained_predictions = postprocess_qa_predictions(datasets["validation"], validation_features, trained_predictions.predictions)

正在后处理 10570 个示例的预测，这些预测分散在 10784 个特征中。


  0%|          | 0/10570 [00:00<?, ?it/s]

In [34]:
if squad_v2:
    formatted_predictions = [{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in final_trained_predictions.items()]
else:
    formatted_predictions = [{"id": k, "prediction_text": v} for k, v in final_trained_predictions.items()]
references = [{"id": ex["id"], "answers": ex["answers"]} for ex in datasets["validation"]]
metric.compute(predictions=formatted_predictions, references=references)

{'exact_match': 74.99526963103122, 'f1': 83.81339494622097}

使用原始模型distilbert-base-uncased对squad数据集进行预测评估，模型f1得分为7.1835；通过使用squad数据集全量数据对原始模型进行微调，模型f1得分提升到83.8134.

In [35]:
trainer.evaluate(tokenized_datasets["validation"])

{'eval_loss': 5.987925052642822,
 'eval_runtime': 108.3884,
 'eval_samples_per_second': 99.494,
 'eval_steps_per_second': 1.559}