<a href="https://colab.research.google.com/github/sudarshansivakumar/SQuAD_QuestionAnswering/blob/main/DistilBERT_QA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### **Installing and importing libraries** 

In [None]:
!pip install transformers
!pip install datasets
!pip install tokenizers

In [None]:
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from torch.utils.data import Dataset,DataLoader
from nltk.tokenize import word_tokenize
import nltk
from transformers import DistilBertTokenizer, DistilBertModel,PreTrainedTokenizer,PreTrainedTokenizerFast,BertModel,AutoTokenizer,DistilBertForQuestionAnswering
import torch
from tokenizers import BertWordPieceTokenizer
from tqdm.notebook import tqdm, trange
import os
from datasets import load_dataset


### **Loading dataset and dataloaders**

In [None]:
dataset = load_dataset(
   'squad')
train_ds = dataset['train']
valid_ds = dataset['validation']

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
class BERTSQuADDataset(Dataset) :
  def __init__(self,dset,tokenizer,max_length = 512,doc_stride = 0) :
    self.dset = dset
    self.tokenizer = tokenizer
    self.max_length = max_length
    self.doc_stride = doc_stride

  def __getitem__(self,idx) :
    # Defining question and context strings
    self.question = self.dset[idx]['question']
    self.context = self.dset[idx]['context']
    self.answer = self.dset[idx]["answers"]["text"][a0]
    # Tokenizing question and context 
    tokenized_example = tokenizer(
      self.question,
      self.context,
      max_length=self.max_length,
      truncation="only_second",
      return_offsets_mapping=True,      
      padding="max_length"
    )
    
    #tokenized_example["start_token"] = []
    #tokenized_example["end_token"] = []
    
    offset_mapping = tokenized_example["offset_mapping"]
    input_ids = tokenized_example["input_ids"]
    attention_mask = tokenized_example["attention_mask"]
    answer_start_char = self.dset[idx]["answers"]["answer_start"][0]
    answer_end_char = answer_start_char + len(self.answer)


    
    context_start_idx = 0
    context_end_idx = len(offset_mapping) - 1
    sequence_ids = tokenized_example.sequence_ids()
    cls_index = input_ids.index(tokenizer.cls_token_id)
    
    while(sequence_ids[context_start_idx] != 1) :
      context_start_idx += 1
    while(sequence_ids[context_end_idx] != 1) :
      context_end_idx -= 1
    
    if not (offset_mapping[context_start_idx][0] <= answer_start_char and offset_mapping[context_end_idx][1] >= answer_end_char) :
      tokenized_example["start_token"] = (cls_index)
      tokenized_example["end_token"] = (cls_index)
    
    else :
      current_token = context_start_idx
      gotStart,gotEnd = False,False

      for start_char,end_char in (offset_mapping[context_start_idx : context_end_idx  + 1]) :  
        if (start_char == answer_start_char) :
          tokenized_example["start_token"] = current_token
          gotStart = True
        if (end_char == answer_end_char) : 
          tokenized_example["end_token"] = current_token
          gotEnd = True
        current_token += 1

      if(gotStart == False) :
        tokenized_example["start_token"] = (cls_index)
      if(gotEnd == False) :
        tokenized_example["end_token"] = (cls_index)
       

    return {"Question" : self.question, 
            "Context" : self.context, 
            "Answer" : self.answer,
            "Input_IDs" : torch.tensor(tokenized_example["input_ids"]),
            "Context_start_index" : (context_start_idx),
            "Context_end_index" : (context_end_idx),
            "Start_token" : (tokenized_example["start_token"]),
            "End_token" : (tokenized_example["end_token"]),
            "Offset_mapping" : torch.tensor(tokenized_example["offset_mapping"]),
            "Attention_mask" : torch.tensor(tokenized_example["attention_mask"])
            }

  def __len__(self) :
    return len(self.dset)

In [None]:
BERTSQuAD_train = BERTSQuADDataset(train_ds,tokenizer,max_length = 512,doc_stride = 128)
BERTSQuAD_valid = BERTSQuADDataset(valid_ds,tokenizer,max_length = 512,doc_stride = 128)

In [None]:
BERTSQuAD_train_loader = DataLoader(BERTSQuAD_train,batch_size = 8,shuffle = True)
BERTSQuAD_valid_loader = DataLoader(BERTSQuAD_valid,batch_size = 8,shuffle = True)

### **Model Training** 

In [None]:
myPreTrainedQAModel = DistilBertForQuestionAnswering.from_pretrained(model_checkpoint)
myPreTrainedOptimizer = torch.optim.AdamW(myPreTrainedQAModel.parameters(), lr = 5e-5)
myPreTrainedQAModel = myPreTrainedQAModel.to(device)

