In [1]:
# colab resource monitor
from urllib.request import urlopen
exec(urlopen("http://colab-monitor.smankusors.com/track.py").read())
_colabMonitor = ColabMonitor().start()

Now live at : http://colab-monitor.smankusors.com/609929dd2e69d


In [2]:
!pip install transformers
!pip install datasets



In [3]:
!nvidia-smi -L

GPU 0: Tesla V100-SXM2-16GB (UUID: GPU-ad40b1f3-6157-c110-5a1f-00133b36c2ce)


# Fine-tuning a model on a question-answering task
This notebook will show how to fine-tune one of the 🤗 Transformers model to a question answering task, which is the task of extracting the answer to a question from a given context.
<br><br>
**Note** : This notebook finetunes models that answer question by taking a substring of a context, not by generating new text.

In [4]:
# set main parameters
squad_v2_flag = False
model_checkpoint = "distilbert-base-uncased"
batch_size = 16

# check execution time for whole code
import time
s_time = time.time()

In [5]:
import datasets

import pandas as pd
import numpy as np
import random
import collections
import tqdm

from IPython.display import display, HTML

import transformers
from transformers import AutoTokenizer
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer
from transformers import default_data_collator

import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# datasets : 1.6.1  |  pd : 1.1.5  |  np : 1.19.5  |  tqdm : 4.41.1  |  transformers : 4.5.1  |  torch : 1.8.1+cu101
print(f'datasets : {datasets.__version__}  |  pd : {pd.__version__}  |  np : {np.__version__}  |  tqdm : {tqdm.__version__}  |  transformers : {transformers.__version__}  |  torch : {torch.__version__}')
print('device :', device)

datasets : 1.6.2  |  pd : 1.1.5  |  np : 1.19.5  |  tqdm : 4.41.1  |  transformers : 4.5.1  |  torch : 1.8.1+cu101
device : cuda


## 1. Loading the dataset & metric
- We will use the 🤗 Datasets library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the functions `load_dataset` and `load_metric`.

- 🤗 Datasets library also provide `list_datasets()` function to get the list of all available datasets. There are about 21 datasets related to QA task.
  - ref : https://huggingface.co/datasets/squad_kor_v1 (Korean squad_v1 by LG CNS)
  - ref : https://huggingface.co/datasets/squad_kor_v2 (Korean squad_v2 by LG CNS)

In [6]:
# check dataset list
dset_list = datasets.list_datasets()
qa_dset_list = [i for i in dset_list if 'quad' in i]

print('>>> Total No of provided datasets :', len(dset_list))
print('>>> No of QA datasets :', len(qa_dset_list))
print(np.array([i for i in dset_list if 'quad' in i]))

>>> Total No of provided datasets : 878
>>> No of QA datasets : 22
['fquad' 'iapp_wiki_qa_squad' 'lc_quad' 'squad' 'squad_adversarial'
 'squad_es' 'squad_it' 'squad_kor_v1' 'squad_kor_v2' 'squad_v1_pt'
 'squad_v2' 'squadshifts' 'thaiqa_squad' 'xquad' 'xquad_r'
 'lhoestq/custom_squad' 'lhoestq/squad' 'piEsposito/br-quad-2.0'
 'piEsposito/br_quad_20' 'piEsposito/squad_20_ptbr'
 'susumu2357/squad_v2_sv' 'vershasaxena91/squad_multitask']


In [7]:
# load dataset & metric
dset_dict = datasets.load_dataset('squad_v2' if squad_v2_flag else 'squad')
metric = datasets.load_metric("squad_v2" if squad_v2_flag else "squad")

# check dataset
print('\n>>> dataset object :')
display(dset_dict)
print('\n>>> sample data :')
display(dset_dict['train'][0])

Reusing dataset squad (/root/.cache/huggingface/datasets/squad/plain_text/1.0.0/4fffa6cf76083860f85fa83486ec3028e7e32c342c218ff2a620fc6b2868483a)



>>> dataset object :


DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})


>>> sample data :


{'answers': {'answer_start': [515], 'text': ['Saint Bernadette Soubirous']},
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 'id': '5733be284776f41900661182',
 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
 'title': 'University_of_Notre_Dame'}

