
<h1 align="center" style="color:green;font-size: 3em;" >Building a Question Answering System Using Bert</h1>




<a class="anchor" id="section2"></a>
<h2 style="color:green;font-size: 2em;">Simple Inference Pipeline on Pretrained Model</h2>

In [27]:
import torch
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer
import warnings
warnings.simplefilter("ignore")

weight_path = "kaporter/bert-base-uncased-finetuned-squad"
# loading tokenizer
tokenizer = BertTokenizer.from_pretrained(weight_path)
#loading the model
model = BertForQuestionAnswering.from_pretrained(weight_path)

Lets take an example

```
question = "How many parameters does BERT-large have?"

context = "BERT-large is really big... it has 24-layers and an embedding size of 1,024, for a total of 340M parameters! Altogether it is 1.34GB, so expect it to take a couple minutes to download to your Colab instance."

answer = 340M

```



In [28]:
question = "How many parameters does BERT-large have?"
context = "BERT-large is really big... it has 24-layers and an embedding size of 1,024, for a total of 340M parameters! Altogether it is 1.34GB, so expect it to take a couple minutes to download to your Colab instance."

input_ids = tokenizer.encode(question, context)
print (f'We have about {len(input_ids)} tokens generated')

tokens = tokenizer.convert_ids_to_tokens(input_ids)
print(" ")
print('Some examples of token-input_id pairs:')

for i, (token,inp_id) in enumerate(zip(tokens,input_ids)):
    print(token,":",inp_id)


We have about 70 tokens generated
 
Some examples of token-input_id pairs:
[CLS] : 101
how : 2129
many : 2116
parameters : 11709
does : 2515
bert : 14324
- : 1011
large : 2312
have : 2031
? : 1029
[SEP] : 102
bert : 14324
- : 1011
large : 2312
is : 2003
really : 2428
big : 2502
. : 1012
. : 1012
. : 1012
it : 2009
has : 2038
24 : 2484
- : 1011
layers : 9014
and : 1998
an : 2019
em : 7861
##bed : 8270
##ding : 4667
size : 2946
of : 1997
1 : 1015
, : 1010
02 : 6185
##4 : 2549
, : 1010
for : 2005
a : 1037
total : 2561
of : 1997
340 : 16029
##m : 2213
parameters : 11709
! : 999
altogether : 10462
it : 2009
is : 2003
1 : 1015
. : 1012
34 : 4090
##gb : 18259
, : 1010
so : 2061
expect : 5987
it : 2009
to : 2000
take : 2202
a : 1037
couple : 3232
minutes : 2781
to : 2000
download : 8816
to : 2000
your : 2115
cola : 15270
##b : 2497
instance : 6013
. : 1012
[SEP] : 102


In [29]:
sep_idx = tokens.index('[SEP]')

# we will provide including [SEP] token which seperates question from context and 1 for rest.
token_type_ids = [0 for i in range(sep_idx+1)] + [1 for i in range(sep_idx+1,len(tokens))]
print(token_type_ids)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


Now lets pass our input through model and sees the output.

In [30]:
# Run our example through the model.
out = model(torch.tensor([input_ids]), # The tokens representing our input text.
                token_type_ids=torch.tensor([token_type_ids]))

start_logits,end_logits = out['start_logits'],out['end_logits']
# Find the tokens with the highest `start` and `end` scores.
answer_start = torch.argmax(start_logits)
answer_end = torch.argmax(end_logits)

ans = ''.join(tokens[answer_start:answer_end])
print('Predicted answer:', ans)

Predicted answer: 340


In [31]:
del model
del tokenizer

We have seen that how we can predict using a finetuned bert(bert-base) model. Now lets train and model on Squad dataest

In [32]:
import transformers
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn.functional as F
import numpy as np
import pandas as pd
import os
import warnings
warnings.simplefilter("ignore")

**Loading dataset**

In [33]:
!pip install datasets



In [34]:
from datasets import load_dataset
dataset = load_dataset("rajpurkar/squad")
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

We have about 87599 data points in train and 10570 in validation.
Lets see some sample

In [35]:
# to make text bold
s_bold = '\033[1m'
e_bold = '\033[0;0m'

print(s_bold + 'Train Data Sample.....' + e_bold)
train_data = dataset["train"]
for data in train_data:
    print(' ')
    print(s_bold + 'ID -' + e_bold, data['id'])
    print(s_bold +'TITLE - '+ e_bold, data['title'])
    print(s_bold + 'CONTEXT - '+ e_bold,data['context'])
    print(s_bold + 'ANSWERS - ' + e_bold,data['answers']['text'])
    print(s_bold + 'ANSWERS START INDEX - ' + e_bold,data['answers']['answer_start'])
    print(' ')
    break

