# SQuAD 1.1

In this notebook, we will see how to fine-tune and evaluate a model on the SQuAD 1.1 dataset.

# Dependencies

If not already done, make sure to install PrimeQA with `notebooks` extras before getting started.

In [1]:
# If you want CUDA 11 uncomment and run this (for CUDA 10 or CPU you can ignore this line).
#! pip install 'torch~=1.11.0' --extra-index-url https://download.pytorch.org/whl/cu113

# Uncomment to install OneQA from source (pypi package pending).
# The path should be the project root (e.g. '.' below).
#! pip install .[notebooks]

# Configuration

We start by setting some parameters to configure the process.  Note that depending on the GPU being used you may need to tune the batch size.

In [1]:
# This needs to be filled in.
output_dir = 'FILL_ME_IN'        # Save the results here.  Will overwrite if directory already exists.

# Optional parameters (feel free to leave as default).
model_name = 'roberta-base'  # Set this to select the LM.  Since this is a multi-lingual dataset, we use the XLM-Roberta model.
cache_dir = None                 # Set this if you have a cache directory for transformers.  Alternatively set the HF_HOME env var.
train_batch_size = 8             # Set this to change the number of features per batch during training.
eval_batch_size = 8              # Set this to change the number of features per batch during evaluation.
gradient_accumulation_steps = 8  # Set this to effectively increase training batch size.
max_train_samples = 100          # Set this to use a subset of the training data (or None for all).
max_eval_samples = 20            # Set this to use a subset of the evaluation data (or None for all).
num_train_epochs = 1             # Set this to change the number of training epochs.
fp16 = False                     # Set this to true to enable fp16 (hardware support required).
num_examples_to_show = 10        # Set this to change the number of random train examples (and their features) to show.

In [2]:
from transformers import TrainingArguments
from transformers.trainer_utils import set_seed

seed = 42
set_seed(seed)

training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=eval_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    num_train_epochs=num_train_epochs,
    evaluation_strategy='no',
    learning_rate=4e-05,
    warmup_ratio=0.1,
    weight_decay=0.1,
    save_steps=50000,
    fp16=fp16,
    seed=seed,
)

# Loading the Model

Here we load the model and tokenizer based on the model_name parameter set above.  We use a model with an extractive QA task head which we will later fine-tune.

In [3]:
from transformers import AutoConfig, AutoTokenizer
from primeqa.mrc.models.heads.extractive import EXTRACTIVE_HEAD
from primeqa.mrc.models.task_model import ModelForDownstreamTasks

from primeqa.mrc.trainers.mrc import MRCTrainer

task_heads = EXTRACTIVE_HEAD
config = AutoConfig.from_pretrained(
    model_name,
    cache_dir=cache_dir,
)
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    cache_dir=cache_dir,
    use_fast=True,
    config=config,
)
model = ModelForDownstreamTasks.from_config(
    config,
    model_name,
    task_heads=task_heads,
    cache_dir=cache_dir,
)
model.set_task_head(next(iter(task_heads)))

print(model)  # Examine the model structure

