# SQuAD 1.1

In this notebook, we will see how to fine-tune and evaluate a model on the SQuAS 1.1 dataset.

# Dependencies

If not already done, make sure to install OneQA with `notebooks` extras before getting started.

In [1]:
# If you want CUDA 11 uncomment and run this (for CUDA 10 or CPU you can ignore this line).
#! pip install 'torch~=1.11.0' --extra-index-url https://download.pytorch.org/whl/cu113

# Uncomment to install OneQA from source (pypi package pending).
# The path should be the project root (e.g. '.' below).
#! pip install .[notebooks]

# Configuration

We start by setting some parameters to configure the process.  Note that depending on the GPU being used you may need to tune the batch size.

In [1]:
# This needs to be filled in.
# output_dir = 'FILL_ME_IN'        # Save the results here.  Will overwrite if directory already exists.
output_dir = 'tmp'
# Optional parameters (feel free to leave as default).
model_name = 'xlm-roberta-base'  # Set this to select the LM.  Since this is a multi-lingual dataset, we use the XLM-Roberta model.
cache_dir = None                 # Set this if you have a cache directory for transformers.  Alternatively set the HF_HOME env var.
train_batch_size = 8             # Set this to change the number of features per batch during training.
eval_batch_size = 8              # Set this to change the number of features per batch during evaluation.
gradient_accumulation_steps = 8  # Set this to effectively increase training batch size.
max_train_samples = 100          # Set this to use a subset of the training data (or None for all).
max_eval_samples = 20            # Set this to use a subset of the evaluation data (or None for all).
num_train_epochs = 1             # Set this to change the number of training epochs.
fp16 = False                     # Set this to true to enable fp16 (hardware support required).
num_examples_to_show = 10        # Set this to change the number of random train examples (and their features) to show.

In [2]:
from transformers import TrainingArguments
from transformers.trainer_utils import set_seed

seed = 42
set_seed(seed)

training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=eval_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    num_train_epochs=num_train_epochs,
    evaluation_strategy='no',
    learning_rate=4e-05,
    warmup_ratio=0.1,
    weight_decay=0.1,
    save_steps=50000,
    fp16=fp16,
    seed=seed,
)

# Loading the Model

Here we load the model and tokenizer based on the model_name parameter set above.  We use a model with an extractive QA task head which we will later fine-tune.

In [3]:
from transformers import AutoConfig, AutoTokenizer
from oneqa.mrc.models.heads.extractive import EXTRACTIVE_HEAD
from oneqa.mrc.models.task_model import ModelForDownstreamTasks

from oneqa.mrc.trainers.mrc import MRCTrainer

task_heads = EXTRACTIVE_HEAD
config = AutoConfig.from_pretrained(
    model_name,
    cache_dir=cache_dir,
)
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    cache_dir=cache_dir,
    use_fast=True,
    config=config,
)
model = ModelForDownstreamTasks.from_config(
    config,
    model_name,
    task_heads=task_heads,
    cache_dir=cache_dir,
)
model.set_task_head(next(iter(task_heads)))

print(model)  # Examine the model structure