In [8]:
# show random sample of a dataset
def show_random_elements(dataset, num_examples=5):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    random.seed(777)
    picks = random.sample(range(len(dataset)), k=num_examples)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

show_random_elements(dset_dict["train"], 2)

Unnamed: 0,answers,context,id,question,title
0,"{'answer_start': [39], 'text': ['52%']}","In absolute terms, the planet has lost 52% of its biodiversity since 1970 according to a 2014 study by the World Wildlife Fund. The Living Planet Report 2014 claims that ""the number of mammals, birds, reptiles, amphibians and fish across the globe is, on average, about half the size it was 40 years ago"". Of that number, 39% accounts for the terrestrial wildlife gone, 39% for the marine wildlife gone, and 76% for the freshwater wildlife gone. Biodiversity took the biggest hit in Latin America, plummeting 83 percent. High-income countries showed a 10% increase in biodiversity, which was canceled out by a loss in low-income countries. This is despite the fact that high-income countries use five times the ecological resources of low-income countries, which was explained as a result of process whereby wealthy nations are outsourcing resource depletion to poorer nations, which are suffering the greatest ecosystem losses.",570bc6466b8089140040fa30,What percentage of biodiversity has the planet lost since 1970,Biodiversity
1,"{'answer_start': [252], 'text': ['more than 100']}","Later the emphasis was on classical studies, dominated by Latin and Ancient History, and, for boys with sufficient ability, Classical Greek. From the latter part of the 19th century this curriculum has changed and broadened: for example, there are now more than 100 students of Chinese, which is a non-curriculum course. In the 1970s, there was just one school computer, in a small room attached to the science buildings. It used paper tape to store programs. Today, all boys must have laptop computers, and the school fibre-optic network connects all classrooms and all boys' bedrooms to the internet.",5727bd3d4b864d1900163c00,How many current students take Chinese courses at Eaton?,Eton_College


## 2. Preprocessing the data


In [9]:
# set parameters for tokenizer
max_length = 384  # The maximum length of a feature (question and context)
doc_stride = 128  # The authorized overlap between two part of the context when splitting it is needed.

# donwload and initialize pre-trained tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
pad_on_right = tokenizer.padding_side == 'right'

# check if fast tokenizer
print(tokenizer)
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)  # raises error if False

PreTrainedTokenizerFast(name_or_path='distilbert-base-uncased', vocab_size=30522, model_max_len=512, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})


In [10]:
# check tokenizer output
tokenized = tokenizer("What is your name?", "My name is Sylvain.")
decoded = tokenizer.decode(tokenized['input_ids'])

for k, v in tokenized.items():
  print(f'>>> {k:<15} : {v}')
print(f'\n>>> decoded : "{decoded}"')