print('---'*30)
print(s_bold + 'Validation Data Sample.....' + e_bold)
train_data = dataset["validation"]
for data in train_data:
    print(' ')
    print(s_bold + 'ID -' + e_bold, data['id'])
    print(s_bold +'TITLE - '+ e_bold, data['title'])
    print(s_bold + 'CONTEXT - '+ e_bold,data['context'])
    print(s_bold + 'ANSWERS - ' + e_bold,data['answers']['text'])
    print(s_bold + 'ANSWERS START INDEX - ' + e_bold,data['answers']['answer_start'])
    print(' ')
    break



[1mTrain Data Sample.....[0;0m
 
[1mID -[0;0m 5733be284776f41900661182
[1mTITLE - [0;0m University_of_Notre_Dame
[1mCONTEXT - [0;0m Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.
[1mANSWERS - [0;0m ['Saint Bernadette Soubirous']
[1mANSWERS START INDEX - [0;0m [515]
 
-----------------------------------------------------------------------

Another important thing that we found is that we can see multiple answers in one of our validation sample.Lets analyze it further.

In [36]:
dataset["train"].filter(lambda x: len(x["answers"]["text"]) != 1)

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 0
})

In [37]:
dataset["validation"].filter(lambda x: len(x["answers"]["text"]) != 1)

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 10567
})

We can see that in train we have only one answer for all the samples.But in validation data there are 10567 samples with multiple answers.

Before getting to the problem Let us understand how a question answering problem is solved.

In [38]:
## Lets sample some dataset so that we can reduce training time.
dataset["train"] = dataset["train"].select([i for i in range(8000)])
dataset["validation"] = dataset["validation"].select([i for i in range(2000)])
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 8000
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 2000
    })
})

**Data Preprocessing**



In [39]:
from transformers import AutoTokenizer

# model_checkpoint = "bert-base-cased"
# tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

trained_checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(trained_checkpoint)

context = dataset["train"][0]["context"]
question = dataset["train"][0]["question"]
answer = dataset["train"][0]["answers"]["text"]


inputs = tokenizer(
    question,
    context,
    max_length=160,
    truncation="only_second",  # only to truncate context
    stride=70,  # no of overlapping tokens  between concecute context pieces
    return_overflowing_tokens=True,  #to let tokenizer know we want overflow tokens
)


print(f"The 4 examples gave {len(inputs['input_ids'])} features.")
print(f"Here is where each comes from: {inputs['overflow_to_sample_mapping']}.")

print('Question: ',question)
print(' ')
print('Context : ',context)
print(' ')
print('Answer: ', answer)
print('--'*25)

for i,ids in enumerate(inputs["input_ids"]):
    print('Context piece', i+1)
    print(tokenizer.decode(ids[ids.index(102):]))
    print(' ')



The 4 examples gave 2 features.
Here is where each comes from: [0, 0].
Question:  To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
 
Context :  Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.
 
Answer:  ['Saint Bernadette Soubirous']
--------------------------------------------------
Context piece 1
[SEP] architecturally, the s

In [40]:
from transformers import AutoTokenizer

del tokenizer
trained_checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(trained_checkpoint)

def train_data_preprocess(examples):

    """
    generate start and end indexes of answer in context
    """

    def find_context_start_end_index(sequence_ids):
        """
        returns the token index in whih context starts and ends
        """
        token_idx = 0
        while sequence_ids[token_idx] != 1:  #means its special tokens or tokens of queston
            token_idx += 1                   # loop only break when context starts in tokens
        context_start_idx = token_idx

        while sequence_ids[token_idx] == 1:
            token_idx += 1
        context_end_idx = token_idx - 1
        return context_start_idx,context_end_idx


    questions = [q.strip() for q in examples["question"]]
    context = examples["context"]
    answers = examples["answers"]

    inputs = tokenizer(
        questions,
        context,
        max_length=512,
        truncation="only_second",
        stride=128,
        return_overflowing_tokens=True,  #returns id of base context
        return_offsets_mapping=True,  # returns (start_index,end_index) of each token
        padding="max_length"
    )


    start_positions = []
    end_positions = []


    for i,mapping_idx_pairs in enumerate(inputs['offset_mapping']):
        context_idx = inputs['overflow_to_sample_mapping'][i]

        # from main context
        answer = answers[context_idx]
        answer_start_char_idx = answer['answer_start'][0]
        answer_end_char_idx = answer_start_char_idx + len(answer['text'][0])


        # now we have to find it in sub contexts
        tokens = inputs['input_ids'][i]
        sequence_ids = inputs.sequence_ids(i)

        # finding the context start and end indexes wrt sub context tokens
        context_start_idx,context_end_idx = find_context_start_end_index(sequence_ids)

        #if the answer is not fully inside context label it as (0,0)
        # starting and end index of charecter of full context text
        context_start_char_index = mapping_idx_pairs[context_start_idx][0]
        context_end_char_index = mapping_idx_pairs[context_end_idx][1]


        #If the answer is not fully inside the context, label is (0, 0)
        if (context_start_char_index > answer_start_char_idx) or (
            context_end_char_index < answer_end_char_idx):
            start_positions.append(0)
            end_positions.append(0)

        else:

            # else its start and end token positions
            # here idx indicates index of token
            idx = context_start_idx
            while idx <= context_end_idx and mapping_idx_pairs[idx][0] <= answer_start_char_idx:
                idx += 1
            start_positions.append(idx - 1)


            idx = context_end_idx
            while idx >= context_start_idx and mapping_idx_pairs[idx][1] > answer_end_char_idx:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

train_sample = dataset["train"].select([i for i in range(200)])

train_dataset = train_sample.map(
    train_data_preprocess,
    batched=True,
    remove_columns=dataset["train"].column_names
)

len(dataset["train"]),len(train_dataset)

(8000, 200)

We can see a increase in number of datapoints after the tokenization method we used.Lets compare the values before and after tokenization.We will print some of the questions,context and answers after tokenization and compare with the original one.

In [41]:
def print_context_and_answer(idx,mini_ds=dataset["train"]):

    print(idx)
    print('----')
    question = mini_ds[idx]['question']
    context = mini_ds[idx]['context']
    answer = mini_ds[idx]['answers']['text']
    print('Theoretical values :')
    print(' ')
    print('Question: ')
    print(question)
    print(' ')
    print('Context: ')
    print(context)
    print(' ')
    print('Answer: ')
    print(answer)
    print(' ')
    answer_start_char_idx = mini_ds[idx]['answers']['answer_start'][0]
    answer_end_char_idx = answer_start_char_idx + len(mini_ds[idx]['answers']['text'][0])
    print('Start and end index of text: ',answer_start_char_idx,answer_end_char_idx)
    print('----'*20)
    print('Values after tokenization:')


    #answer
    sep_tok_index = train_dataset[idx]['input_ids'].index(102) #get index for [SEP]
    question_ = train_dataset[idx]['input_ids'][:sep_tok_index+1]
    question_decoded = tokenizer.decode(question_)
    context_ = train_dataset[idx]['input_ids'][sep_tok_index+1:]
    context_decoded = tokenizer.decode(context_)
    start_idx = train_dataset[idx]['start_positions']
    end_idx = train_dataset[idx]['end_positions']
    answer_toks = train_dataset[idx]['input_ids'][start_idx:end_idx]
    answer_decoded = tokenizer.decode(answer_toks)
    print(' ')
    print('Question: ')
    print(question_decoded)
    print(' ')
    print('Context: ')
    print(context_decoded)
    print(' ')
    print('Answer: ')
    print(answer_decoded)
    print(' ')
    print('Start pos and end pos of tokens: ',train_dataset[idx]['start_positions'],train_dataset[idx]['end_positions'])
    print('____'*20)


print_context_and_answer(0)
print_context_and_answer(1)
print_context_and_answer(2)
print_context_and_answer(3)

0
----
Theoretical values :
 
Question: 
To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
 
Context: 
Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.
 
Answer: 
['Saint Bernadette Soubirous']
 
Start and end index of text:  515 541
--------------------------------------------------------------------------------
Values after tok



<a class="anchor" id="section4"></a>
<h2 style="color:green;font-size: 2em;">Understanding Metrics needed for Evaluation</h2>

**How to evaluate the model?**

Lets take a small eval set.Here we donot need to do much preprocesing. . We will use the pretrained "distilbert-base-uncased" weights which is not fine tuned and lets see how the model performs.

* We will set offset to None for all those questions part of the data.
* We will also append base context id to each sample


We will evalue using our untuned bert-base model and lets see the performance.

In [42]:
from transformers import AutoTokenizer

def preprocess_validation_examples(examples):
    """
    preprocessing validation data
    """
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=512,
        truncation="only_second",
        stride=128,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_map = inputs.pop("overflow_to_sample_mapping")

    base_ids = []

    for i in range(len(inputs["input_ids"])):

        # take the base id (ie in cases of overflow happens we get base id)
        base_context_idx = sample_map[i]
        base_ids.append(examples["id"][base_context_idx])

        # sequence id indicates the input. 0 for first input and 1 for second input
        # and None for special tokens by default
        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        # for Question tokens provide offset_mapping as None
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["base_id"] = base_ids
    return inputs


# del tokenizer

trained_checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(trained_checkpoint)

data_val_sample = dataset["validation"].select([i for i in range(100)])
eval_set = data_val_sample.map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=dataset["validation"].column_names,
)
len(eval_set)

100

In [43]:
import torch
from transformers import DistilBertForQuestionAnswering

# del tokenizer
# take a small sample

eval_set_for_model = eval_set.remove_columns(["base_id", "offset_mapping"])
eval_set_for_model.set_format("torch")

checkpoint =  "distilbert-base-uncased"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

batch = {k: eval_set_for_model[k].to(device) for k in eval_set_for_model.column_names}

model = DistilBertForQuestionAnswering.from_pretrained(checkpoint).to(
    device
)


with torch.no_grad():
    outputs = model(**batch)

start_logits = outputs.start_logits.cpu().numpy()
end_logits = outputs.end_logits.cpu().numpy()

start_logits.shape,end_logits.shape

Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


((100, 512), (100, 512))

We will evaluate our model using Evaluate library. We use 2 metrics for evaluation.

1. Exact match
2. f1 score

**Exact Match**

For each question-answer pair if the charecters of the models prediction exactly match with charecters of true answer then EM=1 else 0.When assessing against a negative example, if the model predicts any text at all, it automatically receives a 0 for that example

**F1 score**

F1 score depends up on precision and recall.
```
f1 score = 2 * (precision * recall)/ precision + recall

```

If we take the theoritical answers and predicted answers,the number of shared words between theoritical and predicted answer is the basis for f1 score.precision is the ratio of the number of shared words to the total number of words in the prediction, and recall is the ratio of the number of shared words to the total number of words in the ground truth.

In [44]:
!pip install evaluate



In [45]:
import numpy as np
import collections
import evaluate

def predict_answers_and_evaluate(start_logits,end_logits,eval_set,examples):
    """
    make predictions
    Args:
    start_logits : strat_position prediction logits
    end_logits: end_position prediction logits
    eval_set: processed val data
    examples: unprocessed val data with context text
    """
    # appending all id's corresponding to the base context id
    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(eval_set):
        example_to_features[feature["base_id"]].append(idx)

    n_best = 20
    max_answer_length = 30
    predicted_answers = []

    for example in examples:
        example_id = example["id"]
        context = example["context"]
        answers = []

        # looping through each sub contexts corresponding to a context and finding
        # answers
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = eval_set["offset_mapping"][feature_index]

            # sorting the predictions of all hidden states and taking best n_best prediction
            # means taking the index of top 20 tokens
            start_indexes = np.argsort(start_logit).tolist()[::-1][:n_best]
            end_indexes = np.argsort(end_logit).tolist()[::-1][:n_best]


            for start_index in start_indexes:
                for end_index in end_indexes:

                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length.
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                       ):
                        continue

                    answers.append({
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                        })


            # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    metric = evaluate.load("squad")

    theoretical_answers = [
            {"id": ex["id"], "answers": ex["answers"]} for ex in examples
    ]

    metric_ = metric.compute(predictions=predicted_answers, references=theoretical_answers)
    return predicted_answers,metric_



Let us evaluate the model.This metric expects the predicted answers in the format we saw above (a list of dictionaries with one key for the ID of the example and one key for the predicted text) and the theoretical answers in the format.

In [46]:
pred_answers,metrics_ = predict_answers_and_evaluate(start_logits,end_logits,eval_set,data_val_sample)
metrics_

{'exact_match': 2.0, 'f1': 5.878787878787878}

We have very poor score as expected. Now we will fine tune the model and will see the performance in whole validation dataset.


<a class="anchor" id="section5"></a>
<h2 style="color:green;font-size: 2em;">Training a Question Answering System based on Bert</h2>

Lets again load the dataset from fresh. We will sample a small portion of dataset for training. You can train with full data if you have enough resources.

In [47]:
from datasets import load_dataset
dataset = load_dataset("squad")

#lets sample a small dataset
dataset['train'] = dataset['train'].select([i for i in range(5000)])
dataset['validation'] = dataset['validation'].select([i for i in range(500)])

dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 5000
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 500
    })
})

