# 加载数据集

In [1]:
# 指定使用的gpu
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '7'

In [2]:

# This flag is the difference between SQUAD v1 or 2 (if you're using another dataset, it indicates if impossible
# answers are allowed or not).
squad_v2 = False
model_checkpoint = "bert-base-uncased"
batch_size = 16

如果使用自己的数据集，可以参考这里的格式规范[Datasets documentation](https://huggingface.co/docs/datasets/loading_datasets.html#from-local-files)

In [3]:
from datasets import load_dataset, load_metric

datasets = load_dataset("squad_v2" if squad_v2 else "squad")

datasets

Reusing dataset squad (/raid/wuxiaojun/.cache/huggingface/datasets/squad/plain_text/1.0.0/6b6c4172d0119c74515f44ea0b8262efe4897f2ddb6613e5e915840fdc309c16)


DatasetDict({
    train: Dataset({
        features: ['answers', 'context', 'id', 'question', 'title'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['answers', 'context', 'id', 'question', 'title'],
        num_rows: 10570
    })
})

In [4]:
datasets["train"][0]

{'answers': {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]},
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 'id': '5733be284776f41900661182',
 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
 'title': 'University_of_Notre_Dame'}

In [5]:
from datasets import ClassLabel, Sequence
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=5):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):
            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])
    display(HTML(df.to_html()))
    
show_random_elements(datasets["train"])   

Unnamed: 0,answers,context,id,question,title
0,"{'text': ['China'], 'answer_start': [110]}","Early in 1953, the French asked Eisenhower for help in French Indochina against the Communists, supplied from China, who were fighting the First Indochina War. Eisenhower sent Lt. General John W. ""Iron Mike"" O'Daniel to Vietnam to study and assess the French forces there. Chief of Staff Matthew Ridgway dissuaded the President from intervening by presenting a comprehensive estimate of the massive military deployment that would be necessary. Eisenhower stated prophetically that ""this war would absorb our troops by divisions.""",573272fab9d445190005eb2e,Who was providing supplies to the Vietnamese communists fighting against France?,Dwight_D._Eisenhower
1,"{'text': ['511'], 'answer_start': [937]}","San Diego is served by the San Diego Trolley light rail system, by the SDMTS bus system, and by Coaster and Amtrak Pacific Surfliner commuter rail; northern San Diego county is also served by the Sprinter light rail line. The Trolley primarily serves downtown and surrounding urban communities, Mission Valley, east county, and coastal south bay. A planned Mid-Coast extension of the Trolley will operate from Old Town to University City and the University of California, San Diego along the I-5 Freeway, with planned operation by 2018. The Amtrak and Coaster trains currently run along the coastline and connect San Diego with Los Angeles, Orange County, Riverside, San Bernardino, and Ventura via Metrolink and the Pacific Surfliner. There are two Amtrak stations in San Diego, in Old Town and the Santa Fe Depot downtown. San Diego transit information about public transportation and commuting is available on the Web and by dialing ""511"" from any phone in the area.",5730349204bcaa1900d7736e,What number can you dial from any phone for public transportation information in San Diego?,San_Diego
2,"{'text': ['1920'], 'answer_start': [24]}","The term was created in 1920 by Hans Winkler, professor of botany at the University of Hamburg, Germany. The Oxford Dictionary suggests the name to be a blend of the words gene and chromosome. However, see omics for a more thorough discussion. A few related -ome words already existed—such as biome, rhizome, forming a vocabulary into which genome fits systematically.",56dc54a514d3a41400c267c6,In what year was the word genome first created?,Genome
3,"{'text': ['ACID-compliant transaction processing'], 'answer_start': [512]}","XML databases are a type of structured document-oriented database that allows querying based on XML document attributes. XML databases are mostly used in enterprise database management, where XML is being used as the machine-to-machine data interoperability standard. XML database management systems include commercial software MarkLogic and Oracle Berkeley DB XML, and a free use software Clusterpoint Distributed XML/JSON Database. All are enterprise software database platforms and support industry standard ACID-compliant transaction processing with strong database consistency characteristics and high level of database security.",572f78f5b2c2fd1400568164,What type of processing is used in enterprise database software?,Database
4,"{'text': ['automobiles and high-powered rifles'], 'answer_start': [355]}","The Arabian oryx, a species of large antelope, once inhabited much of the desert areas of the Middle East. However, the species' striking appearance made it (along with the closely related scimitar-horned oryx and addax) a popular quarry for sport hunters, especially foreign executives of oil companies working in the region.[citation needed] The use of automobiles and high-powered rifles destroyed their only advantage: speed, and they became extinct in the wild exclusively due to sport hunting in 1972. The scimitar-horned oryx followed suit, while the addax became critically endangered. However, the Arabian oryx has now made a comeback and been upgraded from “extinct in the wild” to “vulnerable” due to conservation efforts like captive breeding",5736342b506b47140023658e,What destroyed the Arabian oryx only advantage of speed.,Hunting