>>> input_ids       : [101, 2054, 2003, 2115, 2171, 1029, 102, 2026, 2171, 2003, 25353, 22144, 2378, 1012, 102]
>>> attention_mask  : [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

>>> decoded : "[CLS] what is your name? [SEP] my name is sylvain. [SEP]"


- Now one specific thing for the preprocessing in question answering is how to deal with very long documents. We usually truncate them in other tasks, when they are longer than the model maximum sentence length, but here, removing part of the the context might result in losing the answer we are looking for.

- To deal with this, we will allow one (long) example in our dataset to give several input features, whose length shorter than the maximum length of the model (or the one we set as a hyper-parameter). Also, just in case the answer lies at the point we split a long context, we allow some overlap between the features we generate controlled by the hyper-parameter `doc_stride`:

In [11]:
# get one example longer than max_length
for i, example in enumerate(dset_dict['train']):
  input_len = len(tokenizer(example['context'], example['question'])['input_ids'])
  if input_len > max_length:
    print(f'>>> found {i+1}th example input with {input_len} tokens')
    break

example

>>> found 250th example input with 396 tokens


{'answers': {'answer_start': [30], 'text': ['over 1,600']},
 'context': "The men's basketball team has over 1,600 wins, one of only 12 schools who have reached that mark, and have appeared in 28 NCAA tournaments. Former player Austin Carr holds the record for most points scored in a single game of the tournament with 61. Although the team has never won the NCAA Tournament, they were named by the Helms Athletic Foundation as national champions twice. The team has orchestrated a number of upsets of number one ranked teams, the most notable of which was ending UCLA's record 88-game winning streak in 1974. The team has beaten an additional eight number-one teams, and those nine wins rank second, to UCLA's 10, all-time in wins against the top team. The team plays in newly renovated Purcell Pavilion (within the Edmund P. Joyce Center), which reopened for the beginning of the 2009–2010 season. The team is coached by Mike Brey, who, as of the 2014–15 season, his fifteenth at Notre Dame, has ac

In [12]:
# check example after tokenization (with truncation, stride, return_overflowing_tokens)
tokenized_example = tokenizer(
    example["question" if pad_on_right else "context"],
    example["context" if pad_on_right else "question"],
    truncation='only_second' if pad_on_right else 'only_first',  # never truncate the question, only the context
    max_length=max_length,
    stride=doc_stride,  # No of tokens to overlap between truncated text & overflowing text
    return_overflowing_tokens=True,  # return all overflowing texts (nested list for input_ids when overflowing)
    return_offsets_mapping=True,  # return offset_mapping, the corresponding start and end character in the original text
    padding="max_length",
)

for k in tokenized_example:
  if isinstance(tokenized_example[k][0], list):
    print(f'>>> {k} :')
    for lst in tokenized_example[k]:
      print('\t', lst[:30] + ['.....'] + lst[-30:])
    print()
  else:
    print(f'>>> {k} :\n\t{tokenized_example[k]}\n')

>>> input_ids :
	 [101, 2129, 2116, 5222, 2515, 1996, 10289, 8214, 2273, 1005, 1055, 3455, 2136, 2031, 1029, 102, 1996, 2273, 1005, 1055, 3455, 2136, 2038, 2058, 1015, 1010, 5174, 5222, 1010, 2028, '.....', 6862, 3946, 1998, 6986, 9530, 2532, 18533, 2239, 1010, 1996, 3554, 3493, 3786, 1996, 9523, 2120, 3410, 3804, 2630, 13664, 3807, 2076, 1996, 2161, 1012, 1996, 3590, 5222, 2020, 102]
	 [101, 2129, 2116, 5222, 2515, 1996, 10289, 8214, 2273, 1005, 1055, 3455, 2136, 2031, 1029, 102, 2528, 1012, 1996, 2230, 1516, 2340, 2136, 5531, 2049, 3180, 2161, 4396, 2193, 2698, '.....', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

>>> attention_mask :
	 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, '.....', 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
	 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, '.....', 0, 0, 0, 0, 0, 

In [13]:
# check truncated & overflowing texts
for lst in tokenized_example['input_ids']:
  print(len(lst))
  print(tokenizer.decode(lst), '\n')

384
[CLS] how many wins does the notre dame men's basketball team have? [SEP] the men's basketball team has over 1, 600 wins, one of only 12 schools who have reached that mark, and have appeared in 28 ncaa tournaments. former player austin carr holds the record for most points scored in a single game of the tournament with 61. although the team has never won the ncaa tournament, they were named by the helms athletic foundation as national champions twice. the team has orchestrated a number of upsets of number one ranked teams, the most notable of which was ending ucla's record 88 - game winning streak in 1974. the team has beaten an additional eight number - one teams, and those nine wins rank second, to ucla's 10, all - time in wins against the top team. the team plays in newly renovated purcell pavilion ( within the edmund p. joyce center ), which reopened for the beginning of the 2009 – 2010 season. the team is coached by mike brey, who, as of the 2014 – 15 season, his fifteenth at 

In [14]:
# check vocab & offset index
first_token_id, first_token_offset = tokenized_example['input_ids'][0][1], tokenized_example['offset_mapping'][0][1]
print('>>> token index in vocab :', first_token_id)
print('>>> token index in string :', first_token_offset)

decoded_tk = tokenizer.decode(first_token_id)
sliced_tk = example['question'][first_token_offset[0]:first_token_offset[1]]
print('\n>>> token from vocab :', decoded_tk)
print('>>> token from example string :', sliced_tk)

>>> token index in vocab : 2129
>>> token index in string : (0, 3)

>>> token from vocab : how
>>> token from example string : How


In [15]:
%%time

# function to preprocess texts in train dataset
def prepare_train_features(examples):
  tokenized_examples = tokenizer(
      examples["question" if pad_on_right else "context"],
      examples["context" if pad_on_right else "question"],
      truncation='only_second' if pad_on_right else 'only_first',  # never truncate the question, only the context
      max_length=max_length,
      stride=doc_stride,  # No of tokens to overlap between truncated text & overflowing text
      return_overflowing_tokens=True,  # return all overflowing texts (nested list for input_ids when overflowing)
      return_offsets_mapping=True,  # return offset_mapping, the corresponding start and end character in the original text
      padding="max_length",
  )

  sample_mapping = tokenized_examples.pop('overflow_to_sample_mapping')
  offset_mapping = tokenized_examples.pop('offset_mapping')

  tokenized_examples['start_positions'] = []
  tokenized_examples['end_positions'] = []

  for i, offsets in enumerate(offset_mapping):
    input_ids = tokenized_examples['input_ids'][i]
    cls_id = input_ids.index(tokenizer.cls_token_id)
    sequence_ids = tokenized_examples.sequence_ids(i)

    sample_id = sample_mapping[i]
    answers = examples['answers'][sample_id]

    if len(answers['answer_start'])==0:
      tokenized_examples['start_positions'].append(cls_id)
      tokenized_examples['end_positions'].append(cls_id)
    else:
      # get start & end character index of the answer in the text
      answer_start = answers['answer_start'][0]
      answer_end = answer_start + len(answers['text'][0])

      # start token index of the current span in the text
      token_start_id = 0
      while sequence_ids[token_start_id] != (1 if pad_on_right else 0):
        token_start_id += 1
      
      # end token index of the current span in the text
      token_end_id = len(input_ids)-1
      while sequence_ids[token_end_id] != (1 if pad_on_right else 0):
        token_end_id -= 1

      if answer_start < offsets[token_start_id][0] or offsets[token_end_id][1] < answer_end:
        # detect if the answer is out of the span (in which case this feature is labeled with the CLS index)
        tokenized_examples['start_positions'].append(cls_id)
        tokenized_examples['end_positions'].append(cls_id)
      else:
        # otherwise move the token_start_index and token_end_index to the two ends of the answer
        # NOTE : token_end_index can go after the last offset if the answer is the last word (edge case)
        while token_start_id < len(offsets) and offsets[token_start_id][0] <= answer_start:
          token_start_id += 1
        tokenized_examples['start_positions'].append(token_start_id - 1)
          
        while offsets[token_end_id][1] >= answer_end:
          token_end_id -= 1
        tokenized_examples['end_positions'].append(token_end_id + 1)
  
  return tokenized_examples


# apply function to preprocess texts in train dataset
## since the function changes the number of samples, we need to remove the old columns when applying it
tokenized_datasets = dset_dict.map(
    prepare_train_features,  # function to apply
    batched=True,  # encode the texts by batches together (fast tokenizer's multi-threading to treat the texts in a batch concurrently)
    remove_columns=dset_dict['train'].column_names  # remove the old columns
)

HBox(children=(FloatProgress(value=0.0, max=88.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))


CPU times: user 1min 32s, sys: 842 ms, total: 1min 33s
Wall time: 39.8 s


In [16]:
# check items in sample from original dataset
for k, v in dset_dict['train'][0].items():
  print(f'>>> {k} ({len(v)} items) :\n\t{v}\n') if isinstance(v, list) else  print(f'>>> {k} (1 items) :\n\t{v}\n')

>>> answers (1 items) :
	{'answer_start': [515], 'text': ['Saint Bernadette Soubirous']}

>>> context (1 items) :
	Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.

>>> id (1 items) :
	5733be284776f41900661182

>>> question (1 items) :
	To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?

>>> title (1 items) :
	University_of_Notre_

In [17]:
# check items in sample from tokenized dataset
for k, v in tokenized_datasets['train'][0].items():
  print(f'>>> {k} ({len(v)} items) :\n\t{v}\n') if isinstance(v, list) else  print(f'>>> {k} (1 items) :\n\t{v}\n')

>>> attention_mask (384 items) :
	[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [18]:
# check items in sample from tokenized dataset (5 samples)
for k, v in tokenized_datasets['train'][:5].items():
  print(f'{k} :\n\t{v}\n')

attention_mask :
	[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

## 3. Fine-tuning the model

In [19]:
# load pre-trained model
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint).to(device)

# # check layers in model
# print('\n', model)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForQuestionAnswering: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this mode

In [20]:
%%time
# set training arguments
args = TrainingArguments(
    output_dir='test-squad',
    evaluation_strategy = 'epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
)

# set data collector
## >>> Data collators : objects that will form a batch by using a list of dataset elements as input.
## These elements are of the same type as the elements of train_dataset or eval_dataset.
## To be able to build batches, data collators may apply some processing (like padding).
## >>> default_data_collator : very simple data collator that simply collates batches of dict-like objects
## and performs special handling for potential keys
data_collator = default_data_collator

# initialize trainer
trainer = Trainer(
    model, args, 
    data_collator=data_collator, 
    train_dataset=tokenized_datasets['train'].select(range(10000)),  # use only 10000 examples for training
    eval_dataset=tokenized_datasets['validation'].select(range(5000)),   # use only 5000 examples for validation
    tokenizer=tokenizer, 
)

train_output = trainer.train()  # about 13 min (about 1 hour 50 min for whole dataset) with Tesla P100-PCIE-16GB device
trainer.save_model('test-squad-trained')

train_output

Epoch,Training Loss,Validation Loss,Runtime,Samples Per Second
1,2.8423,1.882277,17.0202,293.768
2,1.5278,1.674996,17.0327,293.552
3,1.2163,1.680981,17.0376,293.469


CPU times: user 9min 31s, sys: 5min 19s, total: 14min 51s
Wall time: 8min 2s


## 4. Evaluation

In [21]:
# get one sample batch from validation dataset
for batch in trainer.get_eval_dataloader():
  break
batch = {k:v.to(device) for k, v in batch.items()}

batch, type(batch)

({'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          ...,
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0'),
  'end_positions': tensor([ 47,  58,  92,  44, 113, 110,  75,  37, 110,  36,  76,  42,  83,  92,
          158,  35], device='cuda:0'),
  'input_ids': tensor([[ 101, 2029, 5088,  ...,    0,    0,    0],
          [ 101, 2029, 5088,  ...,    0,    0,    0],
          [ 101, 2073, 2106,  ...,    0,    0,    0],
          ...,
          [ 101, 2054, 2103,  ...,    0,    0,    0],
          [ 101, 2065, 3142,  ...,    0,    0,    0],
          [ 101, 3565, 4605,  ...,    0,    0,    0]], device='cuda:0'),
  'start_positions': tensor([ 46,  57,  89,  43, 113, 107,  72,  35, 107,  34,  73,  41,  80,  91,
          156,  35], device='cuda:0')},
 dict)

In [22]:
# get the model output(logits)
with torch.no_grad():
  output = trainer.model(**batch)
output

QuestionAnsweringModelOutput([('loss', tensor(1.9543, device='cuda:0')),
                              ('start_logits',
                               tensor([[-0.2529, -6.4996, -6.5697,  ..., -7.5709, -7.4981, -7.5687],
                                       [-0.2846, -6.4965, -6.5142,  ..., -7.5709, -7.4987, -7.5700],
                                       [ 0.0379, -5.5674, -6.3375,  ..., -7.6365, -7.6000, -7.6042],
                                       ...,
                                       [-0.2555, -5.6719, -5.2728,  ..., -7.5319, -7.5519, -7.6245],
                                       [-1.0535, -6.5248, -4.6819,  ..., -7.5702, -7.5799, -7.5823],
                                       [-1.4466, -4.1197, -6.9786,  ..., -7.5698, -7.4913, -7.5582]],
                                      device='cuda:0')),
                              ('end_logits',
                               tensor([[-0.1478, -6.4939, -6.1921,  ..., -7.2403, -7.3213, -7.2773],
                          

In [23]:
# check shape of logits
output.start_logits.shape, output.end_logits.shape

(torch.Size([16, 384]), torch.Size([16, 384]))

In [24]:
# get index(position) with highest possibility
output.start_logits.argmax(dim=1), output.end_logits.argmax(dim=1)

(tensor([ 46,  57,  89,  43, 167, 162,  72,   9, 162, 159,  73,  52,  80,  91,
         156,  35], device='cuda:0'),
 tensor([ 47,  58,  92,  44, 171, 166,  75,  43, 166, 163,  76,  42,  83,  94,
         158,  35], device='cuda:0'))

In [25]:
# search for best combination of start & end logits
## 1) select n_best_size logits from both start & end logits
## 2) calculate sum of start & end logits pair
## 3) pick the pair with highest logit sum

n_best_size = 20  # generate n_best_size**2 pairs
start_logits = output.start_logits[0].cpu().numpy()
end_logits = output.end_logits[0].cpu().numpy()

# get the indices of the best start & end logits
start_indices = np.argsort(start_logits)[-1:-n_best_size-1:-1].tolist()
end_indices = np.argsort(end_logits)[-1:-n_best_size-1:-1].tolist()

# check each pair of start & end indices and save valid pairs
valid_answers = []
for start_idx in start_indices:
  for end_idx in end_indices:
    if start_idx <= end_idx:  # Also need to check the answer is inside the context
      valid_answers.append({
          'score':start_logits[start_idx] + end_logits[end_idx],
          'text':''  # Also need to get back the original substring
      })

# check the result
pd.DataFrame(valid_answers).sort_values('score', ascending=False).reset_index(drop=True).iloc[[0, 1, -2, -1]]

Unnamed: 0,score,text
0,13.823847,
1,11.123655,
242,-6.655677,
243,-6.865956,


In [26]:
%%time
# function to preprocess texts in validation dataset
## difference from train features : ID of example, the offset mapping (a map from token indices to character positions in the context)

def prepare_validation_features(examples, n_best_size=20):
  tokenized_examples = tokenizer(
      text=examples["question" if pad_on_right else "context"],
      text_pair=examples["context" if pad_on_right else "question"],
      truncation="only_second" if pad_on_right else "only_first",
      max_length=max_length,
      stride=doc_stride,
      return_overflowing_tokens=True,
      return_offsets_mapping=True,
      padding="max_length",
  )

  # a map from each feature to its corresponding sample(example)
  sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
  
  tokenized_examples['example_id'] = []
  for i in range(len(tokenized_examples['input_ids'])):
    sequence_ids = tokenized_examples.sequence_ids(i)
    context_index = 1 if pad_on_right else 0

    # get each feature's id(key) of original sample
    sample_index = sample_mapping[i]
    tokenized_examples['example_id'].append(examples["id"][sample_index])

    # set offset_mapping of tokens out of context to None
    ## it will be easier to determine whether a token is part of context
    tokenized_examples['offset_mapping'][i] = [
          (tup if sequence_ids[j]==context_index else None)
          for j, tup in enumerate(tokenized_examples['offset_mapping'][i])
    ]
  
  return tokenized_examples

validation_features = dset_dict['validation'].map(
    prepare_validation_features,
    batched=True,
    remove_columns=dset_dict['validation'].column_names,
)

print('\n', type(validation_features))
print(validation_features, '\n')

HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))



 <class 'datasets.arrow_dataset.Dataset'>
Dataset({
    features: ['attention_mask', 'example_id', 'input_ids', 'offset_mapping'],
    num_rows: 10784
}) 

CPU times: user 24.3 s, sys: 117 ms, total: 24.4 s
Wall time: 18.4 s


In [27]:
%%time

# predict with trainer
raw_predictions = trainer.predict(validation_features)

print('\n', type(raw_predictions))
print(raw_predictions, '\n')


 <class 'transformers.trainer_utils.PredictionOutput'>
PredictionOutput(predictions=(array([[-0.25293124, -6.4996495 , -6.5697393 , ..., -7.570923  ,
        -7.498069  , -7.568746  ],
       [-0.2845543 , -6.496452  , -6.514182  , ..., -7.5709305 ,
        -7.498713  , -7.569987  ],
       [ 0.03791589, -5.5674148 , -6.3375044 , ..., -7.636454  ,
        -7.5999928 , -7.6041727 ],
       ...,
       [-2.1070764 , -6.5935106 , -7.090799  , ..., -7.4234753 ,
        -7.503483  , -7.5091343 ],
       [-2.338358  , -6.6644692 , -6.8318963 , ..., -7.4991574 ,
        -7.4365034 , -7.3411474 ],
       [-1.976466  , -6.508093  , -7.12375   , ..., -7.4792237 ,
        -7.5481377 , -7.5691814 ]], dtype=float32), array([[-0.14780149, -6.4939456 , -6.192118  , ..., -7.2402787 ,
        -7.3213086 , -7.277305  ],
       [-0.20517026, -6.471022  , -6.1261435 , ..., -7.2399507 ,
        -7.3203683 , -7.275719  ],
       [ 0.15468048, -6.246056  , -6.369344  , ..., -7.17046   ,
        -7.2453246 ,

In [28]:
# set the format of dataset to keep columns that are not used by the model
## (here 'example_id' and 'offset_mapping', which we will need for our post-processing)
print('< original format >')
display(validation_features.format)

validation_features.set_format(type=validation_features.format["type"], columns=list(validation_features.features.keys()))

print('\n< changed format >')
display(validation_features.format)

< original format >


{'columns': ['input_ids', 'attention_mask'],
 'format_kwargs': {},
 'output_all_columns': False,
 'type': None}


< changed format >


{'columns': ['attention_mask', 'example_id', 'input_ids', 'offset_mapping'],
 'format_kwargs': {},
 'output_all_columns': False,
 'type': None}

In [29]:
%%time
max_answer_length = 30

# search for best combination of start & end logits
## 1) select n_best_size logits from both start & end logits
## 2) calculate sum of start & end logits pair (exclude pairs that does not match conditions)
## 3) pick the pair with highest logit sum

# get the first output of the model
io_idx = 0
n_best_size = 20  # generate n_best_size**2 pairs
start_logits = output.start_logits[io_idx].cpu().numpy()
end_logits = output.end_logits[io_idx].cpu().numpy()

# get informations for post-processing
offset_mapping = validation_features['offset_mapping'][io_idx]  # io_idx -> example_id
context = dset_dict['validation']['context'][io_idx]

# get the indices of the best start & end logits
start_indices = np.argsort(start_logits)[-1:-n_best_size-1:-1].tolist()
end_indices = np.argsort(end_logits)[-1:-n_best_size-1:-1].tolist()

# check each pair of start & end indices and save valid pairs
valid_answers = []
for start_idx in start_indices:
  for end_idx in end_indices:
    
    # exlcude out-of-scope answers
    ## 1) the indices are out of bounds
    ## 2) the indices correspond part of the input_ids that are not in the context
    if (start_idx >= len(offset_mapping)
        or end_idx >= len(offset_mapping)
        or offset_mapping[start_idx] is None
        or offset_mapping[end_idx] is None) :
      continue
    
    # exclude answers with a length less than 0 or more than max_answer_length
    if end_idx < start_idx or end_idx - start_idx + 1 > max_answer_length:
      continue

    # get the score & text from model prediction
    else:
      start_char_idx = offset_mapping[start_idx][0]
      end_char_idx = offset_mapping[end_idx][1]
      valid_answers.append({
          'score':start_logits[start_idx] + end_logits[end_idx],
          'text':context[start_char_idx:end_char_idx]  # Also need to get back the original substring
      })

# check the result
print()
display(pd.DataFrame(valid_answers).sort_values('score', ascending=False).reset_index(drop=True).iloc[[0, 1, -2, -1]])

print('\n< ground truth >')
print(f"{dset_dict['validation']['answers'][0]}\n")




Unnamed: 0,score,text
0,13.823847,Denver Broncos
1,11.123655,Denver Broncos defeated the National Football ...
149,-6.655677,an American football game to determine the cha...
150,-6.865956,the National Football League (NFL) for the 201...



< ground truth >
{'answer_start': [177, 177, 177], 'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos']}

CPU times: user 6.38 s, sys: 157 ms, total: 6.54 s
Wall time: 6.46 s


In [30]:
# function for post-processing model predictions
def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size=20, max_answer_length=30):
  start_logits_arr, end_logits_arr = raw_predictions.predictions
  
  # build a map of features to its corresponding example
  example_to_idx = {id_string:idx for idx, id_string in enumerate(examples['id'])}
  features_per_example = collections.defaultdict(list)
  for idx, feature in enumerate(features):
    features_per_example[example_to_idx[feature['example_id']]].append(idx)
  
  # logging
  print(f'>>> Post-processing {len(examples)} example predictions split into {len(features)} features...')

  # post-process predictions over all examples
  predictions = collections.OrderedDict()
  for exmaple_idx, example in enumerate(tqdm.auto.tqdm(examples)):
    min_null_score = None  # only used if squad_v2_flag is True
    context = example['context']
    
    feature_id_list = features_per_example[exmaple_idx]
    valid_answers = []
    # for each feature from example...
    for feature_id in feature_id_list:
      start_logits = start_logits_arr[feature_id]
      end_logits = end_logits_arr[feature_id]
      offset_mapping = features[feature_id]['offset_mapping']

      # update minimum null prediction
      cls_idx = features[feature_id]['input_ids'].index(tokenizer.cls_token_id)  # usually 0
      feature_null_score = start_logits[cls_idx] + end_logits[cls_idx]
      if min_null_score is None or min_null_score < feature_null_score:
        min_null_score = feature_null_score
      
      # get the indices of the highest start & end logits
      start_indices = np.argsort(start_logits)[-1:-n_best_size-1:-1].tolist()
      end_indices = np.argsort(end_logits)[-1:-n_best_size-1:-1].tolist()

      # for each combination of start & end idx pairs...
      for start_idx in start_indices:
        for end_idx in end_indices:
          
          # exlcude out-of-scope answers
          if (start_idx >= len(offset_mapping)
              or end_idx >= len(offset_mapping)
              or offset_mapping[start_idx] is None
              or offset_mapping[end_idx] is None) :
            continue
          
          # exclude answers with a length less than 0 or more than max_answer_length
          if end_idx < start_idx or end_idx - start_idx + 1 > max_answer_length:
            continue

          # get the score & text from model prediction
          else:
            start_char_idx = offset_mapping[start_idx][0]
            end_char_idx = offset_mapping[end_idx][1]
            valid_answers.append({
                'score':start_logits[start_idx] + end_logits[end_idx],
                'text':context[start_char_idx:end_char_idx]
            })
      
    if len(valid_answers) > 0:
      best_answer = sorted(valid_answers, key=lambda x:x['score'])[-1]
    # in very rare case of no single non-null prediction, we create a fake prediction to aviod failure
    else:
      best_answer = {'score':0, 'text':''}

    # select the final answer - the one with highest score or null answer (only for squad_v2)
    if not squad_v2_flag:
      predictions[example['id']] = best_answer['text']
    else:
      predictions[example['id']] = best_answer['text'] if best_answer['score'] > min_null_score else ''
  
  return predictions

In [31]:
%%time
# post-process raw predictions
final_predictions = postprocess_qa_predictions(dset_dict['validation'], validation_features, raw_predictions)

>>> Post-processing 10570 example predictions split into 10784 features...


HBox(children=(FloatProgress(value=0.0, max=10570.0), HTML(value='')))


CPU times: user 29.9 s, sys: 223 ms, total: 30.1 s
Wall time: 29.7 s


In [32]:
%%time
# compute metric
if squad_v2_flag:
  formatted_predictions = [{'id':k, 'prediction_text':v, 'no_answer_probability':0} for k, v in final_predictions.items()]
else:
  formatted_predictions = [{'id':k, 'prediction_text':v} for k, v in final_predictions.items()]

references = [{'id':ex['id'], 'answers':ex['answers']} for ex in dset_dict['validation']]
print(metric.compute(predictions=formatted_predictions, references=references))

{'exact_match': 60.54872280037843, 'f1': 71.40810356048306}
CPU times: user 2.91 s, sys: 42 ms, total: 2.95 s
Wall time: 2.91 s


In [33]:
# check execution time for whole code
e_time = time.time()
time_elapsed = e_time - s_time
print(f'Total time elapsed : {int(time_elapsed//60)} min {int(time_elapsed%60)} sec')

Total time elapsed : 10 min 29 sec
