# MobileBERT for Question Answering on the SQuAD dataset

### 2. Fine-tuning the model

In these notebooks we are going use [MobileBERT implemented by HuggingFace](https://huggingface.co/docs/transformers/model_doc/mobilebert) on the question answering task by text-extraction on the [The Stanford Question Answering Dataset (SQuAD)](https://rajpurkar.github.io/SQuAD-explorer/). The data is composed by a set of questions and paragraphs that contain the answers. The model will be trained to locate the answer in the context by giving the positions where the answer starts and ends.

In this notebook we are going to Fine-tuning the model.

More info from HuggingFace docs:
- [Question Answering](https://huggingface.co/tasks/question-answering)
- [Glossary](https://huggingface.co/transformers/glossary.html#model-inputs)
- [Question Answering chapter of NLP course](https://huggingface.co/learn/nlp-course/chapter7/7?fw=pt)

In [None]:
import torch
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, MobileBertForQuestionAnswering
from datasets import load_dataset
from torch.utils.data import DataLoader

In [None]:
from datasets.utils import disable_progress_bar
from datasets import disable_caching


disable_progress_bar()
disable_caching()

In [None]:
# Extract the tokenizer that was used for pretraining that model
# We want to use https://huggingface.co/google/mobilebert-uncased

In [None]:
# instantiate the model
# We will use the model MobileBertForQuestionAnswering that we imported on the first cell
# Use this as reference
# https://huggingface.co/docs/transformers/model_doc/mobilebert#transformers.MobileBertForPreTraining.forward.example
model = ...

### Question

When instantiating model, there's a red message coming up. What does it mean?

In [None]:
# load the dataset

In [None]:
# Preprocessing data
# Include here all the preprocessing that was done on the notebook about exploring the dataset
# and apply it via the dataset filter and the map 



In [None]:
# Define a PyTorch Dataloader for the train set
# Use batch size 256 for a fast training

In [None]:
# Move the model to the GPU 0

In [2]:
# Ensure model is in training mode

## Training

We are going to train for two epocs. We will use a different learning rate values in each epoch:
 - epoch 1: `lr = 2e-4` (to move fast on the loss function over the parameter space)
 - epoch 2: `lr = 2e-5` (to avoid jumping around and start converging towards a minimum)

We will do this manually:
 - Run epoch one
 - Redifine the optimizer with the new learning rate and run again the training

We should aim to loss values around 0.6, which will ensure "decent" predictions

In [1]:
# Define optimizer using "AdamW" (Adam with decoupled weight decay)

In [None]:
def log(loss):
    """Utility function for plotting"""

    return loss.cpu().detach().numpy()

In [None]:
history = []

for epoch in range(1):
    for i, batch in enumerate(train_dataloader):
        ...  # reset automatic differentiation record
        # evaluate the model and pass the output references (start_token_idx and end_token_idx)
        outputs = model(input_ids=batch['input_ids'].to(device),
                        token_type_ids=batch['token_type_ids'].to(device),
                        attention_mask=batch['attention_mask'].to(device),
                        start_positions=batch['start_token_idx'].to(device),
                        end_positions=batch['end_token_idx'].to(device))        
        loss = outputs[0]          # obtain the loss from the model output (specific of HugginFace's API)
        history.append(log(loss))  # [not part of the traing] keep values for plotting later
        ...    # Add the back propagation from the loss
        ...    # update weights with the gradients

In [None]:
plt.plot(history, 'r-')
plt.ylabel('Loss')
plt.xlabel('Steps')
plt.grid()
plt.show()

In [None]:
# Save the mode to disk
torch.save(model.state_dict(), 'mobilebertqa_ft')

## Evaluating the model

When your model is trained, run the notebook `3_mobilebert-squad-testing.ipynb` to test it on the validation set.