# Fine-tuning a BERT model for text extraction with the SQuAD dataset

We are going to fine-tune BERT for the text-extraction task with a dataset of questions and answers. The data is composed by a set of questions and corresponding paragraphs that contains the answers. The model will be trained to locate the answer in the context by giving the positions where the answer starts and ends.

This notebook is based on [BERT (from HuggingFace Transformers) for Text Extraction](https://keras.io/examples/nlp/text_extraction_with_bert/).

Here we use [BERT base model (uncased)](https://huggingface.co/bert-base-uncased) and the [BertForQuestionAnswering](https://huggingface.co/transformers/model_doc/bert.html?highlight=bertforquestionanswering#bertforquestionanswering) class from HugginFace .

In [1]:
import numpy as np
import os
import json
import dataset_utils as du
import eval_utils as eu
import torch
from transformers import BertTokenizer, BertForQuestionAnswering, AdamW
from tokenizers import BertWordPieceTokenizer
from torch.utils.data import DataLoader
from torch.nn import functional as F
from tqdm import tqdm

In [2]:
bert_cache = os.path.join(os.getcwd(), 'cache')

In [3]:
slow_tokenizer = BertTokenizer.from_pretrained(
    'bert-base-uncased',
    cache_dir=os.path.join(bert_cache, '_bert-base-uncased-tokenizer')
)
save_path = os.path.join(bert_cache, 'bert-base-uncased-tokenizer')
if not os.path.exists(save_path):
    os.makedirs(save_path)
    slow_tokenizer.save_pretrained(save_path)
    
# Load the fast tokenizer from saved file
tokenizer = BertWordPieceTokenizer(os.path.join(save_path, 'vocab.txt'),
                                   lowercase=True)

In [4]:
model = BertForQuestionAnswering.from_pretrained(
    "bert-base-uncased",
    cache_dir=os.path.join(bert_cache, 'bert-base-uncased_qa')
)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForQuestionAnswering: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased a

In [5]:
train_path = os.path.join(bert_cache, 'data', 'train-v1.1.json')
eval_path = os.path.join(bert_cache, 'data', 'dev-v1.1.json')
with open(train_path) as f:
    raw_train_data = json.load(f)

with open(eval_path) as f:
    raw_eval_data = json.load(f)

In [6]:
batch_size = 8
max_len = 384

In [7]:
train_squad_examples = du.create_squad_examples(raw_train_data, max_len, tokenizer)
x_train, y_train = du.create_inputs_targets(train_squad_examples, shuffle=True, seed=42)
print(f"{len(train_squad_examples)} training points created.")

eval_squad_examples = du.create_squad_examples(raw_eval_data, max_len, tokenizer)
x_eval, y_eval = du.create_inputs_targets(eval_squad_examples)
print(f"{len(eval_squad_examples)} evaluation points created.")

86136 training points created.
10331 evaluation points created.


In [8]:
class SquadDataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, idx):
        return (torch.tensor(self.x[0][idx]),
                torch.tensor(self.x[1][idx]),
                torch.tensor(self.x[2][idx]),
                torch.tensor(self.y[0][idx]),
                torch.tensor(self.y[1][idx]))

    def __len__(self):
        return len(self.x[0])

In [9]:
train_set = SquadDataset(x_train, y_train)
train_loader = DataLoader(train_set, batch_size=batch_size,
                          shuffle=True)

In [10]:
device = 0
model.to(device)
model.train()

BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_

In [10]:
optim = AdamW(model.parameters(), lr=5e-5)

In [13]:
for epoch in range(1):
    for i, batch in tqdm(enumerate(train_loader)):
        # if i > 10:
        #    break

        optim.zero_grad()
        outputs = model(input_ids=batch[0].to(device),
              token_type_ids=batch[1].to(device),
              attention_mask=batch[2].to(device),
              start_positions=batch[3].to(device),
              end_positions=batch[4].to(device)
             )
        
        loss = outputs[0]
        loss.backward()
        optim.step()

10767it [59:03,  3.04it/s]


In [14]:
# torch.save(model.state_dict(), './cache/model_trained_single_node')

In [13]:
# load the model on gpu
# model.load_state_dict(torch.load('./cache/model_trained_single_node'))

In [11]:
# load the model on cpu for evaluation
model.load_state_dict(
    torch.load('./cache/model_trained_single_node',
               map_location=torch.device('cpu'))
)
model.device

device(type='cpu')

In [12]:
model.eval()

BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_

In [16]:
samples = np.random.choice(len(x_eval[0]), 50, replace=False)

eu.EvalUtility(
    (x_eval[0][samples], x_eval[1][samples], x_eval[2][samples]),
    model,
    eval_squad_examples[samples]
).results()

  - 1290                           | ref: 1290                           | By what year was selling children into slavery common among the Mongols?
  - genghis khan                   | ref: Genghis Khan                   | Who created the code that governed military and civilian conduct in the Mongol Empire?
  -                                | ref: triplet-stacked                | How are peridinin-type chloroplasts' thylakoids arranged?
  - john sutcliffe                 | ref: John Sutcliffe                 | Who reported on the sideline for ESPN Deportes?
  - 1892                           | ref: 1892                           | What year was the first class taught at the University of Chicago?
  - british                        | ref: British                        | What nationality is the band Coldplay?
  - 98 million                     | ref: 9.8 million                    | How man volumes does the The University of Chicago Library system hold?
  - multimember proportional   