#### Finetuning BERT for `Extractive Question Answering` 

We will finetune a BERT model on the task of extractive QA, which involves taking a `factoid question` and a `context passage` of text and `labeling a span` of text from that passage which contains the `answer`. We can frame this as a `classification task`. First we concatenate the question and context passage pair, seperated by a `[SEP]` token. Then we compute the BERT encoding for this sequence. Then we apply a linear transform to each output token's encoding vector to compute a scalar score. By passing the scores from all tokens through a softmax, we obtain a `probability distribution` over tokens in the sequence, which we can interpret as the probability of a token being the start of the span. We actually will compute two separate linear transforms of all tokens and pass both sets of scores through a softmax to get two probability distributions over tokens, one for `start of span` and one for `end of span`. 

We will train this model on the SQuAD v1 dataset which contains passages with multiple questions and answer span pairs. We will use the cross entropy loss at the softmax output. To make predictions, we can simply just add up the scores of the `ith` token being the start and the `jth` token being the end for all i and j>i, then declare the (i,j) with the highest score as the predicted span.





In [187]:
import torch
from transformers import BertTokenizerFast, BertModel
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import csv
import random
from tqdm import tqdm
import psutil
import json
import wandb
wandb.login()

print(torch.cuda.is_available())

True


First, let's load the data from file and then set up pytorch datasets

In [2]:
# load the train and dev JSON documents
with open("train.json", "r") as train_file:
    squad_train = json.load(train_file)         
with open("dev.json", "r") as dev_file:
    squad_dev = json.load(dev_file) 

In [56]:
def get_passages(squad, num_titles=None):
    if num_titles is None:
        num_titles = len(squad['data'])
    # for each title, get passages and all corresponding questions from SQuAD train set
    passages = []
    questions = []
    num_questions = 0
    for i in range(num_titles):
        #print(f"Title# {i}: {squad['data'][i]['title']}, Number of passages: {len(squad['data'][i]['paragraphs'])}")
        for p in squad['data'][i]['paragraphs']:
            # we will append the title to each passage and question
            passages.append(p['context'])
            questions.append([q for q in p['qas']])    
            num_questions += len(p['qas'])
    print(f"Number of passages: {len(passages)}")
    print(f"Number of questions: {num_questions}")
    return passages, questions

In [57]:
passages_train, questions_train = get_passages(squad_train)
passages_val, questions_val = get_passages(squad_dev)

Number of passages: 19035
Number of questions: 130319
Number of passages: 1204
Number of questions: 11873


In [49]:
passage_lengths = [len(p) for p in passages_train]
print(f"Max passage length: {max(passage_lengths)}, Avg passage length: {sum(passage_lengths)/len(passage_lengths)}")

Max passage length: 3706, Avg passage length: 735.5478854741266


Note that the context passaged are very long (over 700 words on average) and won't fit into our BERT model (which can only take upto 512 tokens per sequence). So we will instead take a shorter fixed size context window for each question.  

In [197]:
q = questions_train[0][0]
answer_start_pos = q['answers'][0]['answer_start']
answer_end_pos = answer_start_pos + len(q['answers'][0]['text'])
context = passages_train[0]

print(f"Question: {q['question']}")
print(f"Answer span: {context[answer_start_pos:answer_end_pos]}")

Question: When did Beyonce start becoming popular?
Answer span: in the late 1990s


In [217]:
window_size_chars = 500
# pick a random context window around the answer (try to keep at least 20% of the characters in window on the left side of the answer)
answer_middle_pos = int((answer_start_pos+answer_end_pos)/2) 
a = max(0,answer_end_pos-window_size_chars)
b = max(0, answer_start_pos - 0.2*window_size_chars)
random_window_start_pos = random.randint(a, b)
#window_start_pos = max(0,answer_middle_pos-window_size_chars)
#window_end_pos = answer_middle_pos+window_size_chars
window_start_pos = random_window_start_pos
window_end_pos = window_start_pos+window_size_chars

context_window = context[window_start_pos:window_end_pos]
print("Context window: ", context_window)

answer_start_pos_window = answer_start_pos - window_start_pos
answer_end_pos_window = answer_start_pos_window + len(q['answers'][0]['text'])
answer_window = context_window[answer_start_pos_window:answer_end_pos_window]
print(f"Answer window: {answer_window}")

# Trim off stray partial words at the beginning and end
context_window_words = context_window.split()
context_window_trimmed = ' '.join(context_window_words[1:-1])
left_trim_length = len(context_window_words[0]) + 1 # add 1 for the white space between stary partial first word and next word
print("Context window trimmed: ", context_window_trimmed)

start_pos_window_trimmed = answer_start_pos_window - left_trim_length
end_pos_window_trimmed = answer_end_pos_window - left_trim_length
answer_window_trimmed = context_window_trimmed[start_pos_window_trimmed:end_pos_window_trimmed]
print(f"Answer window trimmed: {answer_window_trimmed}")