Some weights of XLMRobertaModelForDownstreamTasks were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['task_heads.qa_head.qa_outputs.bias', 'task_heads.qa_head.classifier.dense.bias', 'task_heads.qa_head.classifier.out_proj.bias', 'task_heads.qa_head.qa_outputs.weight', 'task_heads.qa_head.classifier.dense.weight', 'task_heads.qa_head.classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


XLMRobertaModelForDownstreamTasks(
  (roberta): XLMRobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(250002, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (L

# Loading Data

Here we load the SQuAD 1.1 dataset using Huggingface's datasets library.

In [5]:
import datasets
import random

raw_datasets = datasets.load_dataset(
    'squad',
    'plain_text',
    cache_dir=cache_dir,
)

train_examples = raw_datasets["train"]
max_train_samples = max_train_samples
if max_train_samples is not None:
    # We will select sample from whole data if argument is specified
    train_examples = train_examples.select(range(max_train_samples))

print(f"Using {train_examples.num_rows} train examples.")

eval_examples = raw_datasets["validation"]
max_eval_samples = max_eval_samples
if max_eval_samples is not None:
    # We will select sample from whole data if argument is specified
    random_idxs = random.sample(range(len(eval_examples)), max_eval_samples)
    eval_examples = eval_examples.select(random_idxs)

print(f"Using {eval_examples.num_rows} eval examples.")

Reusing dataset squad (/u/mabornea/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)


  0%|          | 0/2 [00:00<?, ?it/s]

Using 100 train examples.
Using 20 eval examples.


# Preprocessing

Here we preprocess the data to create features which can be given to the model.

In [6]:
from oneqa.mrc.processors.preprocessors.squad import SQUADPreprocessor

preprocessor = SQUADPreprocessor(
    stride=128,
    tokenizer=tokenizer,
)

# Train Feature Creation
with training_args.main_process_first(desc="train dataset map pre-processing"):
    train_examples, train_dataset = preprocessor.process_train(train_examples)

print(f"Preprocessing produced {train_dataset.num_rows} train features from {train_examples.num_rows} examples.")

# Validation Feature Creation
with training_args.main_process_first(desc="validation dataset map pre-processing"):
    eval_examples, eval_dataset = preprocessor.process_eval(eval_examples)

print(f"Preprocessing produced {eval_dataset.num_rows} eval features from {eval_examples.num_rows} examples.")

Loading cached processed dataset at /u/mabornea/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453/cache-beec2b718128c871.arrow
Loading cached processed dataset at /u/mabornea/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453/cache-6ea01577bc09ec02.arrow


Preprocessing produced 100 train features from 100 examples.


  0%|          | 0/20 [00:00<?, ?ex/s]

Running tokenizer on eval dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Preprocessing produced 20 eval features from 20 examples.


In [7]:
from datasets import ClassLabel, Sequence
import random
import pandas as pd
from IPython.display import display, HTML

# Based on https://github.com/huggingface/notebooks/blob/main/examples/question_answering.ipynb
def show_elements(dataset):
    df = pd.DataFrame(dataset)
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):
            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])
    display(HTML(df.to_html()))

In [8]:
import random

def trim_document(example, max_len=500):
    example['context'] = example['context'][0]
    doc_len = len(example['context'])
    if doc_len > max_len:
        example['context'] = f"{example['context'][:max_len - 3]}..."        
    return example

random_idxs = random.sample(range(len(train_examples)), num_examples_to_show)
random_train_examples = train_examples.select(random_idxs).remove_columns(['passage_candidates'])
random_train_examples = random_train_examples.map(trim_document)

show_elements(random_train_examples)  # Show random train examples

  0%|          | 0/10 [00:00<?, ?ex/s]

Unnamed: 0,title,context,question,example_id,target,language
0,University_of_Notre_Dame,"As of 2012[update] research continued in many fields. The university president, John Jenkins, described his hope that Notre Dame would become ""one of the pre–eminent research institutions in the world"" in his inaugural address. The university has many multi-disciplinary institutes devoted to research in varying fields, including the Medieval Institute, the Kellogg Institute for International Studies, the Kroc Institute for International Peace studies, and the Center for Social Concerns. Recen...",In what year did Notre Dame begin to host the Global Adaptation Index?,5733b5344776f419006610e0,"{'end_positions': [757], 'passage_indices': [0], 'start_positions': [753], 'yes_no_answer': ['NONE']}",UNKNOWN
1,University_of_Notre_Dame,"The university first offered graduate degrees, in the form of a Master of Arts (MA), in the 1854–1855 academic year. The program expanded to include Master of Laws (LL.M.) and Master of Civil Engineering in its early stages of growth, before a formal graduate school education was developed with a thesis not required to receive the degrees. This changed in 1924 with formal requirements developed for graduate degrees, including offering Doctorate (PhD) degrees. Today each of the five colleges o...",What type of degree is an M.Div.?,5733a7bd4776f41900660f6c,"{'end_positions': [642], 'passage_indices': [0], 'start_positions': [624], 'yes_no_answer': ['NONE']}",UNKNOWN
2,University_of_Notre_Dame,"This Main Building, and the library collection, was entirely destroyed by a fire in April 1879, and the school closed immediately and students were sent home. The university founder, Fr. Sorin and the president at the time, the Rev. William Corby, immediately planned for the rebuilding of the structure that had housed virtually the entire University. Construction was started on the 17th of May and by the incredible zeal of administrator and workers the building was completed before the fall s...",On what date was the rebuilding of The Main Building begun at Notre Dame after the fire that claimed the previous?,57338653d058e614000b5c83,"{'end_positions': [396], 'passage_indices': [0], 'start_positions': [385], 'yes_no_answer': ['NONE']}",UNKNOWN
3,University_of_Notre_Dame,"The university is affiliated with the Congregation of Holy Cross (Latin: Congregatio a Sancta Cruce, abbreviated postnominals: ""CSC""). While religious affiliation is not a criterion for admission, more than 93% of students identify as Christian, with over 80% of the total being Catholic. Collectively, Catholic Mass is celebrated over 100 times per week on campus, and a large campus ministry program provides for the faith needs of the community. There are multitudes of religious statues and ar...",What is Congregation of Holy Cross in Latin?,5733b7f74776f4190066112d,"{'end_positions': [99], 'passage_indices': [0], 'start_positions': [73], 'yes_no_answer': ['NONE']}",UNKNOWN
4,University_of_Notre_Dame,"This Main Building, and the library collection, was entirely destroyed by a fire in April 1879, and the school closed immediately and students were sent home. The university founder, Fr. Sorin and the president at the time, the Rev. William Corby, immediately planned for the rebuilding of the structure that had housed virtually the entire University. Construction was started on the 17th of May and by the incredible zeal of administrator and workers the building was completed before the fall s...",In what year was the Main Building at Notre Dame razed in a fire?,57338653d058e614000b5c81,"{'end_positions': [94], 'passage_indices': [0], 'start_positions': [90], 'yes_no_answer': ['NONE']}",UNKNOWN
5,University_of_Notre_Dame,"As of 2012[update] research continued in many fields. The university president, John Jenkins, described his hope that Notre Dame would become ""one of the pre–eminent research institutions in the world"" in his inaugural address. The university has many multi-disciplinary institutes devoted to research in varying fields, including the Medieval Institute, the Kellogg Institute for International Studies, the Kroc Institute for International Peace studies, and the Center for Social Concerns. Recen...",The Kellogg Institute for International Studies is part of which university?,5733b5344776f419006610de,"{'end_positions': [128], 'passage_indices': [0], 'start_positions': [118], 'yes_no_answer': ['NONE']}",UNKNOWN
6,University_of_Notre_Dame,"Father Joseph Carrier, C.S.C. was Director of the Science Museum and the Library and Professor of Chemistry and Physics until 1874. Carrier taught that scientific research and its promise for progress were not antagonistic to the ideals of intellectual and moral culture endorsed by the Church. One of Carrier's students was Father John Augustine Zahm (1851–1921) who was made Professor and Co-Director of the Science Department at age 23 and by 1900 was a nationally prominent scientist and natur...",What professorship did Father Josh Carrier hold at Notre Dame?,5733b0fb4776f41900661042,"{'end_positions': [119], 'passage_indices': [0], 'start_positions': [85], 'yes_no_answer': ['NONE']}",UNKNOWN
7,University_of_Notre_Dame,"The university first offered graduate degrees, in the form of a Master of Arts (MA), in the 1854–1855 academic year. The program expanded to include Master of Laws (LL.M.) and Master of Civil Engineering in its early stages of growth, before a formal graduate school education was developed with a thesis not required to receive the degrees. This changed in 1924 with formal requirements developed for graduate degrees, including offering Doctorate (PhD) degrees. Today each of the five colleges o...",Which department at Notre Dame is the only one to not offer a PhD program?,5733a7bd4776f41900660f6d,"{'end_positions': [795], 'passage_indices': [0], 'start_positions': [757], 'yes_no_answer': ['NONE']}",UNKNOWN
8,University_of_Notre_Dame,"In 1882, Albert Zahm (John Zahm's brother) built an early wind tunnel used to compare lift to drag of aeronautical models. Around 1899, Professor Jerome Green became the first American to send a wireless message. In 1931, Father Julius Nieuwland performed early work on basic reactions that was used to create neoprene. Study of nuclear physics at the university began with the building of a nuclear accelerator in 1936, and continues now partly through a partnership in the Joint Institute for Nu...",Which individual worked on projects at Notre Dame that eventually created neoprene?,5733b1da4776f4190066106b,"{'end_positions': [245], 'passage_indices': [0], 'start_positions': [222], 'yes_no_answer': ['NONE']}",UNKNOWN
9,University_of_Notre_Dame,"In 2014 the Notre Dame student body consisted of 12,179 students, with 8,448 undergraduates, 2,138 graduate and professional and 1,593 professional (Law, M.Div., Business, M.Ed.) students. Around 21–24% of students are children of alumni, and although 37% of students come from the Midwestern United States, the student body represents all 50 states and 100 countries. As of March 2007[update] The Princeton Review ranked the school as the fifth highest 'dream school' for parents to send their ch...",How many teams participate in the Notre Dame Bookstore Basketball tournament?,5733b5df4776f41900661107,"{'end_positions': [1454], 'passage_indices': [0], 'start_positions': [1446], 'yes_no_answer': ['NONE']}",UNKNOWN


In [9]:
from oneqa.mrc.data_models.target_type import TargetType

def target_type_as_str(feature):
    feature['target_type'] = TargetType(feature['target_type']).name
    return feature

random_train_dataset = train_dataset.filter(lambda feature: feature['example_idx'] in random_idxs).remove_columns(['attention_mask', 'offset_mapping'])
show_elements(random_train_dataset.map(target_type_as_str))  # Show random train features

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/10 [00:00<?, ?ex/s]

Unnamed: 0,example_id,input_ids,example_idx,start_positions,end_positions,target_type
0,5733a7bd4776f41900660f6c,"[0, 4865, 10644, 111, 79385, 83, 142, 276, 5, 192576, 5, 32, 2, 2, 581, 152363, 5117, 122399, 150180, 79385, 7, 4, 23, 70, 3173, 111, 10, 18897, 111, 64624, 15, 8218, 247, 23, 70, 543, 12338, 132869, 11663, 108858, 6602, 5, 581, 1528, 71062, 297, 47, 26698, 18897, 111, 36293, 7, 15, 23708, 5, 594, 5, 16, 136, 18897, 111, 18543, 123470, 23, 6863, 39395, 36541, 7, 111, 75678, 4, 8108, 10, 23113, 150180, 10696, 53019, 509, 126809, 678, 10, 159688, 959, 56065, 47, 53299, 70, 79385, 7, 5, 3293, 98816, 23, 58410, 678, 23113, 96679, 126809, 100, 150180, ...]",25,145,148,SPAN_ANSWER
1,5733a7bd4776f41900660f6d,"[0, 130078, 130625, 99, 52151, 67388, 83, 70, 4734, 1632, 47, 959, 18645, 10, 101862, 1528, 32, 2, 2, 581, 152363, 5117, 122399, 150180, 79385, 7, 4, 23, 70, 3173, 111, 10, 18897, 111, 64624, 15, 8218, 247, 23, 70, 543, 12338, 132869, 11663, 108858, 6602, 5, 581, 1528, 71062, 297, 47, 26698, 18897, 111, 36293, 7, 15, 23708, 5, 594, 5, 16, 136, 18897, 111, 18543, 123470, 23, 6863, 39395, 36541, 7, 111, 75678, 4, 8108, 10, 23113, 150180, 10696, 53019, 509, 126809, 678, 10, 159688, 959, 56065, 47, 53299, 70, 79385, 7, 5, 3293, 98816, 23, 58410, 678, ...]",28,182,188,SPAN_ANSWER
2,5733b0fb4776f41900661042,"[0, 4865, 16030, 16070, 6777, 160960, 146393, 3980, 25388, 16401, 99, 52151, 67388, 32, 2, 2, 160960, 33876, 3980, 25388, 4, 313, 5, 294, 5, 441, 5, 509, 31068, 111, 70, 28745, 25946, 136, 70, 103835, 136, 43552, 111, 83230, 38904, 136, 165712, 7, 24189, 186868, 5, 3980, 25388, 189924, 450, 57456, 25188, 136, 6863, 103036, 100, 42658, 3542, 959, 63212, 6126, 48242, 47, 70, 6397, 7, 111, 91768, 289, 136, 14392, 29394, 22, 31004, 297, 390, 70, 84084, 5, 6561, 111, 3980, 25388, 25, 7, 25921, 509, 160960, 4939, 104734, 13, 825, 18337, 7435, 11703, 74668, 69072, 2750, 509, ...]",53,37,43,SPAN_ANSWER
3,5733b1da4776f4190066106b,"[0, 130078, 11651, 79786, 98, 77635, 99, 52151, 67388, 450, 155605, 75935, 16169, 26344, 13, 32, 2, 2, 360, 156999, 4, 24748, 825, 18337, 15, 98385, 825, 18337, 25, 7, 82953, 16, 88303, 142, 39395, 32382, 80208, 11814, 47, 69101, 60520, 47, 24911, 111, 37511, 10792, 70760, 115774, 5, 62, 67688, 122815, 4, 43552, 27971, 13450, 15497, 100512, 70, 5117, 15672, 47, 25379, 10, 135051, 26008, 5, 360, 66426, 4, 160960, 109112, 128365, 1760, 51339, 297, 39395, 4488, 98, 62822, 132539, 7, 450, 509, 11814, 47, 28282, 16169, 26344, 13, 5, 148027, 111, 72249, 6, 34053, 27744, 7, 99, 70, ...]",57,70,73,SPAN_ANSWER
4,5733b5344776f419006610de,"[0, 581, 23203, 4867, 177, 43975, 100, 8357, 132268, 83, 2831, 111, 3129, 152363, 32, 2, 2, 1301, 111, 1324, 1065, 117008, 268, 25188, 136475, 23, 5941, 44457, 7, 5, 581, 152363, 13918, 4, 4939, 234339, 4, 151552, 1919, 15673, 450, 52151, 67388, 2806, 24209, 44, 3630, 111, 70, 479, 1104, 13, 7732, 18, 25188, 38016, 7, 23, 70, 8999, 58, 23, 1919, 33428, 141, 29823, 5, 581, 152363, 1556, 5941, 6024, 9, 141223, 6635, 32872, 90, 30396, 3674, 47, 25188, 23, 285, 38543, 44457, 7, 4, 26719, 70, 11214, 13, 1405, 43975, 4, 70, 23203, 4867, 177, 43975, 100, ...]",69,41,42,SPAN_ANSWER
5,5733b5344776f419006610e0,"[0, 360, 2367, 6602, 6777, 52151, 67388, 9842, 47, 27980, 70, 13453, 91903, 2320, 31471, 32, 2, 2, 1301, 111, 1324, 1065, 117008, 268, 25188, 136475, 23, 5941, 44457, 7, 5, 581, 152363, 13918, 4, 4939, 234339, 4, 151552, 1919, 15673, 450, 52151, 67388, 2806, 24209, 44, 3630, 111, 70, 479, 1104, 13, 7732, 18, 25188, 38016, 7, 23, 70, 8999, 58, 23, 1919, 33428, 141, 29823, 5, 581, 152363, 1556, 5941, 6024, 9, 141223, 6635, 32872, 90, 30396, 3674, 47, 25188, 23, 285, 38543, 44457, 7, 4, 26719, 70, 11214, 13, 1405, 43975, 4, 70, 23203, 4867, 177, 43975, ...]",71,171,171,SPAN_ANSWER
6,5733b5df4776f41900661107,"[0, 11249, 5941, 87199, 42938, 13, 23, 70, 52151, 67388, 83266, 7535, 234333, 233547, 32, 2, 2, 360, 1049, 70, 52151, 67388, 9836, 14361, 35060, 71, 111, 427, 4, 156918, 25921, 4, 678, 382, 4, 165116, 1379, 88610, 63614, 4, 116, 4, 141535, 150180, 136, 23182, 136, 10285, 11591, 23182, 15, 2729, 434, 4, 276, 5, 192576, 5, 4, 14249, 4, 276, 5, 69489, 5, 16, 25921, 5, 62, 67688, 952, 1104, 304, 11267, 111, 25921, 621, 20020, 111, 228140, 4, 136, 102971, 138, 14427, 111, 25921, 1380, 1295, 70, 23166, 1177, 48850, 14098, 46684, 4, 70, 9836, 14361, 33636, ...]",75,350,351,SPAN_ANSWER
7,5733b7f74776f4190066112d,"[0, 4865, 83, 237626, 19, 111, 152239, 47832, 23, 42845, 32, 2, 2, 581, 152363, 83, 148272, 71, 678, 70, 237626, 19, 111, 152239, 47832, 15, 2729, 2311, 12, 237626, 10, 192437, 41649, 329, 4, 1563, 105160, 14, 27686, 1305, 175574, 8080, 12, 44, 441, 14495, 51029, 51404, 167821, 261, 67666, 2320, 83, 959, 10, 166220, 19, 100, 606, 21150, 4, 1286, 3501, 483, 11587, 111, 25921, 135812, 237, 14949, 4, 678, 645, 20668, 111, 70, 3622, 8035, 129574, 5, 138521, 5844, 538, 4, 129574, 74227, 83, 176016, 71, 645, 805, 20028, 117, 5895, 98, 78132, 4, 136, 10, 21334, ...]",83,29,33,SPAN_ANSWER
8,57338653d058e614000b5c81,"[0, 360, 2367, 6602, 509, 70, 12321, 104919, 99, 52151, 67388, 1954, 297, 23, 10, 11476, 32, 2, 2, 3293, 12321, 104919, 4, 136, 70, 35773, 1294, 42486, 4, 509, 167969, 163684, 297, 390, 10, 11476, 23, 7071, 176447, 4, 136, 70, 10696, 155738, 109312, 136, 25921, 3542, 9325, 5368, 5, 581, 152363, 14037, 56, 4, 21894, 5, 13965, 73, 136, 70, 13918, 99, 70, 1733, 4, 70, 80893, 5, 25031, 5631, 1272, 4, 109312, 203251, 100, 70, 456, 146049, 111, 70, 45646, 450, 1902, 18276, 71, 20513, 538, 70, 64194, 12535, 5, 195769, 509, 26859, 98, 70, 729, 927, ...]",89,38,38,SPAN_ANSWER
9,57338653d058e614000b5c83,"[0, 2161, 2367, 5622, 509, 70, 456, 146049, 111, 581, 12321, 104919, 186, 6967, 99, 52151, 67388, 7103, 70, 11476, 450, 63043, 297, 70, 96362, 32, 2, 2, 3293, 12321, 104919, 4, 136, 70, 35773, 1294, 42486, 4, 509, 167969, 163684, 297, 390, 10, 11476, 23, 7071, 176447, 4, 136, 70, 10696, 155738, 109312, 136, 25921, 3542, 9325, 5368, 5, 581, 152363, 14037, 56, 4, 21894, 5, 13965, 73, 136, 70, 13918, 99, 70, 1733, 4, 70, 80893, 5, 25031, 5631, 1272, 4, 109312, 203251, 100, 70, 456, 146049, 111, 70, 45646, 450, 1902, 18276, 71, 20513, 538, 70, 64194, ...]",91,107,110,SPAN_ANSWER


# Fine-tuning

Here we fine-tune the model on the training set.

In [10]:
from operator import attrgetter
import datasets
from transformers import DataCollatorWithPadding
from oneqa.mrc.data_models.eval_prediction_with_processing import EvalPredictionWithProcessing
from oneqa.mrc.processors.postprocessors.squad import SQUADPostProcessor
from oneqa.mrc.processors.postprocessors.scorers import SupportedSpanScorers

# If using mixed precision we pad for efficient hardware acceleration
using_mixed_precision = any(attrgetter('fp16', 'bf16')(training_args))
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=64 if using_mixed_precision else None)

