### **Installing and importing libraries** 

In [1]:
!pip install transformers
!pip install datasets
!pip install tokenizers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.25.1-py3-none-any.whl (5.8 MB)
[K     |████████████████████████████████| 5.8 MB 30.9 MB/s 
[?25hCollecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 62.1 MB/s 
Collecting huggingface-hub<1.0,>=0.10.0
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 76.3 MB/s 
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.11.1 tokenizers-0.13.2 transformers-4.25.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.8.0-py3-none-any.whl (452 kB)
[K     |████████████████████████████████| 452 kB 3

In [2]:
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from torch.utils.data import Dataset,DataLoader
from nltk.tokenize import word_tokenize
import nltk
from transformers import DistilBertTokenizer, DistilBertModel,PreTrainedTokenizer,PreTrainedTokenizerFast,BertModel,AutoTokenizer,DistilBertForQuestionAnswering
import torch
from tokenizers import BertWordPieceTokenizer
from tqdm.notebook import tqdm, trange
import os
from datasets import load_dataset


### **Loading dataset and dataloaders**

In [3]:
dataset = load_dataset("cjlovering/natural-questions-short")
train_ds = dataset['train']
valid_ds = dataset['validation']
print(dataset) 
print(valid_ds[0]) 
print(valid_ds[0]['questions'][0]['input_text'])
print(valid_ds[0]["answers"][0]["span_text"])
print(valid_ds[0]["answers"][0]["span_start"])
print(valid_ds[0]['contexts']) 
#nq
#['name', 'id', 'questions', 'answers', 'has_correct_context', 'contexts']


Downloading readme:   0%|          | 0.00/28.0 [00:00<?, ?B/s]



Downloading and preparing dataset json/cjlovering--natural-questions-short to /root/.cache/huggingface/datasets/cjlovering___json/cjlovering--natural-questions-short-63df990b626b5a72/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/14.1M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/889k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Dataset json downloaded and prepared to /root/.cache/huggingface/datasets/cjlovering___json/cjlovering--natural-questions-short-63df990b626b5a72/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['name', 'id', 'questions', 'answers', 'has_correct_context', 'contexts'],
        num_rows: 13933
    })
    validation: Dataset({
        features: ['name', 'id', 'questions', 'answers', 'has_correct_context', 'contexts'],
        num_rows: 871
    })
})
{'name': 'Mandalay Bay', 'id': '7811140318762480311', 'questions': [{'input_text': 'who is the owner of the mandalay bay in vegas'}], 'answers': [{'candidate_id': 0, 'input_text': 'short', 'span_end': 601, 'span_start': 576, 'span_text': 'MGM Resorts International'}], 'has_correct_context': True, 'contexts': "Mandalay Bay Location Paradise , Nevada , U.S. Address 3950 South Las Vegas Boulevard Opening date March 2 , 1999 ; 18 years ago ( March 2 , 1999 ) Theme Tropical No. of rooms 3,309 Total gaming space 135,000 sq ft ( 12,500 m ) Permanent shows Michael Jackson : One Signature attractions Mandalay Bay Convention Center Mandalay Bay Events Center Shark Reef House of Blues Mandala

In [5]:
model_checkpoint = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [6]:
class bertnqdata(Dataset) :
  def __init__(self,dset,tokenizer,max_length = 512,doc_stride = 0) :
    self.dset = dset
    self.tokenizer = tokenizer
    self.max_length = max_length
    self.doc_stride = doc_stride

  def __getitem__(self,idx) :

    #define 
    self.name = self.dset[idx]['name'] 
    self.question = self.dset[idx]['questions'][0]['input_text']
    self.context = self.dset[idx]['contexts'] 
    self.answer = self.dset[idx]["answers"][0]["span_text"]

    #tokenize
    tokenized_example = tokenizer(
      self.question,
      self.context,
      max_length=self.max_length,
      truncation="only_second",
      return_offsets_mapping=True,      
      padding="max_length"
    )

    
    offset_mapping = tokenized_example["offset_mapping"]
    input_ids = tokenized_example["input_ids"]
    attention_mask = tokenized_example["attention_mask"]
    answer_start_char = self.dset[idx]["answers"][0]["span_start"] 
    answer_end_char = answer_start_char + len(self.answer)


    
    context_start_idx = 0
    context_end_idx = len(offset_mapping) - 1
    sequence_ids = tokenized_example.sequence_ids()
    cls_index = input_ids.index(tokenizer.cls_token_id)
    
    while(sequence_ids[context_start_idx] != 1) :
      context_start_idx += 1
    while(sequence_ids[context_end_idx] != 1) :
      context_end_idx -= 1
    
    if not (offset_mapping[context_start_idx][0] <= answer_start_char and offset_mapping[context_end_idx][1] >= answer_end_char) :
      tokenized_example["start_token"] = (cls_index)
      tokenized_example["end_token"] = (cls_index)
    
    else :
      current_token = context_start_idx
      gotStart,gotEnd = False,False

      for start_char,end_char in (offset_mapping[context_start_idx : context_end_idx  + 1]) :  
        if (start_char == answer_start_char) :
          tokenized_example["start_token"] = current_token
          gotStart = True
        if (end_char == answer_end_char) : 
          tokenized_example["end_token"] = current_token
          gotEnd = True
        current_token += 1

      if(gotStart == False) :
        tokenized_example["start_token"] = (cls_index)
      if(gotEnd == False) :
        tokenized_example["end_token"] = (cls_index)
       

    return {"Question" : self.question, 
            "Context" : self.context, 
            "Answer" : self.answer,
            "Input_IDs" : torch.tensor(tokenized_example["input_ids"]),
            "Context_start_index" : (context_start_idx),
            "Context_end_index" : (context_end_idx),
            "Start_token" : (tokenized_example["start_token"]),
            "End_token" : (tokenized_example["end_token"]),
            "Offset_mapping" : torch.tensor(tokenized_example["offset_mapping"]),
            "Attention_mask" : torch.tensor(tokenized_example["attention_mask"])
            }

  def __len__(self) :
    return len(self.dset)

In [7]:
BERTnq_train = bertnqdata(train_ds,tokenizer,max_length = 512,doc_stride = 128)
BERTnq_valid = bertnqdata(valid_ds,tokenizer,max_length = 512,doc_stride = 128)

In [8]:
#dataloader
BERTnq_train_loader = DataLoader(BERTnq_train,batch_size = 8,shuffle = True)
BERTnq_valid_loader = DataLoader(BERTnq_valid,batch_size = 8,shuffle = True)
# print(BERTnq_valid_loader)


### **Model Training** 

In [9]:
myPreTrainedQAModel = DistilBertForQuestionAnswering.from_pretrained(model_checkpoint)
myPreTrainedOptimizer = torch.optim.AdamW(myPreTrainedQAModel.parameters(), lr = 5e-5)
device =  torch.device("cuda" if torch.cuda.is_available() else "cpu")
myPreTrainedQAModel = myPreTrainedQAModel.to(device)

Downloading:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForQuestionAnswering: ['vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this mode

In [10]:
os.mkdir("PreTrained_QA_Model")

In [11]:
def bertPreTrain(model,train_loader,valid_loader,optimizer,num_epochs = 5,save_freq = 1,model_name = "distilbert-base-uncased",epoch_offset = 0,device = device,save_location = "PreTrained_QA_Model"):
    for epoch in range(num_epochs):
        model.train()
        running_loss, running_f1,running_em = 0.0, 0.0,0.0
        
        bar = tqdm(enumerate(train_loader), total=len(train_loader))

        for batch_idx,data in bar:
          optimizer.zero_grad()

          input_ids = data["Input_IDs"].to(device)
          attention_mask = data["Attention_mask"].to(device)
          start_positions = data["Start_token"].to(device)
          end_positions = data["End_token"].to(device)
          model_outputs = model(input_ids = input_ids,attention_mask = attention_mask,start_positions=start_positions,
                        end_positions=end_positions)
          
          #store batch loss and model outputs
          batch_loss = model_outputs[0]
          start_logits = model_outputs[1]
          end_logits = model_outputs[2]
 
          start_labels = torch.unsqueeze(data['Start_token'],1)
          end_labels = torch.unsqueeze(data['End_token'],1)
          start_labels = start_labels.to(device)
          end_labels = end_labels.to(device)  
    
      
          answers = data["Answer"]
          context_start_indices = data["Context_start_index"]
          context_end_indices = data["Context_end_index"]   
            
          #backpropogation and update gradient 
          batch_loss.backward()
          optimizer.step()
     
          
          running_loss += batch_loss.item()
        
          bar.set_description(str({'epoch':epoch+1, 
                                   'Running loss': round((running_loss)/(batch_idx + 1),4), 
                                   }))

        epoch_training_loss = (running_loss)/len(train_loader)
        running_valid_loss = 0
        try : 
          for data in valid_loader : 
            model.eval() 
            with torch.no_grad() : 
              input_ids = data["Input_IDs"].to(device)
              attention_mask = data["Attention_mask"].to(device)
              start_positions = data["Start_token"].to(device)
              end_positions = data["End_token"].to(device)
              model_outputs = model(input_ids = input_ids,attention_mask = attention_mask,start_positions=start_positions,
                        end_positions=end_positions)
              batch_loss = model_outputs[0]
     

              running_valid_loss += batch_loss.item()
          epoch_valid_loss = running_valid_loss/len(valid_loader)
          #print(num_epochs)

          print(f"Epoch {epoch + epoch_offset + 1}, Epoch_training_loss : {epoch_training_loss}, Epoch_valid_loss : {epoch_valid_loss}")
        except : 
          print(f"Epoch {epoch + epoch_offset + 1}, Epoch_training_loss : {epoch_training_loss} ")
        if(epoch%save_freq == 0) : 
          try : 
            model.save_pretrained(save_location + "/PreTrained_Model" + str(epoch + 1))
          except : 
            torch.save({"params": model.state_dict(),"Epoch_loss" : epoch_loss,"Epoch_em" : epoch_em,"Epoch_F1" : epoch_f1}, model_name + 'epoch'+str(epoch + epoch_offset)+'.pt')
          




In [12]:
bertPreTrain(model = myPreTrainedQAModel,train_loader = BERTnq_train_loader,valid_loader =BERTnq_valid_loader, optimizer = myPreTrainedOptimizer,num_epochs = 5,save_freq = 1,model_name = "distilbert-base-uncased",epoch_offset = 0,device = device)

  0%|          | 0/1742 [00:00<?, ?it/s]

Epoch 1, Epoch_training_loss : 1.9987300399318766, Epoch_valid_loss : 1.4602135079442908


  0%|          | 0/1742 [00:00<?, ?it/s]

Epoch 2, Epoch_training_loss : 1.0857554064073465, Epoch_valid_loss : 1.500034054621644


  0%|          | 0/1742 [00:00<?, ?it/s]

Epoch 3, Epoch_training_loss : 0.665429288833632, Epoch_valid_loss : 1.7446969219850839


  0%|          | 0/1742 [00:00<?, ?it/s]

Epoch 4, Epoch_training_loss : 0.4466717938061731, Epoch_valid_loss : 2.035160688905541


  0%|          | 0/1742 [00:00<?, ?it/s]

Epoch 5, Epoch_training_loss : 0.3475527171121459, Epoch_valid_loss : 1.965729272693669


### **Evaluating Model performance**

In [13]:
from collections import Counter
def normalize_text(s):
    """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
    import string, re

    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))
#calculate f1 precision and recal
def f1_score(preds,answer) : 
  f1_scores = []
  for i in range(len(preds)) : 
    shared_words = 0
    pred_words = normalize_text(preds[i]).split()
    answer_words = normalize_text(answer[i]).split()
    shared_words = set(pred_words) & set(answer_words)
    try : 
      precision = (len(shared_words)/len(pred_words))
    except : 
      precision = 0
    try : 
      recall = (len(shared_words)/len(answer_words))
    except :
      recall = 0
    
    if(precision == 0 or recall == 0) : 
      f1_scores.append(0)
    else : 
      f1_scores.append(2 * (precision * recall)/ (precision + recall))
    #print(pred_words,answer_words,shared_words)
      # print(f1_scores)
  return precision, recall, f1_scores





In [14]:

# Get the predicted answers for a given batch
def getBatchPreds(example_batch,batch_start_probs,batch_end_probs,batch_size = 8) :
  pred_answers = []
  context_start_indices = example_batch["Context_start_index"]
  context_end_indices = example_batch["Context_end_index"]
  for i in range(batch_size) :
    instance_start_probs,instance_end_probs = batch_start_probs[i],batch_end_probs[i]
    context_start,context_end = context_start_indices[i],context_end_indices[i]
    offset_maps = example_batch["Offset_mapping"][i]
    context = example_batch["Context"][i]
    best_start,best_end,best_prob = context_start,context_start,instance_start_probs[context_start] * instance_end_probs[context_start]
    for j in range(context_start,context_end + 1) : 
      for k in range(j,context_end + 1) : 
        current_prob = instance_start_probs[j] * instance_end_probs[k]
        if(current_prob > best_prob) : 
          best_start,best_end,best_prob = j,k,current_prob
    start_char = offset_maps[best_start][0]
    end_char = offset_maps[best_end][1]
    ans = context[start_char:end_char]
    pred_answers.append(ans)

  return pred_answers
    # print(batch_size)
    # print(i)
    # print("===============================1")
    # print(batch_start_probs.size)
    # print(batch_start_probs.size)
    # print("===============================2")  
    # print(batch_start_probs[i])
    # print(batch_end_probs[i])


In [15]:
#get model outputs for different batch
def getOutputs(batch, model):
  with torch.no_grad():
    input_ids = batch["Input_IDs"].to(device)
    attention_mask = batch["Attention_mask"].to(device)
    start_positions = batch["Start_token"].to(device)
    end_positions = batch["End_token"].to(device)
    model_outputs = model(input_ids = input_ids,attention_mask = attention_mask,start_positions=start_positions,end_positions=end_positions)
    start_logits = model_outputs[1] 
    end_logits = model_outputs[2] 
    # print('model_outputs')
    # print(model_outputs)
    # print('start_logits')
    # print(start_logits)
    # print('end_logits')
    # print(end_logits)
    batch_start_probs=torch.softmax(start_logits,dim=1)
    batch_end_probs=torch.softmax(end_logits,dim=1) 
  return batch_start_probs,batch_end_probs

In [16]:
#Calculation of the overall F1 score of the entire model on the given dataloader
def getAccuracies(loader,model) :
  avg_f1, avg_precision, avg_recall = 0, 0, 0
  bar = tqdm(enumerate(loader), total=len(loader)) 
  with torch.no_grad():
    for batch_idx,batch in bar : 
      if batch_idx != len(loader) -1: 
        batch_start_probs,batch_end_probs = getOutputs(batch,model)
        pred_answers = getBatchPreds(batch,batch_start_probs,batch_end_probs)
        actual_answers = batch["Answer"]
        precision, recall, f1 = f1_score(pred_answers,actual_answers)
        f1 = np.mean(f1)
        precision = np.mean(precision)
        recall = np.mean(recall)
        avg_f1 += f1
        avg_precision += precision
        avg_recall += recall
        bar.set_description(str({'precision' : (avg_precision/(batch_idx + 1)),'recall' : (avg_recall/(batch_idx + 1)),'f1 score' : (avg_f1/(batch_idx + 1))}))
      else:
        avg_f1 = avg_f1
        precision = precision
        recall = recall
        bar.set_description(str({'precision' : (avg_precision/(batch_idx + 1)),'recall' : (avg_recall/(batch_idx + 1)),'f1 score' : (avg_f1/(batch_idx + 1))}))
  avg_f1 = (avg_f1)/(len(loader))
  avg_precision = (avg_precision)/(len(loader))
  avg_recall = (avg_recall)/(len(loader))

  return avg_precision, avg_recall, avg_f1

In [17]:


precision, recall, valid_f1 = getAccuracies(BERTnq_valid_loader,myPreTrainedQAModel)

  0%|          | 0/109 [00:00<?, ?it/s]