In [None]:
os.mkdir("PreTrained_QA_Model")

In [None]:
def DistilPreTrainFT(model,train_loader,valid_loader,optimizer,num_epochs = 3,save_freq = 1,model_name = "DistilBERT_SQuAD",epoch_offset = 0,device = device,save_location = "PreTrained_QA_Model"):
    for epoch in range(num_epochs):
        model.train()
        running_loss, running_f1,running_em = 0.0, 0.0,0.0
        #current_loss,current_f1,current_em = 0.0,0.0,0.0
        bar = tqdm(enumerate(train_loader), total=len(train_loader))

        for batch_idx,data in bar:
          #data.to(device)
          optimizer.zero_grad()

          # Calculating model outputs
          input_ids = data["Input_IDs"].to(device)
          attention_mask = data["Attention_mask"].to(device)
          start_positions = data["Start_token"].to(device)
          end_positions = data["End_token"].to(device)
          model_outputs = model(input_ids = input_ids,attention_mask = attention_mask,start_positions=start_positions,
                        end_positions=end_positions)
          
          # Storing model outputs, start_logits and end_logists have shape [batch_size,sequence_length]
          batch_loss = model_outputs[0]
          start_logits = model_outputs[1]
          end_logits = model_outputs[2]
            
          # Getting the start and end labels and loading them on CUDA 
          start_labels = torch.unsqueeze(data['Start_token'],1)
          end_labels = torch.unsqueeze(data['End_token'],1)
          start_labels = start_labels.to(device)
          end_labels = end_labels.to(device)  
    
          # Getting answers, context start, context end, adding them to CUDA 
          answers = data["Answer"]
          context_start_indices = data["Context_start_index"]
          context_end_indices = data["Context_end_index"]
            
            
          # Backpropagation through the loss   
          batch_loss.backward()
          # Updating the gradients 
          optimizer.step()
        
          #batch_size = start_logits.shape[0]  

            


          # Getting start and end predictions for each instance in the batch 
          #batch_start_probs = torch.nn.Softmax(dim = 1)(start_logits)
          #batch_end_probs = torch.nn.Softmax(dim = 1)(end_logits)
          
          #answer_predictions = []
          #for i in range(batch_size) : 
            #instance_preds = model_outputs[i]
            
            #calculate answer_start_pred and answer_end_pred
            #context_start = data["Context_start_index"][i].item()
            #context_end = data["Context_end_index"][i].item()
            #instance_start_probs = batch_start_probs[i][context_start:context_end + 1]
            #instance_end_probs = batch_end_probs[i][context_start : context_end + 1]
            #ans_start_token,ans_end_token,buffer = getAnsPred(instance_start_probs,instance_end_probs)
            #ans_start_char = data["Offset_mapping"][i][ans_start_token : ans_end_token + 1][0][0].item()
            #ans_end_char = data["Offset_mapping"][i][ans_start_token : ans_end_token + 1][-1][1].item()
            #pred_answer = data["Context"][i][ans_start_char : ans_end_char] 
            #answer_predictions.append(pred_answer)
        

          # So currently we have the actual answers looking like ["mathematics and statistics","Rafael Nadal",....] and the predictions like ["mathematics and statistics","Nadal",....]
          #exact_matches = exact_match(answer_predictions,answers)
          #f1_scores = f1_score(answer_predictions,answers)
          
          #batch_em = np.mean(exact_matches)
          #batch_f1 = np.mean(f1_scores)
          
          running_loss += batch_loss.item()
          #running_f1 += batch_f1
          #running_em += batch_em
        
          bar.set_description(str({'epoch':epoch+1, 
                                   'Running loss': round((running_loss)/(batch_idx + 1),4), 
                                   }))

          """
          running_loss += loss.item()
          
          #print("-", end = " ")
          """
        epoch_training_loss = (running_loss)/len(train_loader)
        #epoch_em = (running_em)/len(train_loader)
        #epoch_f1 = (running_f1)/len(train_loader)
        running_valid_loss = 0
        try : 
          for data in valid_loader : 
            model.eval() 
            with torch.no_grad() : 
              input_ids = data["Input_IDs"].to(device)
              attention_mask = data["Attention_mask"].to(device)
              start_positions = data["Start_token"].to(device)
              end_positions = data["End_token"].to(device)
              model_outputs = model(input_ids = input_ids,attention_mask = attention_mask,start_positions=start_positions,
                        end_positions=end_positions)
              batch_loss = model_outputs[0]
              running_valid_loss += batch_loss.item()
          epoch_valid_loss = running_valid_loss/len(valid_loader)
          print(f"Epoch {epoch + epoch_offset + 1}, Epoch_training_loss : {epoch_training_loss}, Epoch_valid_loss : {epoch_valid_loss}")
        except : 
          print(f"Epoch {epoch + epoch_offset + 1}, Epoch_training_loss : {epoch_training_loss} ")
        if(epoch%save_freq == 0) : 
          try : 
            model.save_pretrained(save_location + "/PreTrained_Model" + str(epoch + 1))
          except : 
            torch.save({"params": model.state_dict(),"Epoch_loss" : epoch_loss,"Epoch_em" : epoch_em,"Epoch_F1" : epoch_f1}, model_name + 'epoch'+str(epoch + epoch_offset)+'.pt')
          