# noinspection PyProtectedMember
postprocessor = SQUADPostProcessor(
    k=3,
    n_best_size=20,
    max_answer_length=30,
    scorer_type=SupportedSpanScorers.WEIGHTED_SUM_TARGET_TYPE_AND_SCORE_DIFF,
    single_context_multiple_passages=preprocessor._single_context_multiple_passages,
)

def compute_metrics(p: EvalPredictionWithProcessing):
    eval_metrics = datasets.load_metric("squad")
    return eval_metrics.compute(predictions=p.processed_predictions, references=p.label_ids)

trainer = MRCTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset if training_args.do_train else None,
    eval_dataset=eval_dataset if training_args.do_eval else None,
    eval_examples=eval_examples if training_args.do_eval else None,
    tokenizer=tokenizer,
    data_collator=data_collator,
    post_process_function=postprocessor.process_references_and_predictions,  # see QATrainer in Huggingface
    compute_metrics=compute_metrics,
)

train_result = trainer.train()
trainer.save_model()  # Saves the tokenizer too for easy upload

metrics = train_result.metrics
max_train_samples = max_train_samples or len(train_dataset)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

***** Running training *****
  Num examples = 100
  Num Epochs = 1
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 8
  Total optimization steps = 1


Step,Training Loss




Training completed. Do not forget to share your model on huggingface.co/models =)