Some weights of RobertaModelForDownstreamTasks were not initialized from the model checkpoint at roberta-base and are newly initialized: ['task_heads.qa_head.classifier.out_proj.weight', 'task_heads.qa_head.classifier.out_proj.bias', 'task_heads.qa_head.qa_outputs.weight', 'task_heads.qa_head.classifier.dense.weight', 'task_heads.qa_head.qa_outputs.bias', 'task_heads.qa_head.classifier.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


RobertaModelForDownstreamTasks(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNor

# Loading Data

Here we load the SQuAD 1.1 dataset using Huggingface's datasets library.

In [4]:
import datasets
import random

raw_datasets = datasets.load_dataset(
    'squad',
    'plain_text',
    cache_dir=cache_dir,
)

train_examples = raw_datasets["train"]
max_train_samples = max_train_samples
if max_train_samples is not None:
    # We will select sample from whole data if argument is specified
    train_examples = train_examples.select(range(max_train_samples))

print(f"Using {train_examples.num_rows} train examples.")

eval_examples = raw_datasets["validation"]
max_eval_samples = max_eval_samples
if max_eval_samples is not None:
    # We will select sample from whole data if argument is specified
    random_idxs = random.sample(range(len(eval_examples)), max_eval_samples)
    eval_examples = eval_examples.select(random_idxs)

print(f"Using {eval_examples.num_rows} eval examples.")

Reusing dataset squad (/u/mabornea/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)


  0%|          | 0/2 [00:00<?, ?it/s]

Using 100 train examples.
Using 20 eval examples.


# Preprocessing

Here we preprocess the data to create features which can be given to the model.

In [5]:
from primeqa.mrc.processors.preprocessors.squad import SQUADPreprocessor

preprocessor = SQUADPreprocessor(
    stride=128,
    tokenizer=tokenizer,
)

# Train Feature Creation
with training_args.main_process_first(desc="train dataset map pre-processing"):
    train_examples, train_dataset = preprocessor.process_train(train_examples)

print(f"Preprocessing produced {train_dataset.num_rows} train features from {train_examples.num_rows} examples.")

# Validation Feature Creation
with training_args.main_process_first(desc="validation dataset map pre-processing"):
    eval_examples, eval_dataset = preprocessor.process_eval(eval_examples)

print(f"Preprocessing produced {eval_dataset.num_rows} eval features from {eval_examples.num_rows} examples.")



  0%|          | 0/100 [00:00<?, ?ex/s]

Running tokenizer on train dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Preprocessing produced 100 train features from 100 examples.


  0%|          | 0/20 [00:00<?, ?ex/s]

Running tokenizer on eval dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Preprocessing produced 20 eval features from 20 examples.


In [6]:
from datasets import ClassLabel, Sequence
import random
import pandas as pd
from IPython.display import display, HTML

# Based on https://github.com/huggingface/notebooks/blob/main/examples/question_answering.ipynb
def show_elements(dataset):
    df = pd.DataFrame(dataset)
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):
            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])
    display(HTML(df.to_html()))

In [7]:
import random

def trim_document(example, max_len=500):
    example['context'] = example['context'][0]
    doc_len = len(example['context'])
    if doc_len > max_len:
        example['context'] = f"{example['context'][:max_len - 3]}..."        
    return example

random_idxs = random.sample(range(len(train_examples)), num_examples_to_show)
random_train_examples = train_examples.select(random_idxs).remove_columns(['passage_candidates'])
random_train_examples = random_train_examples.map(trim_document)

show_elements(random_train_examples)  # Show random train examples

  0%|          | 0/10 [00:00<?, ?ex/s]