Let us define a Pyorch dataloader

In [48]:
from torch.utils.data import DataLoader, Dataset


class DataQA(Dataset):
    def __init__(self, dataset,mode="train"):
        self.mode = mode


        if self.mode == "train":
            # sampling
            self.dataset = dataset["train"]
            self.data = self.dataset.map(train_data_preprocess,
                                                      batched=True,
                            remove_columns= dataset["train"].column_names)

        else:
            self.dataset = dataset["validation"]
            self.data = self.dataset.map(preprocess_validation_examples,
            batched=True,remove_columns = dataset["validation"].column_names,
               )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        out = {}
        example = self.data[idx]
        out['input_ids'] = torch.tensor(example['input_ids'])
        out['attention_mask'] = torch.tensor(example['attention_mask'])


        if self.mode == "train":

            out['start_positions'] = torch.unsqueeze(torch.tensor(example['start_positions']),dim=0)
            out['end_positions'] = torch.unsqueeze(torch.tensor(example['end_positions']),dim=0)

        return out


In [49]:
trained_checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(trained_checkpoint)


train_dataset = DataQA(dataset,mode="train")
val_dataset = DataQA(dataset,mode="validation")



for i,d in enumerate(train_dataset):
    for k in d.keys():
        print(k + ' : ', d[k].shape)
    print('--'*40)

    if i == 3:
        break