Saving model checkpoint to tmp
Configuration saved in tmp/config.json
Model weights saved in tmp/pytorch_model.bin
tokenizer config file saved in tmp/tokenizer_config.json
Special tokens file saved in tmp/special_tokens_map.json


***** train metrics *****
  epoch                    =       0.62
  total_flos               =    10421GF
  train_loss               =     4.4133
  train_runtime            = 0:00:31.81
  train_samples            =        100
  train_samples_per_second =      3.143
  train_steps_per_second   =      0.031


# Evaluation

Here we evaluate the model on the validation set.

In [11]:
metrics = trainer.evaluate()

max_eval_samples = max_eval_samples or len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))

trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


100%|██████████| 20/20 [00:00<00:00, 540.89it/s]
100%|██████████| 20/20 [00:00<00:00, 4289.31it/s]


***** eval metrics *****
  epoch            =   0.62
  eval_exact_match =    0.0
  eval_f1          = 3.8277
  eval_samples     =     20


# Predictions

Here we examine the model predictions.

In [12]:
import json
import os
from pprint import pprint

with open(os.path.join(output_dir, 'eval_predictions_processed.json'), 'r') as f:
    predictions = json.load(f)

pprint(predictions)

[{'id': '573786b51c4567190057448d',
  'prediction_text': 'radius () of the Earth to the gravitational accelerat'},
 {'id': '56e1ddfce3433e14004231d8',
  'prediction_text': 'question of whether P equals NP is one'},
 {'id': '56bf48cc3aeaaa14008c95af',
  'prediction_text': 'team to have worn white as the designated home team in '
                     'the Super Bowl was the Pittsburgh Steel'},
 {'id': '5725c337271a42140099d164',
  'prediction_text': 'cles fringed with tentilla ("little tentacles") that are '
                     'covered with colloblasts, sticky cells that capture pre'},
 {'id': '5725e44238643c19005ace36', 'prediction_text': 'minutes. On one'},
 {'id': '571ccfbadd7acb1400e4c164',
  'prediction_text': 'hydrogen and oxygen in the explosive ratio 2:1. Contra'},
 {'id': '56f827caa6d7ea1400e1743a', 'prediction_text': 'claring Luther an out'},
 {'id': '56e127bccd28a01900c6765f', 'prediction_text': 'e'},
 {'id': '572976183f37b31900478431',
  'prediction_text': 'rose being expor