<a href="https://colab.research.google.com/github/susantaghosh1/nlp-notebooks/blob/develop/Fine_Tuning_Extractive_QA_with_BERT_and_Friends.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine Tuning BERT/RoBERTa/DeBERTa/ALBERT/DistillBERT for extractive QA on Squad dataset

In this section we will fine-tune Extractive QA on Squad dataset. Encoder-only models like BERT tend to be great at extracting answers to factoid questions like “Who invented the Transformer architecture?” but fare poorly when given open-ended questions like “Why is the sky blue?” In these more challenging cases, encoder-decoder models like T5 and BART are typically used to synthesize the information in a way that’s quite similar to text summarization.

All of those work in the same way: they add a linear layer on top of the base model, which is used to produce a tensor of shape (batch_size,sequence_length,2), indicating the unnormalized scores *[LOGITS]* for start position and end position of the answers for every example in the batch.

Let's discuss little bit internal working of the model :

1. Question and Context [tokeized version] will be passed together as a pair to the model **[ let's say shape of input to the model is (5,30) where 5 is batch_size and 30 is sequence length [number of tokens in each input]**
2. Vanilla BERT [OR it's friends] will produce contextualized embeddings for each and every word in the sequence. Shape of output from BERT is **(5,30,768) where 5 is the batch size, 30 is the sequece length and 768 is the embedding dimension of the each token**
3. Now a liner head will be added on top of each of the tokens and each liner layer will take 768 dim as input and outputs 2 tensors , which we call start_logits and end_logits. Now, shape of output is **(5,30,2)**
4. Now we will split the start_logits and end_logits where shape of each logits are **(5,30,1)**
5. Now we will remove the single dimesion from the last dimension of start and end logits or in other words we will squeeze the start and end logits across the last dimesion and now shape of start and end logits will be **(5,30)**

**start_logits = tensor of shape (5,30)**
**end_logits = tensor of shape (5,30)**

6. Model will take start_positions and end_positions of the answer as labels

start_positions (`torch.LongTensor` of shape `(batch_size,)`):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss.

end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss.

**start_positions = tensor of shape (5,)**
**end_positions = tensor of shape (5,)**

7. Now Cross Entropy loss will be computed between **start_logits and start_positions** and end_logits and end_positions**.

8. Total loss will be the average loss of **start_logits and start_positions** and end_logits and end_positions** and it will be backpropagated to the model for calculationg the gradients and optimizing the weights

Pseudo code for QA Model with BERT

class PseudoQA(nn.Module):

  def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bert = BertModel(config, add_pooling_layer=False)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        # Initialize weights and apply final processing
        self.post_init()
  
   def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        start_positions: Optional[torch.Tensor] = None,
        end_positions: Optional[torch.Tensor] = None,
    ) :
        
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0] ## ** last hidden state output of bert**

        # ** shape of sequence_output : (batch_size,sequence_length,768) **

        logits = self.qa_outputs(sequence_output)
        # ** shape of logits : (batch_size,sequence_length,2) **
        start_logits, end_logits = logits.split(1, dim=-1)
        # ** shape of start_logits and end_logits : (batch_size,sequence_length,1) **
        start_logits = start_logits.squeeze(-1).contiguous() # ** shape : (batch_size,sequence_length) **
        end_logits = end_logits.squeeze(-1).contiguous() # ** shape : (batch_size,sequence_length) **

        total_loss = None
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, 
            # we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions = start_positions.clamp(0, ignored_index)
            end_positions = end_positions.clamp(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
  



Enough of theory!!!! Let's dirty our hands

In [1]:
%%capture
!pip install datasets transformers[sentencepiece]
!pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
!pip install scipy sklearn

In [2]:
import torch
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [39]:
!nvidia-smi

Mon May 23 06:26:26 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P8     9W /  70W |      3MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [4]:
# load the dataset

from datasets import load_dataset

raw_datasets = load_dataset("squad")

Downloading builder script:   0%|          | 0.00/1.97k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

Downloading and preparing dataset squad/plain_text (download: 33.51 MiB, generated: 85.63 MiB, post-processed: Unknown size, total: 119.14 MiB) to /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/8.12M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.05M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/87599 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10570 [00:00<?, ? examples/s]

Dataset squad downloaded and prepared to /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

In [6]:
print("Context: ", raw_datasets["train"][0]["context"])
print("Question: ", raw_datasets["train"][0]["question"])
print("Answer: ", raw_datasets["train"][0]["answers"])

Context:  Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.
Question:  To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
Answer:  {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}


In [9]:
print(raw_datasets["train"][0]["answers"].keys())
print(type(raw_datasets["train"][0]["answers"]['text']))
print(raw_datasets["train"][0]["answers"]['text'][0])

dict_keys(['text', 'answer_start'])
<class 'list'>
Saint Bernadette Soubirous


In [12]:
answer = raw_datasets["train"][0]["answers"]['text'][0]
answer_start = raw_datasets["train"][0]["answers"]['answer_start'][0]
answer_end = answer_start + len(answer)
answer_from_context = raw_datasets["train"][0]["context"] [answer_start:answer_end]


In [13]:
answer_from_context

'Saint Bernadette Soubirous'

During training, there is only one possible answer. We can double-check this by using the Dataset.filter() method:

In [14]:
raw_datasets["train"].filter(lambda x: len(x["answers"]["text"]) != 1)

  0%|          | 0/88 [00:00<?, ?ba/s]

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 0
})