In [None]:
DistilPreTrainFT(model = myPreTrainedQAModel,train_loader = BERTSQuAD_train_loader,valid_loader =BERTSQuAD_valid_loader, optimizer = myPreTrainedOptimizer,num_epochs = 3,save_freq = 1,model_name = "DistilBERT_SQuAD",epoch_offset = 0,device = device)

### **Evaluating Model performance**

In [None]:
def normalize_text(s):
    """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
    import string, re

    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


# Takes 2 lists of strings and compares corresponding elements to check if they are exact matches. 
def exact_match(preds,answer) : 
  exact_matches = []
  for i in range(len(preds)) : 
    exact_matches.append(normalize_text(preds[i]) == normalize_text(answer[i]))
  return exact_matches

# Takes 2 lists of strings and calculates the F1 scores between corresponding elements in both strings
def f1_score(preds,answer) : 
  f1_scores = []
  for i in range(len(preds)) : 
    shared_words = 0
    pred_words = normalize_text(preds[i]).split()
    answer_words = normalize_text(answer[i]).split()
    shared_words = set(pred_words) & set(answer_words)
    try : 
      precision = (len(shared_words)/len(pred_words))
    except : 
      precision = 0
    try : 
      recall = (len(shared_words)/len(answer_words))
    except :
      recall = 0
    
    if(precision == 0 or recall == 0) : 
      f1_scores.append(0)
    else : 
      f1_scores.append(2 * (precision * recall)/ (precision + recall))
    #print(pred_words,answer_words,shared_words)
  return f1_scores

In [None]:
# Getting the predicted answers for a given batch
def getBatchPreds(example_batch,batch_start_probs,batch_end_probs,batch_size = 8) :
  pred_answers = []
  context_start_indices = example_batch["Context_start_index"]
  context_end_indices = example_batch["Context_end_index"]
  for i in range(batch_size) :
    instance_start_probs,instance_end_probs = batch_start_probs[i],batch_end_probs[i]
    context_start,context_end = context_start_indices[i],context_end_indices[i]
    offset_maps = example_batch["Offset_mapping"][i]
    context = example_batch["Context"][i]
    best_start,best_end,best_prob = context_start,context_start,instance_start_probs[context_start] * instance_end_probs[context_start]
    for j in range(context_start,context_end + 1) : 
      for k in range(j,context_end + 1) : 
        current_prob = instance_start_probs[j] * instance_end_probs[k]
        if(current_prob > best_prob) : 
          best_start,best_end,best_prob = j,k,current_prob
    start_char = offset_maps[best_start][0]
    end_char = offset_maps[best_end][1]
    ans = context[start_char:end_char]
    pred_answers.append(ans)

  return pred_answers

In [None]:
# Calculating overall exact match and F1 score of the entire model on the given dataloader
def getAccuracies(loader,model) :
  avg_em, avg_f1 = 0,0
  bar = tqdm(enumerate(loader), total=len(loader)) 
  for batch_idx,batch in bar : 
    batch_start_probs,batch_end_probs = getOutputs(batch,model)
    pred_answers = getBatchPreds(batch,batch_start_probs,batch_end_probs)
    actual_answers = batch["Answer"]
    em = np.mean(exact_match(pred_answers,actual_answers))
    f1 = np.mean(f1_score(pred_answers,actual_answers))
    avg_em += em
    avg_f1 += f1
    bar.set_description(str({'running_em' : (avg_em/(batch_idx + 1)),
                             'running_f1' : (avg_f1/(batch_idx + 1))}))
  avg_em = (avg_em)/(len(loader))
  avg_f1 = (avg_f1)/(len(loader))
  return avg_em,avg_f1

In [None]:
train_em,train_f1 = getAccuracies(BERTSQuAD_train_loader,myPreTrainedQAModel)
valid_em,valid_f1 = getAccuracies(BERTSQuAD_valid_loader,myPreTrainedQAModel)