# Fine-tuning a BERT model for text extraction with the SQuAD dataset

We are going to fine-tune BERT for the text-extraction task with a dataset of questions and answers. The data is composed by a set of questions and corresponding paragraphs that contains the answers. The model will be trained to locate the answer in the context by giving the positions where the answer starts and ends.

This notebook is based on [BERT (from HuggingFace Transformers) for Text Extraction](https://keras.io/examples/nlp/text_extraction_with_bert/).

Here we use [BERT base model (uncased)](https://huggingface.co/bert-base-uncased) and the [BertForQuestionAnswering](https://huggingface.co/transformers/model_doc/bert.html?highlight=bertforquestionanswering#bertforquestionanswering) class from HugginFace .

In [1]:
import ipcmagic

In [2]:
%ipcluster start -n 2 --mpi

  0%|          | 0/2 [00:00<?, ?engine/s]

In [25]:
%%px
import numpy as np
import os
import json
import dataset_utils as du
import eval_utils as eu
import torch
import torch.distributed as dist
from transformers import BertTokenizer, BertForQuestionAnswering, AdamW
from tokenizers import BertWordPieceTokenizer
from torch.utils.data import DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel

In [4]:
%%px
bert_cache = os.path.join(os.getcwd(), 'cache')

In [5]:
%%px
slow_tokenizer = BertTokenizer.from_pretrained(
    'bert-base-uncased',
    cache_dir=os.path.join(bert_cache, '_bert-base-uncased-tokenizer')
)
save_path = os.path.join(bert_cache, 'bert-base-uncased-tokenizer')
if not os.path.exists(save_path):
    os.makedirs(save_path)
    slow_tokenizer.save_pretrained(save_path)
    
# Load the fast tokenizer from saved file
tokenizer = BertWordPieceTokenizer(os.path.join(save_path, 'vocab.txt'),
                                   lowercase=True)

In [20]:
%%px
model = BertForQuestionAnswering.from_pretrained(
    "bert-base-uncased",
    cache_dir=os.path.join(bert_cache, 'bert-base-uncased_qa')
)

[stderr:0] 
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForQuestionAnswering: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[stderr:0] 
Some weights of BertForQuestionAnswering were not initialized from the model checkpoin

In [7]:
%%px
train_path = os.path.join(bert_cache, 'data', 'train-v1.1.json')
eval_path = os.path.join(bert_cache, 'data', 'dev-v1.1.json')
with open(train_path) as f:
    raw_train_data = json.load(f)

with open(eval_path) as f:
    raw_eval_data = json.load(f)

In [8]:
%%px
batch_size = 8
max_len = 384

In [9]:
%%px
train_squad_examples = du.create_squad_examples(raw_train_data, max_len, tokenizer)
x_train, y_train = du.create_inputs_targets(train_squad_examples, shuffle=True, seed=42)
print(f"{len(train_squad_examples)} training points created.")

eval_squad_examples = du.create_squad_examples(raw_eval_data, max_len, tokenizer)
x_eval, y_eval = du.create_inputs_targets(eval_squad_examples)
print(f"{len(eval_squad_examples)} evaluation points created.")

[stdout:0] 86136 training points created.
[stdout:1] 86136 training points created.
[stdout:0] 10331 evaluation points created.
[stdout:1] 10331 evaluation points created.


In [10]:
%%px
class SquadDataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, idx):
        return (torch.tensor(self.x[0][idx]),
                torch.tensor(self.x[1][idx]),
                torch.tensor(self.x[2][idx]),
                torch.tensor(self.y[0][idx]),
                torch.tensor(self.y[1][idx]))

    def __len__(self):
        return len(self.x[0])

In [11]:
%%px
from pt_distr_env import setup_distr_env

setup_distr_env()
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()

In [12]:
%%px
train_set = SquadDataset(x_train, y_train)
train_sampler = DistributedSampler(train_set, num_replicas=world_size,
                                   rank=rank, shuffle=False, seed=42)

train_loader = DataLoader(train_set, batch_size=batch_size,
                          shuffle=False, sampler=train_sampler)

In [13]:
%%px
device = 0
model.to(device)
model = DistributedDataParallel(model, device_ids=[0])
model.train()

[stdout:1] 
nid02092:19446:19446 [0] NCCL INFO Bootstrap : Using [0]ipogif0:148.187.40.61<0>
nid02092:19446:19446 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
nid02092:19446:19446 [0] NCCL INFO NET/IB : No device found.
nid02092:19446:19446 [0] NCCL INFO NET/Socket : Using [0]ipogif0:148.187.40.61<0>
nid02092:19446:19446 [0] NCCL INFO Using network Socket
nid02092:19446:19565 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 8/8/64
nid02092:19446:19565 [0] NCCL INFO Trees [0] -1/-1/-1->1->0|0->1->-1/-1/-1 [1] 0/-1/-1->1->-1|-1->1->0/-1/-1
nid02092:19446:19565 [0] NCCL INFO Setting affinity for GPU 0 to ffffff
nid02092:19446:19565 [0] NCCL INFO Channel 00 : 0[2000] -> 1[2000] [receive] via NET/Socket/0
nid02092:19446:19565 [0] NCCL INFO Channel 00 : 1[2000] -> 0[2000] [send] via NET/Socket/0
nid02092:19446:19565 [0] NCCL INFO Channel 01 : 0[2000] -> 1[2000] [receive] via NET/Socket/0
nid02092:19446:19565 [0] NCCL INFO Channel 01 : 1[2000] ->

[stderr:1] 
libibverbs: Could not locate libibgni (/usr/lib64/libibgni.so.1: undefined symbol: verbs_uninit_context)
[stderr:0] 
libibverbs: Could not locate libibgni (/usr/lib64/libibgni.so.1: undefined symbol: verbs_uninit_context)


[0;31mOut[0:11]: [0m
DistributedDataParallel(
  (module): BertForQuestionAnswering(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=76

[0;31mOut[1:11]: [0m
DistributedDataParallel(
  (module): BertForQuestionAnswering(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=76

In [14]:
%%px
optim = AdamW(model.parameters(), lr=5e-5)

In [15]:
%%px
for epoch in range(1):
    for i, batch in enumerate(train_loader):
        if i > 2:
             break

        optim.zero_grad()
        outputs = model(input_ids=batch[0].to(device),
                        token_type_ids=batch[1].to(device),
                        attention_mask=batch[2].to(device),
                        start_positions=batch[3].to(device),
                        end_positions=batch[4].to(device)
                       )
        loss = outputs[0]
        loss.backward()
        optim.step()
        if not i % 100:
            print(i, loss)

[stdout:0] 0 tensor(6.0962, device='cuda:0', grad_fn=<DivBackward0>)
[stdout:1] 0 tensor(5.9524, device='cuda:0', grad_fn=<DivBackward0>)


In [18]:
# %%px
# load the model on gpu
# model.load_state_dict(torch.load('./cache/model_trained_8_nodes'))

In [17]:
# %%px --target 0
# torch.save(model.module.state_dict(), './cache/model_trained_8_nodes_state_dict')

In [21]:
%%px --target 0
# load the model on cpu for evaluation
model.load_state_dict(
    torch.load('./cache/model_trained_8_nodes_state_dict',
               map_location=torch.device('cpu'))
)
model.device

[0;31mOut[0:18]: [0mdevice(type='cpu')

In [22]:
%%px
model.eval()

[0;31mOut[0:19]: [0m
BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), 

[0;31mOut[1:16]: [0m
BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), 

In [24]:
%%px --target 0
samples = np.random.choice(len(x_eval[0]), 50, replace=False)

eu.EvalUtility(
    (x_eval[0][samples], x_eval[1][samples], x_eval[2][samples]),
    model,
    eval_squad_examples[samples]
).results()

[stdout:0] 
  - from 1870 to 1939              | ref: 1870 to 1939                   | How long was the Summer Theatre in operation?
  - extralegal                     | ref: extra-legal                    | Excessive bureaucratic red tape is one of the reasons for what type of ownership?
  - paramount pictures             | ref: Paramount Pictures             | What company did Eisner become president of when he left ABC in 1976?
  - 300 km                         | ref: 300 km long                    | How long is the Upper Rhine Plain?
  - plague was present somewhere i | ref: the plague was present somewhere in Europe in every year between 1346 and 1671. | What did Biraben say about the plague in Europe?
  - 738 days                       | ref: 738 days                       | How long did Julia Butterfly Hill live in a tree?
  - sybilla of normandy            | ref: Sybilla of Normandy            | Who did Alexander I marry?
  - boat                           | ref: boat         

In [26]:
%ipcluster stop