# 处理数据集

In [6]:
from transformers import AutoTokenizer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

# fast tokenizers
import transformers
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

In [7]:
print(tokenizer("What is your name?", "My name is Sylvain."))
tokenizer(["What is your name?", "My name is Sylvain."])

{'input_ids': [101, 2054, 2003, 2115, 2171, 1029, 102, 2026, 2171, 2003, 25353, 22144, 2378, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


{'input_ids': [[101, 2054, 2003, 2115, 2171, 1029, 102], [101, 2026, 2171, 2003, 25353, 22144, 2378, 1012, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]]}

In [8]:

max_length = 384 # The maximum length of a feature (question and context)
doc_stride = 128 # The authorized overlap between two part of the context when splitting it is needed.

# 
for i, example in enumerate(datasets["train"]):
    if len(tokenizer(example["question"], example["context"])["input_ids"]) > 384:
        break
example = datasets["train"][i]

print(len(tokenizer(example["question"], example["context"])["input_ids"]))

# 直接截断可能会损失信息，甚至是答案
print(len(tokenizer(example["question"], example["context"], max_length=max_length, truncation="only_second")["input_ids"]))

396
384


In [9]:
tokenized_example = tokenizer(
    example["question"],
    example["context"],
    max_length=max_length,
    truncation="only_second", # 只截断context，不截断question
    return_overflowing_tokens=True, # 返回一个最大长度的列表
    stride=doc_stride
)

print([len(x) for x in tokenized_example["input_ids"]]) # 返回一个最大长度的列表，生成了157个384 token的token list

for x in tokenized_example["input_ids"][:2]:
    print(tokenizer.decode(x))

[384, 157]
[CLS] how many wins does the notre dame men's basketball team have? [SEP] the men's basketball team has over 1, 600 wins, one of only 12 schools who have reached that mark, and have appeared in 28 ncaa tournaments. former player austin carr holds the record for most points scored in a single game of the tournament with 61. although the team has never won the ncaa tournament, they were named by the helms athletic foundation as national champions twice. the team has orchestrated a number of upsets of number one ranked teams, the most notable of which was ending ucla's record 88 - game winning streak in 1974. the team has beaten an additional eight number - one teams, and those nine wins rank second, to ucla's 10, all - time in wins against the top team. the team plays in newly renovated purcell pavilion ( within the edmund p. joyce center ), which reopened for the beginning of the 2009 – 2010 season. the team is coached by mike brey, who, as of the 2014 – 15 season, his fiftee

In [10]:
tokenized_example = tokenizer(
    example["question"],
    example["context"],
    max_length=max_length,
    truncation="only_second",
    return_overflowing_tokens=True,
    return_offsets_mapping=True, # 答案的位置
    stride=doc_stride
)

print(len(tokenized_example["offset_mapping"][0]))
print(len(tokenized_example["offset_mapping"][1]))
print(tokenized_example["offset_mapping"][0][:100]) # 每个token位置
# The very first token ([CLS]) has (0, 0) because it doesn't correspond to any part of the question/answer, 
# then the second token is the same as the characters 0 to 3 of the question

384
157
[(0, 0), (0, 3), (4, 8), (9, 13), (14, 18), (19, 22), (23, 28), (29, 33), (34, 37), (37, 38), (38, 39), (40, 50), (51, 55), (56, 60), (60, 61), (0, 0), (0, 3), (4, 7), (7, 8), (8, 9), (10, 20), (21, 25), (26, 29), (30, 34), (35, 36), (36, 37), (37, 40), (41, 45), (45, 46), (47, 50), (51, 53), (54, 58), (59, 61), (62, 69), (70, 73), (74, 78), (79, 86), (87, 91), (92, 96), (96, 97), (98, 101), (102, 106), (107, 115), (116, 118), (119, 121), (122, 126), (127, 138), (138, 139), (140, 146), (147, 153), (154, 160), (161, 165), (166, 171), (172, 175), (176, 182), (183, 186), (187, 191), (192, 198), (199, 205), (206, 208), (209, 210), (211, 217), (218, 222), (223, 225), (226, 229), (230, 240), (241, 245), (246, 248), (248, 249), (250, 258), (259, 262), (263, 267), (268, 271), (272, 277), (278, 281), (282, 285), (286, 290), (291, 301), (301, 302), (303, 307), (308, 312), (313, 318), (319, 321), (322, 325), (326, 330), (330, 331), (332, 340), (341, 351), (352, 354), (355, 363), (364, 373

In [11]:
first_token_id = tokenized_example["input_ids"][0][1]
offsets = tokenized_example["offset_mapping"][0][1]
print(tokenizer.convert_ids_to_tokens([first_token_id])[0], example["question"][offsets[0]:offsets[1]])

how How


In [12]:
# 区分是question还是context
sequence_ids = tokenized_example.sequence_ids()
print(sequence_ids)

[None, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, None, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [13]:
answers = example["answers"]
start_char = answers["answer_start"][0]
end_char = start_char + len(answers["text"][0])

# Start token index of the current span in the text.
token_start_index = 0
while sequence_ids[token_start_index] != 1:
    token_start_index += 1

# End token index of the current span in the text.
token_end_index = len(tokenized_example["input_ids"][0]) - 1
while sequence_ids[token_end_index] != 1:
    token_end_index -= 1

# Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
offsets = tokenized_example["offset_mapping"][0]
if (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
    # Move the token_start_index and token_end_index to the two ends of the answer.
    # Note: we could go after the last offset if the answer is the last word (edge case).
    while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
        token_start_index += 1
    start_position = token_start_index - 1
    while offsets[token_end_index][1] >= end_char:
        token_end_index -= 1
    end_position = token_end_index + 1
    print(start_position, end_position)
else:
    print("The answer is not in this feature.")

23 26


In [14]:
# 找到答案的位置
print(tokenizer.decode(tokenized_example["input_ids"][0][start_position: end_position+1]))
print(answers["text"][0])

over 1, 600
over 1,600


In [15]:
pad_on_right = tokenizer.padding_side == "right"

- 把上面的分析放在一个prepare_train_feature函数中
- 如果在这条分样本中没有答案（1条样本被分成了157条分样本），在首尾加上【CLS】
- 如果allow_impossible_answers = False，就舍弃这些分样本

In [16]:
def prepare_train_features(examples):
    # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # The offset mappings will give us a map from token to character position in the original context. This will
    # help us compute the start_positions and end_positions.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Let's label those examples!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # If no answers are given, set the cls_index as answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

In [17]:
features = prepare_train_features(datasets['train'][:5])
print(len(features['input_ids']))
print(tokenizer.decode(features['input_ids'][0]))
features

5
[CLS] to whom did the virgin mary allegedly appear in 1858 in lourdes france? [SEP] architecturally, the school has a catholic character. atop the main building's gold dome is a golden statue of the virgin mary. immediately in front of the main building and facing it, is a copper statue of christ with arms upraised with the legend " venite ad me omnes ". next to the main building is the basilica of the sacred heart. immediately behind the basilica is the grotto, a marian place of prayer and reflection. it is a replica of the grotto at lourdes, france where the virgin mary reputedly appeared to saint bernadette soubirous in 1858. at the end of the main drive ( and in a direct line that connects through 3 statues and the gold dome ), is a simple, modern stone statue of mary. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD

{'input_ids': [[101, 2000, 3183, 2106, 1996, 6261, 2984, 9382, 3711, 1999, 8517, 1999, 10223, 26371, 2605, 1029, 102, 6549, 2135, 1010, 1996, 2082, 2038, 1037, 3234, 2839, 1012, 10234, 1996, 2364, 2311, 1005, 1055, 2751, 8514, 2003, 1037, 3585, 6231, 1997, 1996, 6261, 2984, 1012, 3202, 1999, 2392, 1997, 1996, 2364, 2311, 1998, 5307, 2009, 1010, 2003, 1037, 6967, 6231, 1997, 4828, 2007, 2608, 2039, 14995, 6924, 2007, 1996, 5722, 1000, 2310, 3490, 2618, 4748, 2033, 18168, 5267, 1000, 1012, 2279, 2000, 1996, 2364, 2311, 2003, 1996, 13546, 1997, 1996, 6730, 2540, 1012, 3202, 2369, 1996, 13546, 2003, 1996, 24665, 23052, 1010, 1037, 14042, 2173, 1997, 7083, 1998, 9185, 1012, 2009, 2003, 1037, 15059, 1997, 1996, 24665, 23052, 2012, 10223, 26371, 1010, 2605, 2073, 1996, 6261, 2984, 22353, 2135, 2596, 2000, 3002, 16595, 9648, 4674, 2061, 12083, 9711, 2271, 1999, 8517, 1012, 2012, 1996, 2203, 1997, 1996, 2364, 3298, 1006, 1998, 1999, 1037, 3622, 2240, 2008, 8539, 2083, 1017, 11342, 1998, 1996, 2

In [18]:
# 对整个数据集处理
tokenized_datasets = datasets.map(prepare_train_features, batched=True, remove_columns=datasets["train"].column_names)

Loading cached processed dataset at /raid/wuxiaojun/.cache/huggingface/datasets/squad/plain_text/1.0.0/6b6c4172d0119c74515f44ea0b8262efe4897f2ddb6613e5e915840fdc309c16/cache-7e9ed40609d0af0c.arrow
Loading cached processed dataset at /raid/wuxiaojun/.cache/huggingface/datasets/squad/plain_text/1.0.0/6b6c4172d0119c74515f44ea0b8262efe4897f2ddb6613e5e915840fdc309c16/cache-82a69902ea3daed6.arrow


# Fine-tune模型

In [19]:

from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer

model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForQuestionAnswering: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased a

In [20]:
model_name = model_checkpoint.split("/")[-1]
args = TrainingArguments(
    f"test-squad",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=1,
    weight_decay=0.01
)

In [21]:
from transformers import default_data_collator

data_collator = default_data_collator

In [22]:
trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [23]:
trainer.train()



[34m[1mwandb[0m: Currently logged in as: [33mshawn-wu[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.11.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Epoch,Training Loss,Validation Loss
1,1.0981,1.040891


TrainOutput(global_step=5533, training_loss=1.3255303336508382, metrics={'train_runtime': 772.0603, 'train_samples_per_second': 114.659, 'train_steps_per_second': 7.167, 'total_flos': 2.2209777555757056e+16, 'train_loss': 1.3255303336508382, 'epoch': 1.0})

# 模型评价

In [25]:
import torch

for batch in trainer.get_eval_dataloader():
    break
batch = {k: v.to(trainer.args.device) for k, v in batch.items()}
with torch.no_grad():
    output = trainer.model(**batch)
output.keys()

odict_keys(['loss', 'start_logits', 'end_logits'])

In [26]:
output.start_logits.shape, output.end_logits.shape # 16个样本，length 384

(torch.Size([16, 384]), torch.Size([16, 384]))

In [27]:
output.start_logits.argmax(dim=-1), output.end_logits.argmax(dim=-1) # 答案的开始结束位置

(tensor([ 46,  57,  78,  43, 118, 162,  72,  35, 162,  34,  73,  41,  80,  91,
         156,  35], device='cuda:0'),
 tensor([ 47,  58,  81,  44, 118, 109,  75,  37, 109,  36,  76,  42,  83,  94,
         158,  35], device='cuda:0'))

In [28]:
n_best_size = 20

import numpy as np

start_logits = output.start_logits[0].cpu().numpy()
end_logits = output.end_logits[0].cpu().numpy()
# Gather the indices the best start/end logits:
start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
valid_answers = []
for start_index in start_indexes:
    for end_index in end_indexes:
        if start_index <= end_index: # We need to refine that test to check the answer is inside the context
            valid_answers.append(
                {
                    "score": start_logits[start_index] + end_logits[end_index],
                    "text": "" # We need to find a way to get back the original substring corresponding to the answer in the context
                }
            )

In [35]:
valid_answers

[{'score': 12.56584, 'text': 'Denver Broncos'},
 {'score': 11.295288,
  'text': 'Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers'},
 {'score': 9.633663, 'text': 'Carolina Panthers'},
 {'score': 8.797333,
  'text': 'The American Football Conference (AFC) champion Denver Broncos'},
 {'score': 8.000433,
  'text': 'American Football Conference (AFC) champion Denver Broncos'},
 {'score': 7.5267816,
  'text': 'The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers'},
 {'score': 7.518034, 'text': 'Denver'},
 {'score': 7.1832376, 'text': 'AFC) champion Denver Broncos'},
 {'score': 7.1333537, 'text': 'Broncos'},
 {'score': 6.7298813,
  'text': 'American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers'},
 {'score': 6.7220583,
  'text': 'Denver Broncos defeated the National Football Conference (NFC

In [29]:
def prepare_validation_features(examples):
    # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

    # We keep the example_id that gave us this feature and we will store the offset mappings.
    tokenized_examples["example_id"] = []

    for i in range(len(tokenized_examples["input_ids"])):
        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)
        context_index = 1 if pad_on_right else 0

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        tokenized_examples["example_id"].append(examples["id"][sample_index])

        # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
        # position is part of the context or not.
        tokenized_examples["offset_mapping"][i] = [
            (o if sequence_ids[k] == context_index else None)
            for k, o in enumerate(tokenized_examples["offset_mapping"][i])
        ]

    return tokenized_examples

In [31]:
validation_features = datasets["validation"].map(
    prepare_validation_features,
    batched=True,
    remove_columns=datasets["validation"].column_names
)

len(validation_features)

Loading cached processed dataset at /raid/wuxiaojun/.cache/huggingface/datasets/squad/plain_text/1.0.0/6b6c4172d0119c74515f44ea0b8262efe4897f2ddb6613e5e915840fdc309c16/cache-e241a90cda44b4cf.arrow


10784

In [32]:
raw_predictions = trainer.predict(validation_features)

In [33]:

validation_features.set_format(type=validation_features.format["type"], columns=list(validation_features.features.keys()))


max_answer_length = 30

start_logits = output.start_logits[0].cpu().numpy()
end_logits = output.end_logits[0].cpu().numpy()
offset_mapping = validation_features[0]["offset_mapping"]
# The first feature comes from the first example. For the more general case, we will need to be match the example_id to
# an example index
context = datasets["validation"][0]["context"]

# Gather the indices the best start/end logits:
start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
valid_answers = []
for start_index in start_indexes:
    for end_index in end_indexes:
        # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
        # to part of the input_ids that are not in the context.
        if (
            start_index >= len(offset_mapping)
            or end_index >= len(offset_mapping)
            or offset_mapping[start_index] is None
            or offset_mapping[end_index] is None
        ):
            continue
        # Don't consider answers with a length that is either < 0 or > max_answer_length.
        if end_index < start_index or end_index - start_index + 1 > max_answer_length:
            continue
        if start_index <= end_index: # We need to refine that test to check the answer is inside the context
            start_char = offset_mapping[start_index][0]
            end_char = offset_mapping[end_index][1]
            valid_answers.append(
                {
                    "score": start_logits[start_index] + end_logits[end_index],
                    "text": context[start_char: end_char]
                }
            )

valid_answers = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[:n_best_size]
valid_answers

[{'score': 12.56584, 'text': 'Denver Broncos'},
 {'score': 11.295288,
  'text': 'Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers'},
 {'score': 9.633663, 'text': 'Carolina Panthers'},
 {'score': 8.797333,
  'text': 'The American Football Conference (AFC) champion Denver Broncos'},
 {'score': 8.000433,
  'text': 'American Football Conference (AFC) champion Denver Broncos'},
 {'score': 7.5267816,
  'text': 'The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers'},
 {'score': 7.518034, 'text': 'Denver'},
 {'score': 7.1832376, 'text': 'AFC) champion Denver Broncos'},
 {'score': 7.1333537, 'text': 'Broncos'},
 {'score': 6.7298813,
  'text': 'American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers'},
 {'score': 6.7220583,
  'text': 'Denver Broncos defeated the National Football Conference (NFC

In [34]:
datasets["validation"][0]["answers"]

{'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'],
 'answer_start': [177, 177, 177]}

[34m[1mwandb[0m: Network error resolved after 0:02:20.072726, resuming normal operation.
[34m[1mwandb[0m: Network error resolved after 0:03:54.629979, resuming normal operation.


# 参考链接

- https://github.com/huggingface/notebooks/blob/master/examples/question_answering.ipynb
- https://mp.weixin.qq.com/s?__biz=MzIwNDA5NDYzNA==&mid=2247492410&idx=1&sn=2774db9bde601c2dbaafcaf93c93f1f4&chksm=96c7ceffa1b047e9afec2d0cfb5ad94c0ed80795e5598e9f55a1a6c281b253fe2fac14583eea&scene=178&cur_album_id=1364202321906941952#rd
- https://github.com/datawhalechina/competition-baseline/blob/e99c4ce87b214ce5380f8d63b54043761031dd98/competition/AIWIN2021/AIWIN-%E4%BF%9D%E9%99%A9%E6%96%87%E6%9C%AC%E9%97%AE%E7%AD%94-submit.ipynb