#### Finetuning BERT for `Extractive Question Answering` 

We will finetune a BERT model on the task of extractive QA, which involves taking a `factoid question` and a `context passage` of text and `labeling a span` of text from that passage which contains the `answer`. We can frame this as a `classification task`. First we concatenate the question and context passage pair, seperated by a `[SEP]` token. Then we compute the BERT encoding for this sequence. Then we apply a linear transform to each output token's encoding vector to compute a scalar score. By passing the scores from all tokens through a softmax, we obtain a `probability distribution` over tokens in the sequence, which we can interpret as the probability of a token being the start of the span. We actually will compute two separate linear transforms of all tokens and pass both sets of scores through a softmax to get two probability distributions over tokens, one for `start of span` and one for `end of span`. 

We will train this model on the SQuAD v1 dataset which contains passages with multiple questions and answer span pairs. We will use the cross entropy loss at the softmax output. To make predictions, we can simply just add up the scores of the `ith` token being the start and the `jth` token being the end for all i and j>i, then declare the (i,j) with the highest score as the predicted span.





In [19]:
import torch
from transformers import BertTokenizerFast, BertModel
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import csv
from tqdm import tqdm
import psutil
import json
import wandb
wandb.login()

print(torch.cuda.is_available())

True


First, let's load the data from file and then set up pytorch datasets

In [2]:
# load the train and dev JSON documents
with open("train.json", "r") as train_file:
    squad_train = json.load(train_file)         
with open("dev.json", "r") as dev_file:
    squad_dev = json.load(dev_file) 

In [56]:
def get_passages(squad, num_titles=None):
    if num_titles is None:
        num_titles = len(squad['data'])
    # for each title, get passages and all corresponding questions from SQuAD train set
    passages = []
    questions = []
    num_questions = 0
    for i in range(num_titles):
        #print(f"Title# {i}: {squad['data'][i]['title']}, Number of passages: {len(squad['data'][i]['paragraphs'])}")
        for p in squad['data'][i]['paragraphs']:
            # we will append the title to each passage and question
            passages.append(p['context'])
            questions.append([q for q in p['qas']])    
            num_questions += len(p['qas'])
    print(f"Number of passages: {len(passages)}")
    print(f"Number of questions: {num_questions}")
    return passages, questions

In [57]:
passages_train, questions_train = get_passages(squad_train)
passages_val, questions_val = get_passages(squad_dev)

Number of passages: 19035
Number of questions: 130319
Number of passages: 1204
Number of questions: 11873


In [49]:
passage_lengths = [len(p) for p in passages_train]
print(f"Max passage length: {max(passage_lengths)}, Avg passage length: {sum(passage_lengths)/len(passage_lengths)}")

Max passage length: 3706, Avg passage length: 735.5478854741266


Note that the context passaged are very long (over 700 words on average) and won't fit into our BERT model (which can only take upto 512 tokens per sequence). So we will instead take a shorter context window centered around the answer span for each question.  

In [128]:
q = questions_train[50][0]
answer_start_pos = q['answers'][0]['answer_start']
answer_end_pos = answer_start_pos + len(q['answers'][0]['text'])
context = passages_train[50]

print(f"Question: {q['question']}")
print(f"Answer span: {context[answer_start_pos:answer_end_pos]}")

Question: Artist of the Decade was bestowed upon Beyonce from which magazine?
Answer span: The Guardian


In [133]:
window_size_chars = 500
answer_middle_pos = int((answer_start_pos+answer_end_pos)/2) 
window_start_pos = max(0,answer_middle_pos-window_size_chars)
window_end_pos = answer_middle_pos+window_size_chars
context_window = context[window_start_pos:window_end_pos]
print("Context window: ", context_window)

answer_start_pos_window = window_start_pos + (answer_middle_pos - window_start_pos) - (answer_middle_pos - answer_start_pos) 
answer_end_pos_window = answer_start_pos_window + len(q['answers'][0]['text'])
answer_window = context_window[answer_start_pos_window:answer_end_pos_window]
print(f"Answer window: {answer_window}")

# Trim off stray partial words at the beginning and end
context_window_words = context_window.split()
context_window_trimmed = ' '.join(context_window_words[1:-1])
left_trim_length = len(context_window_words[0]) + 1 # add 1 for the white space between stary partial first word and next word
print("Context window trimmed: ", context_window_trimmed)

start_pos_window_trimmed = answer_start_pos_window - left_trim_length
end_pos_window_trimmed = answer_end_pos_window - left_trim_length
answer_window_trimmed = context_window_trimmed[start_pos_window_trimmed:end_pos_window_trimmed]
print(f"Answer window trimmed: {answer_window_trimmed}")


Context window:  In The New Yorker music critic Jody Rosen described Beyoncé as "the most important and compelling popular musician of the twenty-first century..... the result, the logical end point, of a century-plus of pop." When The Guardian named her Artist of the Decade, Llewyn-Smith wrote, "Why Beyoncé? [...] Because she made not one but two of the decade's greatest singles, with Crazy in Love and Single Ladies (Put a Ring on It), not to mention her hits with Destiny's Child; and this was the decade when singles – particularly R&B singles – regained their status as pop's favourite medium. [...] [She] and not any superannuated rock star was arguably the greatest live performer of the past 10 years." In 2013, Beyoncé made th
Answer window: The Guardian
Context window trimmed:  The New Yorker music critic Jody Rosen described Beyoncé as "the most important and compelling popular musician of the twenty-first century..... the result, the logical end point, of a century-plus of pop." W

Since we will use WordPiece tokenization, we need to be careful about converting the character positions of the start and end of the span to subwork token positions.

In [130]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

# encode the passage
context_encoded = tokenizer.encode_plus(context_window_trimmed, add_special_tokens=True, return_offsets_mapping=True)
print(context_encoded.keys())
# convert character positions from original sentence to subword token positions
start_pos_enc = context_encoded.char_to_token(start_pos_window_trimmed)
end_pos_enc = context_encoded.char_to_token(end_pos_window_trimmed-1)
# get the corresponding subword token span
answer_span_encoded = context_encoded['input_ids'][start_pos_enc:end_pos_enc+1]
# decode the span to check if it matches original answer span
print(f"Decoded subword token span: {tokenizer.decode(answer_span_encoded)}")

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping'])
Decoded subword token span: the guardian


Now let's set up a pytorch dataset

In [None]:
class SquadDataset(Dataset):
    def __init__(self, data, max_length=256):
        self.data = data
        self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

in the late 1990s
