# Fine-tuning a BERT model for text extraction with the SQuAD dataset

We are going to fine-tune [BERT implemented by HuggingFace](https://huggingface.co/bert-base-uncased) for the text-extraction task with a dataset of questions and answers with the [SQuAD (The Stanford Question Answering Dataset)](https://rajpurkar.github.io/SQuAD-explorer/) dataset.
The data is composed by a set of questions and corresponding paragraphs that contains the answers.
The model will be trained to locate the answer in the context by giving the positions where the answer starts and ends.

In this notebook we are going to do the training using multiple GPUs.

This notebook is based on [BERT (from HuggingFace Transformers) for Text Extraction](https://keras.io/examples/nlp/text_extraction_with_bert/).

More info:
- [Glossary - HuggingFace docs](https://huggingface.co/transformers/glossary.html#model-inputs)
- [BERT NLP — How To Build a Question Answering Bot](https://towardsdatascience.com/bert-nlp-how-to-build-a-question-answering-bot-98b1d1594d7b)

In [None]:
import ipcmagic

In [None]:
%ipcluster start -n 2 --mpi

In [None]:
%%px
import numpy as np
import os
import json
import dataset_utils as du
import eval_utils as eu
import torch
import torch.distributed as dist
from datetime import datetime
from transformers import BertTokenizer, BertForQuestionAnswering, AdamW
from tokenizers import BertWordPieceTokenizer
from torch.utils.data import DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel
from tqdm import tqdm
from tqdm.notebook import tqdm

In [None]:
%%px
hf_model = 'bert-base-uncased'
bert_cache = os.path.join(os.getcwd(), 'cache')

In [None]:
%%px
slow_tokenizer = BertTokenizer.from_pretrained(
    hf_model,
    cache_dir=os.path.join(bert_cache, f'_{hf_model}-tokenizer')
)
save_path = os.path.join(bert_cache, f'{hf_model}-tokenizer')
if not os.path.exists(save_path):
    os.makedirs(save_path)
    slow_tokenizer.save_pretrained(save_path)
    
# Load the fast tokenizer from saved file
tokenizer = BertWordPieceTokenizer(os.path.join(save_path, 'vocab.txt'),
                                   lowercase=True)

In [None]:
%%px
model = BertForQuestionAnswering.from_pretrained(
    hf_model,
    cache_dir=os.path.join(bert_cache, f'{hf_model}_qa')
)

In [None]:
%%px
train_path = os.path.join(bert_cache, 'data', 'train-v1.1.json')
eval_path = os.path.join(bert_cache, 'data', 'dev-v1.1.json')
with open(train_path) as f:
    raw_train_data = json.load(f)

with open(eval_path) as f:
    raw_eval_data = json.load(f)

In [None]:
%%px
max_len = 384

train_squad_examples = du.create_squad_examples(raw_train_data, max_len, tokenizer)
x_train, y_train = du.create_inputs_targets(train_squad_examples, shuffle=True, seed=42)
print(f"{len(train_squad_examples)} training points created.")

eval_squad_examples = du.create_squad_examples(raw_eval_data, max_len, tokenizer)
x_eval, y_eval = du.create_inputs_targets(eval_squad_examples)
print(f"{len(eval_squad_examples)} evaluation points created.")

In [None]:
%%px
class SquadDataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, idx):
        return (torch.tensor(self.x[0][idx]),
                torch.tensor(self.x[1][idx]),
                torch.tensor(self.x[2][idx]),
                torch.tensor(self.y[0][idx]),
                torch.tensor(self.y[1][idx]))

    def __len__(self):
        return len(self.x[0])

In [None]:
%%px
from pt_distr_env import setup_distr_env

setup_distr_env()
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()

In [None]:
%%px
batch_size = 16

train_set = SquadDataset(x_train, y_train)
train_sampler = DistributedSampler(train_set, num_replicas=world_size,
                                   rank=rank, shuffle=False, seed=42)

train_loader = DataLoader(train_set, batch_size=batch_size,
                          shuffle=False, sampler=train_sampler)

In [None]:
%%px
device = 0
model.to(device)
model = DistributedDataParallel(model, device_ids=[device])
model.train()

model.training

In [None]:
%%px
optim = AdamW(model.parameters(), lr=5e-5)

In [None]:
%%px
for epoch in range(1):
    for i, batch in enumerate(train_loader):
        optim.zero_grad()
        outputs = model(input_ids=batch[0].to(device),
                        token_type_ids=batch[1].to(device),
                        attention_mask=batch[2].to(device),
                        start_positions=batch[3].to(device),
                        end_positions=batch[4].to(device)
                       )
        loss = outputs[0]
        loss.backward()
        optim.step()
            
        # if i > 10:
        #      break

In [None]:
%%px --target 0
model_hash = datetime.now().strftime("%Y-%m-%d-%H%M%S")
model_path_name = './cache/model_trained_2_nodes_{model_hash}'

# save model's state_dict
# the model now is a DDP model
# use `model.module.state_dict()` in order the load it later on
# any number of nodes
torch.save(model.module.state_dict(), model_path_name)

# create the model again since the previous one is on the gpu
model_cpu = BertForQuestionAnswering.from_pretrained(
    "bert-base-uncased",
    cache_dir=os.path.join(bert_cache, 'bert-base-uncased_qa')
)

# load the model on cpu
model_cpu.load_state_dict(
    torch.load(model_path_name,
               map_location=torch.device('cpu'))
)

# load the model on gpu
# model.load_state_dict(torch.load(model_path_name))

In [None]:
%%px --target 0
model.eval()

model.training

In [None]:
%%px --target 0
samples = np.random.choice(len(x_eval[0]), 50, replace=False)

eu.EvalUtility(
    (x_eval[0][samples], x_eval[1][samples], x_eval[2][samples]),
    model_cpu,
    eval_squad_examples[samples]
).results()

In [None]:
%ipcluster stop