Unnamed: 0,title,context,question,example_id,target,language
0,University_of_Notre_Dame,"In 1882, Albert Zahm (John Zahm's brother) built an early wind tunnel used to compare lift to drag of aeronautical models. Around 1899, Professor Jerome Green became the first American to send a wireless message. In 1931, Father Julius Nieuwland performed early work on basic reactions that was used to create neoprene. Study of nuclear physics at the university began with the building of a nuclear accelerator in 1936, and continues now partly through a partnership in the Joint Institute for Nu...",Which individual worked on projects at Notre Dame that eventually created neoprene?,5733b1da4776f4190066106b,"{'end_positions': [245], 'passage_indices': [0], 'start_positions': [222], 'yes_no_answer': ['NONE']}",UNKNOWN
1,University_of_Notre_Dame,"In 2014 the Notre Dame student body consisted of 12,179 students, with 8,448 undergraduates, 2,138 graduate and professional and 1,593 professional (Law, M.Div., Business, M.Ed.) students. Around 21–24% of students are children of alumni, and although 37% of students come from the Midwestern United States, the student body represents all 50 states and 100 countries. As of March 2007[update] The Princeton Review ranked the school as the fifth highest 'dream school' for parents to send their ch...",How many teams participate in the Notre Dame Bookstore Basketball tournament?,5733b5df4776f41900661107,"{'end_positions': [1454], 'passage_indices': [0], 'start_positions': [1446], 'yes_no_answer': ['NONE']}",UNKNOWN
2,University_of_Notre_Dame,"The library system of the university is divided between the main library and each of the colleges and schools. The main building is the 14-story Theodore M. Hesburgh Library, completed in 1963, which is the third building to house the main collection of books. The front of the library is adorned with the Word of Life mural designed by artist Millard Sheets. This mural is popularly known as ""Touchdown Jesus"" because of its proximity to Notre Dame Stadium and Jesus' arms appearing to make the s...",What is the name of the main library at Notre Dame?,5733ad384776f41900660fed,"{'end_positions': [173], 'passage_indices': [0], 'start_positions': [145], 'yes_no_answer': ['NONE']}",UNKNOWN
3,University_of_Notre_Dame,"Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend ""Venite Ad Me Omnes"". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary repu...",To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?,5733be284776f41900661182,"{'end_positions': [541], 'passage_indices': [0], 'start_positions': [515], 'yes_no_answer': ['NONE']}",UNKNOWN
4,University_of_Notre_Dame,"In 1919 Father James Burns became president of Notre Dame, and in three years he produced an academic revolution that brought the school up to national standards by adopting the elective system and moving away from the university's traditional scholastic and classical emphasis. By contrast, the Jesuit colleges, bastions of academic conservatism, were reluctant to move to a system of electives. Their graduates were shut out of Harvard Law School for that reason. Notre Dame continued to grow ov...",Those who attended a Jesuit college may have been forbidden from joining which Law School due to the curricula at the Jesuit institution?,57338724d058e614000b5ca0,"{'end_positions': [448], 'passage_indices': [0], 'start_positions': [430], 'yes_no_answer': ['NONE']}",UNKNOWN
5,University_of_Notre_Dame,All of Notre Dame's undergraduate students are a part of one of the five undergraduate colleges at the school or are in the First Year of Studies program. The First Year of Studies program was established in 1962 to guide incoming freshmen in their first year at the school before they have declared a major. Each student is given an academic advisor from the program who helps them to choose classes that give them exposure to any major in which they are interested. The program also includes a L...,What entity provides help with the management of time for new students at Notre Dame?,5733a70c4776f41900660f64,"{'end_positions': [520], 'passage_indices': [0], 'start_positions': [496], 'yes_no_answer': ['NONE']}",UNKNOWN
6,University_of_Notre_Dame,"This Main Building, and the library collection, was entirely destroyed by a fire in April 1879, and the school closed immediately and students were sent home. The university founder, Fr. Sorin and the president at the time, the Rev. William Corby, immediately planned for the rebuilding of the structure that had housed virtually the entire University. Construction was started on the 17th of May and by the incredible zeal of administrator and workers the building was completed before the fall s...",In what year was the Main Building at Notre Dame razed in a fire?,57338653d058e614000b5c81,"{'end_positions': [94], 'passage_indices': [0], 'start_positions': [90], 'yes_no_answer': ['NONE']}",UNKNOWN
7,University_of_Notre_Dame,"In 1882, Albert Zahm (John Zahm's brother) built an early wind tunnel used to compare lift to drag of aeronautical models. Around 1899, Professor Jerome Green became the first American to send a wireless message. In 1931, Father Julius Nieuwland performed early work on basic reactions that was used to create neoprene. Study of nuclear physics at the university began with the building of a nuclear accelerator in 1936, and continues now partly through a partnership in the Joint Institute for Nu...",In what year did Albert Zahm begin comparing aeronatical models at Notre Dame?,5733b1da4776f41900661068,"{'end_positions': [7], 'passage_indices': [0], 'start_positions': [3], 'yes_no_answer': ['NONE']}",UNKNOWN
8,University_of_Notre_Dame,"Notre Dame is known for its competitive admissions, with the incoming class enrolling in fall 2015 admitting 3,577 from a pool of 18,156 (19.7%). The academic profile of the enrolled class continues to rate among the top 10 to 15 in the nation for national research universities. The university practices a non-restrictive early action policy that allows admitted students to consider admission to Notre Dame as well as any other colleges to which they were accepted. 1,400 of the 3,577 (39.1%) we...",How many miles does the average student at Notre Dame travel to study there?,5733ae924776f41900661017,"{'end_positions': [637], 'passage_indices': [0], 'start_positions': [618], 'yes_no_answer': ['NONE']}",UNKNOWN
9,University_of_Notre_Dame,"The College of Engineering was established in 1920, however, early courses in civil and mechanical engineering were a part of the College of Science since the 1870s. Today the college, housed in the Fitzpatrick, Cushing, and Stinson-Remick Halls of Engineering, includes five departments of study – aerospace and mechanical engineering, chemical and biomolecular engineering, civil engineering and geological sciences, computer science and engineering, and electrical engineering – with eight B.S....",The College of Science began to offer civil engineering courses beginning at what time at Notre Dame?,5733a6424776f41900660f52,"{'end_positions': [164], 'passage_indices': [0], 'start_positions': [155], 'yes_no_answer': ['NONE']}",UNKNOWN