Context window:  arter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she performed in various singing and dancing competitions as a child, and rose to fame in the late 1990s as lead singer of R&B girl-group Destiny's Child. Managed by her father, Mathew Knowles, the group became one of the world's best-selling girl groups of all time. Their hiatus saw the release of Beyoncé's debut album, Dangerously in Love 
Answer window: in the late 1990s
Context window trimmed:  (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she performed in various singing and dancing competitions as a child, and rose to fame in the late 1990s as lead singer of R&B girl-group Destiny's Child. Managed by her father, Mathew Knowles, the group became one of the world's best-selling girl groups of all time. Their

Since we will use WordPiece tokenization, we need to be careful about converting the character positions of the start and end of the span to subwork token positions.

In [None]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

In [218]:
# encode the passage
context_encoded = tokenizer.encode_plus((context_window_trimmed, q['question']), add_special_tokens=True, return_offsets_mapping=True)
print(context_encoded.keys())
# convert character positions from original sentence to subword token positions
start_pos_enc = context_encoded.char_to_token(start_pos_window_trimmed)
end_pos_enc = context_encoded.char_to_token(end_pos_window_trimmed-1)
# get the corresponding subword token span
answer_span_encoded = context_encoded['input_ids'][start_pos_enc:end_pos_enc+1]
# decode the span to check if it matches original answer span
print(f"Decoded subword token span: {tokenizer.decode(answer_span_encoded)}")
print(f"Decoded sentence pair: {tokenizer.decode(context_encoded['input_ids'])}")

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping'])
Decoded subword token span: in the late 1990s
Decoded sentence pair: [CLS] ( / biːˈjɒnseɪ / bee - yon - say ) ( born september 4, 1981 ) is an american singer, songwriter, record producer and actress. born and raised in houston, texas, she performed in various singing and dancing competitions as a child, and rose to fame in the late 1990s as lead singer of r & b girl - group destiny's child. managed by her father, mathew knowles, the group became one of the world's best - selling girl groups of all time. their hiatus saw the release of beyonce's debut album, dangerously in [SEP] when did beyonce start becoming popular? [SEP]


Now let's create question, trimmed context window and answer span instances. Note: some quesitions may contain empty answers list, those will be skipped.

In [219]:
# function for creating instances
def get_instances(passages, questions, window_size_chars=300):
    triplets = []
    for i, context in enumerate(passages):
        for q in questions[i]:
            for a in q['answers']:
                # get start and end character positions of answer span 
                answer_start_pos = a['answer_start']
                answer_end_pos = answer_start_pos + len(a['text'])
                # get windowed context
                answer_middle_pos = int((answer_start_pos+answer_end_pos)/2) 
                a = max(0,answer_end_pos-window_size_chars)
                b = max(0, answer_start_pos - 0.2*window_size_chars)
                window_start_pos = random.randint(a, b)
                window_end_pos = window_start_pos+window_size_chars
                context_window = context[window_start_pos:window_end_pos]
                # offset the answer span positions
                answer_start_pos_window = answer_start_pos - window_start_pos
                answer_end_pos_window = answer_start_pos_window + len(q['answers'][0]['text'])
                # Trim off stray partial words at the beginning and end of window
                context_window_words = context_window.split()
                context_window_trimmed = ' '.join(context_window_words[1:-1])
                left_trim_length = len(context_window_words[0]) + 1 # add 1 for the white space between stary partial first word and next word
                start_pos_window_trimmed = answer_start_pos_window - left_trim_length
                end_pos_window_trimmed = answer_end_pos_window - left_trim_length
                triplets.append(((context_window_trimmed, q['question']), start_pos_window_trimmed, end_pos_window_trimmed))
    return triplets            

In [220]:
triplets_train = get_instances(passages_train, questions_train)
triplets_val = get_instances(passages_val, questions_val)

Now let's set up a pytorch dataset

In [172]:
class SquadDataset(Dataset):
    def __init__(self, data, max_length=256):
        self.data = data
        self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # get the data instance
        sentence_pair =  self.data[idx][0]
        start_pos_char =  self.data[idx][1]
        end_pos_char =  self.data[idx][2]
        # encode the sentences
        encoded = self.tokenizer.encode_plus(sentence_pair, add_special_tokens=True, return_offsets_mapping=True)
        # get the token indices and attention mask
        input_idx = encoded['input_ids']
        attn_mask = encoded['attention_mask']   
        # convert start and end character positions to subword token positions
        start_pos_enc = encoded.char_to_token(start_pos_char)
        end_pos_enc = encoded.char_to_token(end_pos_char-1)
        # convert to tensors
        input_idx = torch.tensor(input_idx)
        attn_mask = torch.tensor(attn_mask)
        start_pos_enc = torch.tensor(start_pos_enc)
        end_pos_enc = torch.tensor(end_pos_enc)
        return input_idx, start_pos_enc, end_pos_enc, attn_mask

In [185]:
train_dataset = SquadDataset(triplets_train)
val_dataset = SquadDataset(triplets_val)

#### Now define the classification model.

In [None]:
class BERTExtractiveQA(torch.nn.Module):
    def __init__(self, block_size, hidden_size=768, dropout_rate=0.1, finetune=False):
        super().__init__()
        # load pretrained BERT model
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = torch.nn.Dropout(dropout_rate)      
        # define two classifier heads, one for predicting start of span and another for end of span 
        num_classes = block_size
        self.classifier_head_start_span = torch.nn.Linear(hidden_size, num_classes)
        self.classifier_head_end_span = torch.nn.Linear(hidden_size, num_classes)

        for param in self.bert.parameters():
            if finetune:
                # make all parameters of BERT model trainable if we're finetuning
                param.requires_grad = True
            else:
                # freeze all parameters of BERT model if we're not finetuning
                param.requires_grad = False

    def forward(self, input_idx, labels_start, labels_end, attn_mask):
        # compute BERT encodings
        bert_output = self.bert(input_idx, attention_mask=attn_mask)
        bert_output = bert_output.last_hidden_state # shape: (batch_size, sequence_length, hidden_size)
        # compute logits/scores over tokens for each of the classifier heads
        logits_start = self.classifier_head_start_span(bert_output)  # shape: (batch_size, sequence_length)
        logits_end = self.classifier_head_end_span(bert_output)  # shape: (batch_size, sequence_length)
        # compute loss
        loss = F.cross_entropy(logits_start, labels_start) + F.cross_entropy(logits_end, labels_end) 

        return logits_start, logits_end, loss

in the late 1990s