print('__'*50)

for i,d in enumerate(val_dataset):
    for k in d.keys():
        print(k + ' : ', len(d[k]))
    print('--'*40)

    if i == 3:
        break

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

input_ids :  torch.Size([512])
attention_mask :  torch.Size([512])
start_positions :  torch.Size([1])
end_positions :  torch.Size([1])
--------------------------------------------------------------------------------
input_ids :  torch.Size([512])
attention_mask :  torch.Size([512])
start_positions :  torch.Size([1])
end_positions :  torch.Size([1])
--------------------------------------------------------------------------------
input_ids :  torch.Size([512])
attention_mask :  torch.Size([512])
start_positions :  torch.Size([1])
end_positions :  torch.Size([1])
--------------------------------------------------------------------------------
input_ids :  torch.Size([512])
attention_mask :  torch.Size([512])
start_positions :  torch.Size([1])
end_positions :  torch.Size([1])
--------------------------------------------------------------------------------
____________________________________________________________________________________________________
input_ids :  512
attention_mask :  

Let us load the data in batches

In [50]:
from transformers import default_data_collator
from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    collate_fn=default_data_collator,
    batch_size=2,
)
eval_dataloader = DataLoader(
    val_dataset, collate_fn=default_data_collator, batch_size=2
)




for batch in train_dataloader:
   print(batch['input_ids'].shape)
   print(batch['attention_mask'].shape)
   print(batch['start_positions'].shape)
   print(batch['end_positions'].shape)
   break