In [8]:
from primeqa.mrc.data_models.target_type import TargetType

def target_type_as_str(feature):
    feature['target_type'] = TargetType(feature['target_type']).name
    return feature

random_train_dataset = train_dataset.filter(lambda feature: feature['example_idx'] in random_idxs).remove_columns(['attention_mask', 'offset_mapping'])
show_elements(random_train_dataset.map(target_type_as_str))  # Show random train features

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/10 [00:00<?, ?ex/s]

Unnamed: 0,example_id,input_ids,example_idx,start_positions,end_positions,target_type
0,5733be284776f41900661182,"[0, 3972, 2661, 222, 5, 9880, 2708, 2346, 2082, 11, 504, 4432, 11, 226, 2126, 10067, 1470, 116, 2, 2, 37848, 37471, 28108, 6, 5, 334, 34, 10, 4019, 2048, 4, 497, 1517, 5, 4326, 6919, 18, 1637, 31346, 16, 10, 9030, 9577, 9, 5, 9880, 2708, 4, 29261, 11, 760, 9, 5, 4326, 6919, 8, 2114, 24, 6, 16, 10, 7621, 9577, 9, 4845, 19, 3701, 62, 33161, 19, 5, 7875, 22, 39043, 1459, 1614, 1464, 13292, 4977, 845, 4130, 7, 5, 4326, 6919, 16, 5, 26429, 2426, 9, 5, 25095, 6924, 4, 29261, 639, 5, 32394, 2426, 16, ...]",0,135,142,SPAN_ANSWER
1,5733a6424776f41900660f52,"[0, 133, 1821, 9, 4662, 880, 7, 904, 2366, 4675, 7484, 1786, 23, 99, 86, 23, 10579, 9038, 116, 2, 2, 133, 1821, 9, 9466, 21, 2885, 11, 18283, 6, 959, 6, 419, 7484, 11, 2366, 8, 12418, 4675, 58, 10, 233, 9, 5, 1821, 9, 4662, 187, 5, 41102, 29, 4, 2477, 5, 1564, 6, 15740, 11, 5, 20842, 6, 230, 11286, 6, 8, 312, 9554, 12, 31157, 1758, 41621, 9, 9466, 6, 1171, 292, 6522, 9, 892, 126, 15064, 8, 12418, 4675, 6, 4747, 8, 43963, 4104, 32188, 4675, 6, 2366, 4675, 8, 30694, 17874, 6, 3034, 2866, ...]",19,48,50,SPAN_ANSWER
2,5733a70c4776f41900660f64,"[0, 2264, 10014, 1639, 244, 19, 5, 1052, 9, 86, 13, 92, 521, 23, 10579, 9038, 116, 2, 2, 3684, 9, 10579, 9038, 18, 19555, 521, 32, 10, 233, 9, 65, 9, 5, 292, 19555, 8975, 23, 5, 334, 50, 32, 11, 5, 1234, 2041, 9, 9307, 586, 4, 20, 1234, 2041, 9, 9307, 586, 21, 2885, 11, 19515, 7, 4704, 11433, 19684, 11, 49, 78, 76, 23, 5, 334, 137, 51, 33, 2998, 10, 538, 4, 4028, 1294, 16, 576, 41, 5286, 11220, 31, 5, 586, 54, 2607, 106, 7, 2807, 4050, 14, 492, 106, 4895, 7, 143, 538, ...]",20,111,113,SPAN_ANSWER
3,5733ad384776f41900660fed,"[0, 2264, 16, 5, 766, 9, 5, 1049, 5560, 23, 10579, 9038, 116, 2, 2, 133, 5560, 467, 9, 5, 2737, 16, 6408, 227, 5, 1049, 5560, 8, 349, 9, 5, 8975, 8, 1304, 4, 20, 1049, 745, 16, 5, 501, 12, 6462, 26164, 256, 4, 32899, 24035, 5672, 6, 2121, 11, 18733, 6, 61, 16, 5, 371, 745, 7, 790, 5, 1049, 2783, 9, 2799, 4, 20, 760, 9, 5, 5560, 16, 29191, 19, 5, 15690, 9, 3126, 21281, 1887, 30, 3025, 5388, 1120, 264, 2580, 4, 152, 21281, 16, 1406, 352, 684, 25, 22, 40121, 3955, 5772, 113, ...]",35,43,48,SPAN_ANSWER
4,5733ae924776f41900661017,"[0, 6179, 171, 1788, 473, 5, 674, 1294, 23, 10579, 9038, 1504, 7, 892, 89, 116, 2, 2, 7199, 241, 9038, 16, 684, 13, 63, 2695, 18054, 6, 19, 5, 11433, 1380, 16914, 154, 11, 1136, 570, 13874, 155, 6, 36447, 31, 10, 3716, 9, 504, 6, 27915, 36, 1646, 4, 406, 23528, 20, 5286, 4392, 9, 5, 12751, 1380, 1388, 7, 731, 566, 5, 299, 158, 7, 379, 11, 5, 1226, 13, 632, 557, 6630, 4, 20, 2737, 3464, 10, 786, 12, 7110, 12127, 2088, 419, 814, 714, 14, 2386, 2641, 521, 7, 1701, 7988, 7, 10579, 9038, 25, ...]",43,147,150,SPAN_ANSWER
5,5733b1da4776f41900661068,"[0, 1121, 99, 76, 222, 8098, 21008, 119, 1642, 12818, 16482, 261, 36105, 3092, 23, 10579, 9038, 116, 2, 2, 1121, 504, 6551, 6, 8098, 21008, 119, 36, 10567, 21008, 119, 18, 2138, 43, 1490, 41, 419, 2508, 10615, 341, 7, 8933, 5258, 7, 8386, 9, 16482, 261, 26832, 3092, 4, 8582, 43130, 6, 6020, 15385, 1628, 1059, 5, 78, 470, 7, 2142, 10, 6955, 1579, 4, 96, 36332, 6, 9510, 20487, 234, 13627, 605, 1245, 3744, 419, 173, 15, 3280, 11012, 14, 21, 341, 7, 1045, 3087, 1517, 20962, 4, 13019, 9, 1748, 17759, 23, 5, 2737, 880, 19, ...]",54,21,22,SPAN_ANSWER
6,5733b1da4776f4190066106b,"[0, 32251, 1736, 1006, 15, 1377, 23, 10579, 9038, 14, 2140, 1412, 3087, 1517, 20962, 116, 2, 2, 1121, 504, 6551, 6, 8098, 21008, 119, 36, 10567, 21008, 119, 18, 2138, 43, 1490, 41, 419, 2508, 10615, 341, 7, 8933, 5258, 7, 8386, 9, 16482, 261, 26832, 3092, 4, 8582, 43130, 6, 6020, 15385, 1628, 1059, 5, 78, 470, 7, 2142, 10, 6955, 1579, 4, 96, 36332, 6, 9510, 20487, 234, 13627, 605, 1245, 3744, 419, 173, 15, 3280, 11012, 14, 21, 341, 7, 1045, 3087, 1517, 20962, 4, 13019, 9, 1748, 17759, 23, 5, 2737, 880, 19, 5, 745, ...]",57,68,73,SPAN_ANSWER
7,5733b5df4776f41900661107,"[0, 6179, 171, 893, 4064, 11, 5, 10579, 9038, 5972, 8005, 12610, 1967, 116, 2, 2, 1121, 777, 5, 10579, 9038, 1294, 809, 22061, 9, 316, 6, 26340, 521, 6, 19, 290, 6, 36246, 37340, 24597, 6, 132, 6, 25352, 5318, 8, 2038, 8, 112, 6, 39785, 2038, 36, 22532, 6, 256, 4, 37165, 482, 2090, 6, 256, 4, 5404, 1592, 521, 4, 8582, 733, 2383, 1978, 207, 9, 521, 32, 408, 9, 16132, 6, 8, 1712, 2908, 207, 9, 521, 283, 31, 5, 4079, 16507, 315, 532, 6, 5, 1294, 809, 3372, 70, 654, 982, 8, 727, 749, 4, ...]",75,306,307,SPAN_ANSWER
8,57338653d058e614000b5c81,"[0, 1121, 99, 76, 21, 5, 4326, 6919, 23, 10579, 9038, 910, 16314, 11, 10, 668, 116, 2, 2, 713, 4326, 6919, 6, 8, 5, 5560, 2783, 6, 21, 4378, 4957, 30, 10, 668, 11, 587, 504, 5220, 6, 8, 5, 334, 1367, 1320, 8, 521, 58, 1051, 184, 4, 20, 2737, 3787, 6, 4967, 4, 14405, 179, 8, 5, 394, 23, 5, 86, 6, 5, 7161, 4, 2897, 2812, 1409, 6, 1320, 1904, 13, 5, 13407, 9, 5, 3184, 14, 56, 15740, 8077, 5, 1445, 589, 4, 8911, 21, 554, 15, 5, 601, 212, 9, 392, 8, 30, 5, ...]",89,36,37,SPAN_ANSWER
9,57338724d058e614000b5ca0,"[0, 11195, 54, 2922, 10, 34466, 1564, 189, 33, 57, 27686, 31, 3736, 61, 2589, 835, 528, 7, 5, 39167, 5571, 23, 5, 34466, 7540, 2, 2, 1121, 35284, 9510, 957, 11247, 1059, 394, 9, 10579, 9038, 6, 8, 11, 130, 107, 37, 2622, 41, 5286, 7977, 14, 1146, 5, 334, 62, 7, 632, 2820, 30, 15059, 5, 10371, 2088, 467, 8, 1375, 409, 31, 5, 2737, 18, 2065, 8447, 1168, 11599, 8, 15855, 9723, 4, 870, 5709, 6, 5, 34466, 8975, 6, 25753, 2485, 9, 5286, 35717, 6, 58, 11923, 7, 517, 7, 10, 467, 9, 10371, 3699, 4, ...]",97,106,108,SPAN_ANSWER


