# MobileBERT for Question Answering on the SQuAD dataset

### 3. Evaluating the fine-tuned model on the validation set 

In these notebooks we are going use [MobileBERT implemented by HuggingFace](https://huggingface.co/docs/transformers/model_doc/mobilebert) on the question answering task by text-extraction on the [The Stanford Question Answering Dataset (SQuAD)](https://rajpurkar.github.io/SQuAD-explorer/). The data is composed by a set of questions and paragraphs that contain the answers. The model will be trained to locate the answer in the context by giving the positions where the answer starts and ends.

In this notebook we are going to evaluate the model from a checkpoint we obtained in the fine-tuning step.

More info from HuggingFace docs:
- [Question Answering](https://huggingface.co/tasks/question-answering)
- [Glossary](https://huggingface.co/transformers/glossary.html#model-inputs)
- [Question Answering chapter of NLP course](https://huggingface.co/learn/nlp-course/chapter7/7?fw=pt)

In [None]:
import evaluate
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, MobileBertForQuestionAnswering
from datasets import load_dataset
from torch.utils.data import DataLoader

In [None]:
from datasets.utils import disable_progress_bar
from datasets import disable_caching


disable_progress_bar()
disable_caching()

In [None]:
hf_model_checkpoint = 'google/mobilebert-uncased'

In [None]:
model = MobileBertForQuestionAnswering.from_pretrained(hf_model_checkpoint)

model.eval();

# use checkpoint from fine-tuning
model.load_state_dict(
    torch.load('mobilebertqa_ft',
               map_location=torch.device('cpu')
              )
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(hf_model_checkpoint)

In [None]:
hf_dataset = load_dataset('squad')

In [None]:
# Preprocessing data
# Find more info about this in the notebook about exploring the dataset

MAX_SEQ_LEN = 300

def tokenize_dataset(squad_example, tokenizer=tokenizer):
    """Tokenize the text in the dataset and convert
    the start and ending positions of the answers
    from text to tokens"""
    max_len = MAX_SEQ_LEN
    context = squad_example['context']
    answer_start = squad_example['answers']['answer_start'][0]
    answer = squad_example['answers']['text'][0]
    squad_example_tokenized = tokenizer(
        context, squad_example['question'],
        padding='max_length',
        max_length=max_len,
        truncation='only_first',
    )
    token_start = len(tokenizer.tokenize(context[:answer_start + 1]))
    token_end = len(tokenizer.tokenize(answer)) + token_start

    squad_example_tokenized['start_token_idx'] = token_start
    squad_example_tokenized['end_token_idx'] = token_end

    return squad_example_tokenized


def filter_samples_by_max_seq_len(squad_example):
    """Fliter out the samples where the answers are
    not within the first `MAX_SEQ_LEN` tokens"""
    max_len = MAX_SEQ_LEN
    answer_start = squad_example['answers']['answer_start'][0]
    answer = squad_example['answers']['text'][0]
    token_start = len(tokenizer.tokenize(squad_example['context'][:answer_start]))
    token_end = len(tokenizer.tokenize(answer)) + token_start
    return token_end < max_len

dataset_filtered = hf_dataset['validation'].filter(
    filter_samples_by_max_seq_len,
    num_proc=24,
)

dataset_tok = dataset_filtered.map(
    tokenize_dataset,
    remove_columns=hf_dataset['validation'].column_names,
    num_proc=12,
)
dataset_tok.set_format('pt')

In [None]:
eval_dataloader = DataLoader(
    dataset_tok,
    shuffle=True,   # shuffle to print different predictions
    batch_size=8
)

In [None]:
# Define a metric that tell us how good the preductions are
squad_metric = evaluate.load("squad")

In [None]:
# Evaluate a few random samples

for batch in eval_dataloader:
    # evaluate the model
    with torch.no_grad():
        outputs = model(
            input_ids=batch['input_ids'],
            token_type_ids=batch['token_type_ids'],
            attention_mask=batch['attention_mask']
        )
    
    # obtain the predicted start and end possitions logits and apply
    # a softmax to to them to get the probability distribution
    start_distr = F.softmax(outputs.start_logits, dim=-1)
    end_distr   = F.softmax(outputs.end_logits,   dim=-1)
    
    # loop over the batch of inputs and outputs
    for context_tokens, start_ref, end_ref, start_pred, end_pred, in zip(batch['input_ids'],
                                                               batch['start_token_idx'], batch['end_token_idx'],
                                                               start_distr, end_distr):
        # get back the text from the tokenizers since both the train and
        # validation sets has been replaced by tokenized versions
        # * This is also important for the metrics because the original
        #   text may be different than the one recovered from the
        #   tokens in terms of spaces around puntuation or certain
        #   symbols. Will be working only with text recovered
        #   from tokens
        context_text = tokenizer.decode(context_tokens, skip_special_tokens=True,
                                        clean_up_tokenization_spaces=True)

        # find the max class that the softmax gives
        start_pred = torch.argmax(start_pred)
        end_pred = torch.argmax(end_pred)
        
        # predicted answer
        answer_tokens = context_tokens[start_pred:end_pred]
        answer_text = tokenizer.decode(answer_tokens, skip_special_tokens=True,
                                       clean_up_tokenization_spaces=True)
        start_text = len(tokenizer.decode(context_tokens[:start_pred],
                                          skip_special_tokens=True,
                                          clean_up_tokenization_spaces=True)) + 1
        
        # reference answers
        answer_tokens_ref = context_tokens[start_ref:end_ref]
        answer_text_ref = tokenizer.decode(answer_tokens_ref, skip_special_tokens=True,
                                           clean_up_tokenization_spaces=True)

        # metrics
        predictions = [{'prediction_text': answer_text, 'id': 'xxx'}]
        references = [{'answers': {'answer_start': [start_text], 'text': [answer_text_ref]}, 'id': 'xxx'}]
        results = squad_metric.compute(predictions=predictions, references=references)
                
        print(f'* {context_text}\n')
        print(f'[  model  ] {answer_text}')
        print(f'[   ref   ] {answer_text_ref}')
        print(f'[ metrics ] {results}')
        print('\n---\n')
        
    # Run only the first batch
    break