print('---'*20)

for batch in eval_dataloader:
   print(batch['input_ids'].shape)
   print(batch['attention_mask'].shape)
   break

torch.Size([2, 512])
torch.Size([2, 512])
torch.Size([2, 1])
torch.Size([2, 1])
------------------------------------------------------------
torch.Size([2, 512])
torch.Size([2, 512])


**Define Model**

In [51]:
from transformers import DistilBertForQuestionAnswering
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Available device: {device}')

checkpoint =  "distilbert-base-uncased"
model = DistilBertForQuestionAnswering.from_pretrained(checkpoint)
model = model.to(device)

Available device: cuda


Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


**Model Training**

In [52]:
from transformers import AdamW
from tqdm.notebook import tqdm
import datetime
import numpy as np
import collections
import evaluate

optimizer = AdamW(model.parameters(), lr=2e-5)

epochs = 2

# Total number of training steps is [number of batches] x [number of epochs].
# (Note that this is not the same as the number of training samples).
total_steps = len(train_dataloader) * epochs
print(total_steps)


def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed)))

    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))





5010


We need processed validation data at the time of evaluation to get offsets for each context

In [53]:
# we need processed validation data to get offsets at the time of evaluation
validation_processed_dataset = dataset["validation"].map(preprocess_validation_examples,
            batched=True,remove_columns = dataset["validation"].column_names,
               )