# Fine-tuning

Here we fine-tune the model on the training set.

In [9]:
from operator import attrgetter
import datasets
from transformers import DataCollatorWithPadding
from primeqa.mrc.data_models.eval_prediction_with_processing import EvalPredictionWithProcessing
from primeqa.mrc.processors.postprocessors.squad import SQUADPostProcessor
from primeqa.mrc.processors.postprocessors.scorers import SupportedSpanScorers
from primeqa.mrc.metrics.squad import squad

# If using mixed precision we pad for efficient hardware acceleration
using_mixed_precision = any(attrgetter('fp16', 'bf16')(training_args))
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=64 if using_mixed_precision else None)

# noinspection PyProtectedMember
postprocessor = SQUADPostProcessor(
    k=3,
    n_best_size=20,
    max_answer_length=30,
    scorer_type=SupportedSpanScorers.WEIGHTED_SUM_TARGET_TYPE_AND_SCORE_DIFF,
    single_context_multiple_passages=preprocessor._single_context_multiple_passages,
)

def compute_metrics(p: EvalPredictionWithProcessing):
    eval_metrics = datasets.load_metric(squad.__file__)
    return eval_metrics.compute(predictions=p.processed_predictions, references=p.label_ids)

trainer = MRCTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset if training_args.do_train else None,
    eval_dataset=eval_dataset if training_args.do_eval else None,
    eval_examples=eval_examples if training_args.do_eval else None,
    tokenizer=tokenizer,
    data_collator=data_collator,
    post_process_function=postprocessor.process_references_and_predictions,  # see QATrainer in Huggingface
    compute_metrics=compute_metrics,
)