For evaluation, however, there are several possible answers for each sample, which may be the same or different:

In [15]:
print(raw_datasets["validation"][0]["answers"])
print(raw_datasets["validation"][2]["answers"])

{'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'], 'answer_start': [177, 177, 177]}
{'text': ['Santa Clara, California', "Levi's Stadium", "Levi's Stadium in the San Francisco Bay Area at Santa Clara, California."], 'answer_start': [403, 355, 355]}


# Processing the training data

In [16]:
from transformers import AutoTokenizer

model_checkpoint = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/208k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/426k [00:00<?, ?B/s]

In [17]:
tokenizer.is_fast

True

In [18]:
tokenizer.special_tokens_map

{'cls_token': '[CLS]',
 'mask_token': '[MASK]',
 'pad_token': '[PAD]',
 'sep_token': '[SEP]',
 'unk_token': '[UNK]'}

We can pass to our tokenizer the question and the context together, and it will properly insert the special tokens to form a sentence like this:

Copied
[CLS] question [SEP] context [SEP]

a predicted answer to all the acceptable answers and take the best score. 

In [25]:
context = raw_datasets["train"][0]["context"]
question = raw_datasets["train"][0]["question"]

inputs = tokenizer(question, context,return_offsets_mapping=True)


In [27]:
len(inputs['input_ids'])

181

In [28]:
len(inputs['offset_mapping'])

181

In [26]:
inputs