In [54]:
import random,time
import numpy as np

# to reproduce results
seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)


#storing all training and validation stats
stats = []


#to measure total training time
total_train_time_start = time.time()

for epoch in range(epochs):
    print(' ')
    print(f'=====Epoch {epoch + 1}=====')
    print('Training....')

    # ===============================
    #    Train
    # ===============================
    # measure how long training epoch takes
    t0 = time.time()

    training_loss = 0
    # loop through train data
    model.train()
    for step,batch in enumerate(train_dataloader):

        # we will print train time in every 40 epochs
        if step%40 == 0 and not step == 0:
              elapsed_time = format_time(time.time() - t0)
              # Report progress.
              print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed_time))



        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)



        #set gradients to zero
        model.zero_grad()

        result = model(input_ids = input_ids,
                        attention_mask = attention_mask,
                        start_positions = start_positions,
                        end_positions = end_positions,
                        return_dict=True)

        loss = result.loss

        #accumulate the loss over batches so that we can calculate avg loss at the end
        training_loss += loss.item()

        #perform backward prorpogation
        loss.backward()

        # update the gradients
        optimizer.step()

    # calculate avg loss
    avg_train_loss = training_loss/len(train_dataloader)

    # calculates training time
    training_time = format_time(time.time() - t0)


    print("")
    print("  Average training loss: {0:.2f}".format(avg_train_loss))
    print("  Training epoch took: {:}".format(training_time))


    # ===============================
    #    Validation
    # ===============================

    print("")
    print("Running Validation...")

    t0 = time.time()

    # Put the model in evaluation mode--the dropout layers behave differently
    # during evaluation.
    model.eval()


    start_logits,end_logits = [],[]
    for step,batch in enumerate(eval_dataloader):


        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)


        # Tell pytorch not to bother with constructing the compute graph during
        # the forward pass, since this is only needed for backprop (training).
        with torch.no_grad():
             result = model(input_ids = input_ids,
                        attention_mask = attention_mask,return_dict=True)



        start_logits.append(result.start_logits.cpu().numpy())
        end_logits.append(result.end_logits.cpu().numpy())


    start_logits = np.concatenate(start_logits)
    end_logits = np.concatenate(end_logits)
    # start_logits = start_logits[: len(val_dataset)]
    # end_logits = end_logits[: len(val_dataset)]




    # calculating metrics
    answers,metrics_ = predict_answers_and_evaluate(start_logits,end_logits,validation_processed_dataset,dataset["validation"])
    print(f'Exact match: {metrics_["exact_match"]}, F1 score: {metrics_["f1"]}')


    print('')
    # Measure how long the validation run took.
    validation_time = format_time(time.time() - t0)

    print("  Validation took: {:}".format(validation_time))

print("")
print("Training complete!")

print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_train_time_start)))


 
=====Epoch 1=====
Training....
  Batch    40  of  2,505.    Elapsed: 0:00:04.
  Batch    80  of  2,505.    Elapsed: 0:00:09.
  Batch   120  of  2,505.    Elapsed: 0:00:13.
  Batch   160  of  2,505.    Elapsed: 0:00:18.
  Batch   200  of  2,505.    Elapsed: 0:00:22.
  Batch   240  of  2,505.    Elapsed: 0:00:27.
  Batch   280  of  2,505.    Elapsed: 0:00:31.
  Batch   320  of  2,505.    Elapsed: 0:00:36.
  Batch   360  of  2,505.    Elapsed: 0:00:40.
  Batch   400  of  2,505.    Elapsed: 0:00:44.
  Batch   440  of  2,505.    Elapsed: 0:00:49.
  Batch   480  of  2,505.    Elapsed: 0:00:53.
  Batch   520  of  2,505.    Elapsed: 0:00:58.
  Batch   560  of  2,505.    Elapsed: 0:01:02.
  Batch   600  of  2,505.    Elapsed: 0:01:06.
  Batch   640  of  2,505.    Elapsed: 0:01:11.
  Batch   680  of  2,505.    Elapsed: 0:01:15.
  Batch   720  of  2,505.    Elapsed: 0:01:20.
  Batch   760  of  2,505.    Elapsed: 0:01:24.
  Batch   800  of  2,505.    Elapsed: 0:01:29.
  Batch   840  of  2,505.  

**Note: You can train for more epochs with full data which will provide us a better result**