# Finetuning Question Answering on BERT
## This notebook outlines the concepts behind finetuning Question Answering on SQuAD dataset using BERT model

**Fine_tuning BERT Extracitve Question Answering in PyTorch**

Please read from [here](https://huggingface.co/transformers/model_doc/bert.html#bertforquestionanswering) or most online articles about BERT for quesion answering.
- What I am trying to answer is the how. How hugging face fine_tuned this model.

In [1]:
import torch
torch.cuda.empty_cache

<function torch.cuda.memory.empty_cache>

Install transformers Library

In [2]:
!pip install -q transformers datasets

[K     |████████████████████████████████| 2.3MB 30.2MB/s 
[K     |████████████████████████████████| 235kB 48.5MB/s 
[K     |████████████████████████████████| 901kB 55.9MB/s 
[K     |████████████████████████████████| 3.3MB 44.5MB/s 
[K     |████████████████████████████████| 112kB 60.3MB/s 
[K     |████████████████████████████████| 245kB 60.0MB/s 
[?25h

Import libraries

In [3]:
import numpy as np
import torch
from torch.optim import Adam
from transformers import BertForQuestionAnswering, BertTokenizerFast, get_linear_schedule_with_warmup


**1. Instantiate model**

- I will be inheriting from the bert_base_uncased and the BertQuestion answering framework.

In [4]:
# We are using BertTokenizerFast because other python tokens do not have char_to_token functionality we will need later.
MODEL = "bert-base-uncased"

tokenizer = BertTokenizerFast.from_pretrained(MODEL)
model = BertForQuestionAnswering.from_pretrained(MODEL, return_dict = True)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=570.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForQuestionAnswering: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased a

**2. Data**

- We will be using **S**tanford **Qu**estion**A**nswering **D**ataset (**SQuAD**)
- SQuAD is a pre_cleaned question answering dataset but we will apply a few changes to get correct answer alignments

- You can explore the dataset [here](https://rajpurkar.github.io/SQuAD-explorer/explore/1.1/dev/), download on tfds, huggingface datasets or Kaggle.
* The goal is to find, for each question, a span of text in a paragraph that answers that question.

### Load the dataset

In [None]:
from datasets import load_dataset

### Load and split dataset, using small datasets for the sake of model training

In [5]:
train_data, valid_data = load_dataset('squad', split='train[:1%]'), load_dataset('squad', split='validation[:3%]')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1877.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=955.0, style=ProgressStyle(description_…


Downloading and preparing dataset squad/plain_text (download: 33.51 MiB, generated: 85.75 MiB, post-processed: Unknown size, total: 119.27 MiB) to /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/4fffa6cf76083860f85fa83486ec3028e7e32c342c218ff2a620fc6b2868483a...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=8116577.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1054280.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Reusing dataset squad (/root/.cache/huggingface/datasets/squad/plain_text/1.0.0/4fffa6cf76083860f85fa83486ec3028e7e32c342c218ff2a620fc6b2868483a)


Dataset squad downloaded and prepared to /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/4fffa6cf76083860f85fa83486ec3028e7e32c342c218ff2a620fc6b2868483a. Subsequent calls will reuse this data.


### Checking the features of the answers 

In [6]:
train_data.shape, valid_data.shape

((876, 5), (317, 5))

In [65]:
train_data[0]

{'answers': {'answer_start': [515], 'text': ['Saint Bernadette Soubirous']},
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 'id': '5733be284776f41900661182',
 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
 'title': 'University_of_Notre_Dame'}

### Getting correct answer text alignment and tokenizing the dataset

In [7]:
# Dataset cleaning and tokenization
# BertTokenizerFast because python tokenizer do not have char_to_token functionality

def correct_alignment(context, answer):

    """ Description: This functions corrects the alignment of answers in the squad dataset that are sometimes off by one or 2 values also adds end_postion index.
    
    inputs: list of contexts and answers
    outputs: Updated list that contains answer_end positions """
    
    start_text = answer['text'][0]
    start_idx = answer['answer_start'][0]
    end_idx = start_idx + len(start_text)

    # When alignment is okay
    if context[start_idx:end_idx] == start_text:
      return start_idx, end_idx    
      # When alignment is off by 1 character
    elif context[start_idx-1:end_idx-1] == start_text:
      return start_idx-1, end_idx-1  
      # when alignment is off by 2 characters
    elif context[start_idx-2:end_idx-2] == start_text:
      return start_idx-2, end_idx-2
    else:
      raise ValueError()

### Tokenize our training dataset

In [None]:
def convert_to_features(example_batch):
  """ Description: This functions tokenizes the context and questions then appends encoded start_positions and end_positions from the above function.
    
    inputs: list of contexts, questions and answers
    outputs: Updated list that contains answer_end positions """

    # Tokenize contexts and questions (as pairs of inputs)
  encodings = tokenizer(example_batch['context'], example_batch['question'], truncation=True)

    # Compute start and end tokens for labels using Transformers's fast tokenizers alignement methods.
  start_positions, end_positions = [], []
  for i, (context, answer) in enumerate(zip(example_batch['context'], example_batch['answers'])):
    start_idx, end_idx = correct_alignment(context, answer)
    start_positions.append(encodings.char_to_token(i, start_idx))
    end_positions.append(encodings.char_to_token(i, end_idx-1))
    # update encodings   
  encodings.update({'start_positions': start_positions, 'end_positions': end_positions})

  return encodings

### Map the dataset to the convert_function, faster than using for loops.


In [8]:
Training_encoded = train_data.map(convert_to_features, batched=True)
Validation_encoded = valid_data.map(convert_to_features, batched = True)

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




### Encoded features
- Our encoded dataset has some columns we don't need

In [9]:
Training_encoded.features

{'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'context': Value(dtype='string', id=None),
 'end_positions': Value(dtype='int64', id=None),
 'id': Value(dtype='string', id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'question': Value(dtype='string', id=None),
 'start_positions': Value(dtype='int64', id=None),
 'title': Value(dtype='string', id=None),
 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)}

### Format our encided datasets to outputs torch.Tensor to train our pytorch model

In [10]:
columns = ['input_ids', 'attention_mask', 'token_type_ids', 'start_positions', 'end_positions']
Training_encoded.set_format(type='torch', columns=columns)
Validation_encoded.set_format(type='torch', columns=columns)

In [11]:
column_names =['answers', 'context', 'id', 'question', 'title']

Validation_encoded.remove_columns_(column_names=column_names)
Training_encoded.remove_columns_(column_names=column_names)


  This is separate from the ipykernel package so we can avoid doing imports until


### Loading the tensor data into dataloader.

In [12]:
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

# Instantiate a PyTorch Dataloader around our dataset
# Let's do dynamic batching (pad on the fly with our own collate_fn)
def collate_fn(examples):
    return tokenizer.pad(examples, return_tensors='pt')

### Dataloaders for training and validation

In [None]:
dataloader_val = DataLoader(Validation_encoded, collate_fn=collate_fn, batch_size= 4, sampler=SequentialSampler(Validation_encoded))
dataloader = DataLoader(Training_encoded, collate_fn=collate_fn, batch_size =4, sampler= RandomSampler(Training_encoded))

### Setting the seed for generating random numbers

In [13]:
import random

seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

**3. Training and evaluating the model**

**Inputs/parameters**. Here are the [explanations](https://huggingface.co/transformers/glossary.html#attention-mask) of what these paramenters represent.

*  input_ids - Ids of word embeddings
*  attention_masks - Values to point inputs that should be attended to, i.e inputs that are not paddings.
*  input_type_ids - Classification and separation tokens.
*  segment_ids - Whether the segment is a question or an answer.
- start_positions and end_positions - Tokens representing the start and end of an answer.

**outputs**
* Start_logits - probabilities that the start value is an input_id x. (torch.FloatTensor of shape (batch_size, sequence_length)) – Span-start scores (before SoftMax)
* End_logits - Probabilities that the end value is an input_id x. (torch.FloatTensor of shape (batch_size, sequence_length)) – Span-start scores (before SoftMax)
* Other return values are loss (cross enhropy loss). Hidden states and attention heads when specified.
- Start_Loss is calculated by comparing the correct start_posistions with the start_logits from the QuestionAnswering class. 
- Then  end_Loss is calculated by comparing the correct end_posistions with the end_logits from the QuestionAnswering class.
- The two losses are added then devided by two.

In [18]:
# Validation function for the model

def model_validation(dataloader_val):

    model.eval().to(device)
    val_total_loss = 0
    
    for batch in dataloader_val:      
        # batch = tuple(b.to(device) for b in batch)
        batch.to(device)

        with torch.no_grad():        
            outputs = model(**batch)
            
        loss = outputs.loss
        val_total_loss += loss.item()
    return val_total_loss



### Optimizer

In [None]:
from transformers import Adam
optimizer = Adam(model.parameters(), lr=1e-5)

### Scheduler

In [None]:
from transformers import get_linear_schedule_with_warmup

scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps=0,
                                            num_training_steps=len(dataloader_val)*epochs) 

### Training Loop

In [21]:
from tqdm.notebook import tqdm

#Clear cache before running model
torch.cuda.empty_cache

epochs = 10


device = 'cuda' if torch.cuda.is_available() else 'cpu'

for epoch in tqdm(range(1, epochs+1)):
    
    model.train().to(device)
    
    loss_train_total = 0

    progress_bar = tqdm(dataloader, desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)
    for batch in progress_bar:

        model.zero_grad()
        
        batch.to(device)      

        outputs = model(**batch)
        
        loss = outputs.loss
        loss_train_total += loss.item()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()
        
        progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(batch))})
         
        
    torch.save(model.state_dict(), f'finetuned_BERT_epoch_{epoch}.model')
        
    tqdm.write(f'\nEpoch {epoch}')
    
    loss_train_avg = loss_train_total/len(dataloader)            
    tqdm.write(f'Training loss: {round(loss_train_avg, 2)}')
    
    val_loss = model_validation(dataloader_val)
    val_loss_avg = val_loss/len(dataloader_val)
    tqdm.write(f'Validation loss: {round(val_loss_avg, 2)}')

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=219.0, style=ProgressStyle(description_widt…


Epoch 1
Training loss: 0.74
Validation loss: 2.14


HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=219.0, style=ProgressStyle(description_widt…


Epoch 2
Training loss: 0.51
Validation loss: 2.34


HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=219.0, style=ProgressStyle(description_widt…


Epoch 3
Training loss: 0.37
Validation loss: 2.43


HBox(children=(FloatProgress(value=0.0, description='Epoch 4', max=219.0, style=ProgressStyle(description_widt…


Epoch 4
Training loss: 0.32
Validation loss: 2.47


HBox(children=(FloatProgress(value=0.0, description='Epoch 5', max=219.0, style=ProgressStyle(description_widt…


Epoch 5
Training loss: 0.31
Validation loss: 2.47


HBox(children=(FloatProgress(value=0.0, description='Epoch 6', max=219.0, style=ProgressStyle(description_widt…


Epoch 6
Training loss: 0.32
Validation loss: 2.47


HBox(children=(FloatProgress(value=0.0, description='Epoch 7', max=219.0, style=ProgressStyle(description_widt…


Epoch 7
Training loss: 0.3
Validation loss: 2.47


HBox(children=(FloatProgress(value=0.0, description='Epoch 8', max=219.0, style=ProgressStyle(description_widt…


Epoch 8
Training loss: 0.29
Validation loss: 2.47


HBox(children=(FloatProgress(value=0.0, description='Epoch 9', max=219.0, style=ProgressStyle(description_widt…


Epoch 9
Training loss: 0.3
Validation loss: 2.47


HBox(children=(FloatProgress(value=0.0, description='Epoch 10', max=219.0, style=ProgressStyle(description_wid…


Epoch 10
Training loss: 0.33
Validation loss: 2.47