train_result = trainer.train()
trainer.save_model()  # Saves the tokenizer too for easy upload

metrics = train_result.metrics
max_train_samples = max_train_samples or len(train_dataset)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

***** Running training *****
  Num examples = 100
  Num Epochs = 1
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 8
  Total optimization steps = 1


Step,Training Loss




Training completed. Do not forget to share your model on huggingface.co/models =)


Saving model checkpoint to tmp
Configuration saved in tmp/config.json
Model weights saved in tmp/pytorch_model.bin
tokenizer config file saved in tmp/tokenizer_config.json
Special tokens file saved in tmp/special_tokens_map.json


***** train metrics *****
  epoch                    =       0.62
  total_flos               =     9154GF
  train_loss               =     4.3079
  train_runtime            = 0:00:29.60
  train_samples            =        100
  train_samples_per_second =      3.378
  train_steps_per_second   =      0.034


# Evaluation

Here we evaluate the model on the validation set.

In [10]:
metrics = trainer.evaluate()

max_eval_samples = max_eval_samples or len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))

trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


100%|██████████| 20/20 [00:00<00:00, 548.78it/s]
100%|██████████| 20/20 [00:00<00:00, 3911.14it/s]


***** eval metrics *****
  epoch            =   0.62
  eval_exact_match =    0.0
  eval_f1          = 9.5732
  eval_samples     =     20


# Predictions

Here we examine the model predictions.

In [11]:
import json
import os
from pprint import pprint

with open(os.path.join(output_dir, 'eval_predictions_processed.json'), 'r') as f:
    predictions = json.load(f)

pprint(predictions)

[{'id': '573786b51c4567190057448d',
  'prediction_text': 'effects of gravity might be observed in different'},
 {'id': '56e1ddfce3433e14004231d8',
  'prediction_text': 'types of integer programming problems'},
 {'id': '56bf48cc3aeaaa14008c95af',
  'prediction_text': 'Broncos last wore matching white jerseys'},
 {'id': '5725c337271a42140099d164',
  'prediction_text': 'species, which live as parasites on the salps on which '
                     'adults of their species feed. In favorable '
                     'circumstances, ctenoph'},
 {'id': '5725e44238643c19005ace36',
  'prediction_text': 'F. Gordon, Jr. Conrad and Bean carried the first'},
 {'id': '571ccfbadd7acb1400e4c164',
  'prediction_text': 'liners in case of depressurization emergencies. Another '
                     'air separation technology involves forcing air to '
                     'dissolve through ceramic membranes based on zircon'},
 {'id': '56f827caa6d7ea1400e1743a',
  'prediction_text': "held to determine Luther