{'input_ids': [101, 1706, 2292, 1225, 1103, 6567, 2090, 9273, 2845, 1107, 8109, 1107, 10111, 20500, 1699, 136, 102, 22182, 1193, 117, 1103, 1278, 1144, 170, 2336, 1959, 119, 1335, 4184, 1103, 4304, 4334, 112, 188, 2284, 10945, 1110, 170, 5404, 5921, 1104, 1103, 6567, 2090, 119, 13301, 1107, 1524, 1104, 1103, 4304, 4334, 1105, 4749, 1122, 117, 1110, 170, 7335, 5921, 1104, 4028, 1114, 1739, 1146, 14089, 5591, 1114, 1103, 7051, 107, 159, 21462, 1566, 24930, 2508, 152, 1306, 3965, 107, 119, 5893, 1106, 1103, 4304, 4334, 1110, 1103, 19349, 1104, 1103, 11373, 4641, 119, 13301, 1481, 1103, 171, 17506, 9538, 1110, 1103, 144, 10595, 2430, 117, 170, 14789, 1282, 1104, 8070, 1105, 9284, 119, 1135, 1110, 170, 16498, 1104, 1103, 176, 10595, 2430, 1120, 10111, 20500, 117, 1699, 1187, 1103, 6567, 2090, 25153, 1193, 1691, 1106, 2216, 17666, 6397, 3786, 1573, 25422, 13149, 1107, 8109, 119, 1335, 1103, 1322, 1104, 1103, 1514, 2797, 113, 1105, 1107, 170, 2904, 1413, 1115, 8200, 1194, 124, 11739, 1105, 11

In [30]:
question[3:7]

'whom'

In [23]:
tokenizer.decode(inputs["input_ids"])


'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues and the Gold Dome ), is a simple, modern stone statue of Mary. [SEP]'

In [24]:
tokenizer.convert_ids_to_tokens(inputs["input_ids"])

['[CLS]',
 'To',
 'whom',
 'did',
 'the',
 'Virgin',
 'Mary',
 'allegedly',
 'appear',
 'in',
 '1858',
 'in',
 'Lou',
 '##rdes',
 'France',
 '?',
 '[SEP]',
 'Architectural',
 '##ly',
 ',',
 'the',
 'school',
 'has',
 'a',
 'Catholic',
 'character',
 '.',
 'At',
 '##op',
 'the',
 'Main',
 'Building',
 "'",
 's',
 'gold',
 'dome',
 'is',
 'a',
 'golden',
 'statue',
 'of',
 'the',
 'Virgin',
 'Mary',
 '.',
 'Immediately',
 'in',
 'front',
 'of',
 'the',
 'Main',
 'Building',
 'and',
 'facing',
 'it',
 ',',
 'is',
 'a',
 'copper',
 'statue',
 'of',
 'Christ',
 'with',
 'arms',
 'up',
 '##rai',
 '##sed',
 'with',
 'the',
 'legend',
 '"',
 'V',
 '##eni',
 '##te',
 'Ad',
 'Me',
 'O',
 '##m',
 '##nes',
 '"',
 '.',
 'Next',
 'to',
 'the',
 'Main',
 'Building',
 'is',
 'the',
 'Basilica',
 'of',
 'the',
 'Sacred',
 'Heart',
 '.',
 'Immediately',
 'behind',
 'the',
 'b',
 '##asi',
 '##lica',
 'is',
 'the',
 'G',
 '##rot',
 '##to',
 ',',
 'a',
 'Marian',
 'place',
 'of',
 'prayer',
 'and',
 'refle

In this case the context is not too long, but some of the examples in the dataset have very long contexts that will exceed the maximum length we set (which is 384 in this case).  we will deal with long contexts by creating several training features from one sample of our dataset, with a sliding window between them.

To see how this works using the current example, we can limit the length to 100 and use a sliding window of 50 tokens. As a reminder, we use:

max_length to set the maximum length (here 100)
truncation="only_second" to truncate the context (which is in the second position) when the question with its context is too long
stride to set the number of overlapping tokens between two successive chunks (here 50)
return_overflowing_tokens=True to let the tokenizer know we want the overflowing tokens

return_offsets_mapping=True to get the positions of the tokens with respect to the input of the tokenizer [ here question+context+

In [31]:
batch_encoding = tokenizer(question,context,max_length=100,truncation="only_second",stride=50,
                           return_overflowing_tokens=True,return_offsets_mapping=True)

In [32]:
batch_encoding.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping'])

In [35]:
batch_encoding

{'input_ids': [[101, 1706, 2292, 1225, 1103, 6567, 2090, 9273, 2845, 1107, 8109, 1107, 10111, 20500, 1699, 136, 102, 22182, 1193, 117, 1103, 1278, 1144, 170, 2336, 1959, 119, 1335, 4184, 1103, 4304, 4334, 112, 188, 2284, 10945, 1110, 170, 5404, 5921, 1104, 1103, 6567, 2090, 119, 13301, 1107, 1524, 1104, 1103, 4304, 4334, 1105, 4749, 1122, 117, 1110, 170, 7335, 5921, 1104, 4028, 1114, 1739, 1146, 14089, 5591, 1114, 1103, 7051, 107, 159, 21462, 1566, 24930, 2508, 152, 1306, 3965, 107, 119, 5893, 1106, 1103, 4304, 4334, 1110, 1103, 19349, 1104, 1103, 11373, 4641, 119, 13301, 1481, 1103, 171, 17506, 102], [101, 1706, 2292, 1225, 1103, 6567, 2090, 9273, 2845, 1107, 8109, 1107, 10111, 20500, 1699, 136, 102, 1103, 4304, 4334, 1105, 4749, 1122, 117, 1110, 170, 7335, 5921, 1104, 4028, 1114, 1739, 1146, 14089, 5591, 1114, 1103, 7051, 107, 159, 21462, 1566, 24930, 2508, 152, 1306, 3965, 107, 119, 5893, 1106, 1103, 4304, 4334, 1110, 1103, 19349, 1104, 1103, 11373, 4641, 119, 13301, 1481, 1103, 171

In [33]:
batch_encoding['overflow_to_sample_mapping'] # one long context has been truncated to 4 samples

[0, 0, 0, 0]

In [38]:
for tokens,positions in zip(batch_encoding['input_ids'][0],batch_encoding['offset_mapping'][0]):
  print(f"tokens :: {tokens} and decoed token :: {tokenizer.convert_ids_to_tokens(tokens)} and positions :: {positions}")  ## positions for special tokens will be (0,0)

tokens :: 101 and decoed token :: [CLS] and positions :: (0, 0)
tokens :: 1706 and decoed token :: To and positions :: (0, 2)
tokens :: 2292 and decoed token :: whom and positions :: (3, 7)
tokens :: 1225 and decoed token :: did and positions :: (8, 11)
tokens :: 1103 and decoed token :: the and positions :: (12, 15)
tokens :: 6567 and decoed token :: Virgin and positions :: (16, 22)
tokens :: 2090 and decoed token :: Mary and positions :: (23, 27)
tokens :: 9273 and decoed token :: allegedly and positions :: (28, 37)
tokens :: 2845 and decoed token :: appear and positions :: (38, 44)
tokens :: 1107 and decoed token :: in and positions :: (45, 47)
tokens :: 8109 and decoed token :: 1858 and positions :: (48, 52)
tokens :: 1107 and decoed token :: in and positions :: (53, 55)
tokens :: 10111 and decoed token :: Lou and positions :: (56, 59)
tokens :: 20500 and decoed token :: ##rdes and positions :: (59, 63)
tokens :: 1699 and decoed token :: France and positions :: (64, 70)
tokens :: 1

In [40]:
# let's try to encode few more samples together

sample_question =  raw_datasets["train"][2:6]["question"] # list of size 4
sample_context =  raw_datasets["train"][2:6]["context"] # list of size 4

In [41]:
sample_question

['The Basilica of the Sacred heart at Notre Dame is beside to which structure?',
 'What is the Grotto at Notre Dame?',
 'What sits on top of the Main Building at Notre Dame?',
 'When did the Scholastic Magazine of Notre dame begin publishing?']

In [43]:
sample_encoding = tokenizer(sample_question,sample_context,max_length=100,truncation="only_second",stride=50,
                           return_overflowing_tokens=True,return_offsets_mapping=True)
sample_encoding

{'input_ids': [[101, 1109, 19349, 1104, 1103, 11373, 1762, 1120, 10360, 8022, 1110, 3148, 1106, 1134, 2401, 136, 102, 22182, 1193, 117, 1103, 1278, 1144, 170, 2336, 1959, 119, 1335, 4184, 1103, 4304, 4334, 112, 188, 2284, 10945, 1110, 170, 5404, 5921, 1104, 1103, 6567, 2090, 119, 13301, 1107, 1524, 1104, 1103, 4304, 4334, 1105, 4749, 1122, 117, 1110, 170, 7335, 5921, 1104, 4028, 1114, 1739, 1146, 14089, 5591, 1114, 1103, 7051, 107, 159, 21462, 1566, 24930, 2508, 152, 1306, 3965, 107, 119, 5893, 1106, 1103, 4304, 4334, 1110, 1103, 19349, 1104, 1103, 11373, 4641, 119, 13301, 1481, 1103, 171, 17506, 102], [101, 1109, 19349, 1104, 1103, 11373, 1762, 1120, 10360, 8022, 1110, 3148, 1106, 1134, 2401, 136, 102, 1103, 4304, 4334, 1105, 4749, 1122, 117, 1110, 170, 7335, 5921, 1104, 4028, 1114, 1739, 1146, 14089, 5591, 1114, 1103, 7051, 107, 159, 21462, 1566, 24930, 2508, 152, 1306, 3965, 107, 119, 5893, 1106, 1103, 4304, 4334, 1110, 1103, 19349, 1104, 1103, 11373, 4641, 119, 13301, 1481, 1103, 1

In [44]:
for k,v in sample_encoding.items():
  print(f"shape of {k} :: {len(v)}")  # 4 inputs  results in 19 samples

shape of input_ids :: 19
shape of token_type_ids :: 19
shape of attention_mask :: 19
shape of offset_mapping :: 19
shape of overflow_to_sample_mapping :: 19


In [46]:
len(sample_encoding['input_ids'][0])

100

In [None]:
# let's make the labels. labels will be start_positions and end_positions where each of them will be of shape (batch_size)

(0, 0) if the answer is not in the corresponding span of the context
(start_position, end_position) if the answer is in the corresponding span of the context, with start_position being the index of the token (in the input IDs) at the start of the answer and end_position being the index of the token (in the input IDs) where the answer ends