# Fine-tuning a BERT model for text extraction with the SQuAD dataset

We are going to fine-tune BERT for the text-extraction task with a dataset of questions and answers. The data is composed by a set of questions and corresponding paragraphs that contains the answers. The model will be trained to locate the answer in the context by giving the positions where the answer starts and ends.

This notebook is based on [BERT (from HuggingFace Transformers) for Text Extraction](https://keras.io/examples/nlp/text_extraction_with_bert/).

Here we use [BERT base model (uncased)](https://huggingface.co/bert-base-uncased) and the [BertForQuestionAnswering](https://huggingface.co/transformers/model_doc/bert.html?highlight=bertforquestionanswering#bertforquestionanswering) class from HugginFace .

In [1]:
import ipcmagic

In [2]:
%ipcluster start -n 2 --mpi

  0%|          | 0/2 [00:00<?, ?engine/s]

In [4]:
%%px
import numpy as np
import os
import json
import dataset_utils as du
import torch
import torch.distributed as dist
from transformers import BertTokenizer, BertForQuestionAnswering, AdamW
from tokenizers import BertWordPieceTokenizer
from torch.utils.data import DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel

In [5]:
%%px
bert_cache = os.path.join(os.getcwd(), 'cache')

In [6]:
%%px
slow_tokenizer = BertTokenizer.from_pretrained(
    'bert-base-uncased',
    cache_dir=os.path.join(bert_cache, '_bert-base-uncased-tokenizer')
)
save_path = os.path.join(bert_cache, 'bert-base-uncased-tokenizer')
if not os.path.exists(save_path):
    os.makedirs(save_path)
    slow_tokenizer.save_pretrained(save_path)
    
# Load the fast tokenizer from saved file
tokenizer = BertWordPieceTokenizer(os.path.join(save_path, 'vocab.txt'),
                                   lowercase=True)

In [7]:
%%px
model = BertForQuestionAnswering.from_pretrained(
    "bert-base-uncased",
    cache_dir=os.path.join(bert_cache, 'bert-base-uncased_qa')
)

[stderr:0] 
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForQuestionAnswering: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[stderr:0] 
Some weights of BertForQuestionAnswering were not initialized from the model checkpoin

In [8]:
%%px
train_path = os.path.join(bert_cache, 'data', 'train-v1.1.json')
eval_path = os.path.join(bert_cache, 'data', 'dev-v1.1.json')
with open(train_path) as f:
    raw_train_data = json.load(f)

with open(eval_path) as f:
    raw_eval_data = json.load(f)

In [9]:
%%px
batch_size = 8
max_len = 384

In [10]:
%%px
train_squad_examples = du.create_squad_examples(raw_train_data, max_len, tokenizer)
x_train, y_train = du.create_inputs_targets(train_squad_examples, shuffle=True, seed=42)
print(f"{len(train_squad_examples)} training points created.")

eval_squad_examples = du.create_squad_examples(raw_eval_data, max_len, tokenizer)
x_eval, y_eval = du.create_inputs_targets(eval_squad_examples)
print(f"{len(eval_squad_examples)} evaluation points created.")

[stdout:1] 86136 training points created.
[stdout:0] 86136 training points created.
[stdout:1] 10331 evaluation points created.
[stdout:0] 10331 evaluation points created.


In [11]:
%%px
class SquadDataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, idx):
        return (torch.tensor(self.x[0][idx]),
                torch.tensor(self.x[1][idx]),
                torch.tensor(self.x[2][idx]),
                torch.tensor(self.y[0][idx]),
                torch.tensor(self.y[1][idx]))

    def __len__(self):
        return len(self.x[0])

In [12]:
%%px
from pt_distr_env import setup_distr_env

setup_distr_env()
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()

In [45]:
%%px
train_set = SquadDataset(x_train, y_train)
train_sampler = DistributedSampler(train_set, num_replicas=world_size,
                                   rank=rank, shuffle=True, seed=42)

train_loader = DataLoader(train_set, batch_size=batch_size,
                          shuffle=False, sampler=train_sampler)

test_set = SquadDataset(x_eval, y_eval)
test_sampler = DistributedSampler(test_set, num_replicas=world_size,
                                  rank=rank, shuffle=True)
test_loader = DataLoader(test_set, batch_size=8,
                         shuffle=False, sampler=test_sampler)

In [13]:
%%px
for i, xxx in enumerate(train_loader):
    if i < 2:
        print(xxx[0].shape, xxx[0][0][10])

[stdout:1] 
torch.Size([8, 384]) tensor(2348)
torch.Size([8, 384]) tensor(2038)
[stdout:0] 
torch.Size([8, 384]) tensor(10232)
torch.Size([8, 384]) tensor(2189)


In [15]:
%%px
device = 0
model.to(device)
model = DistributedDataParallel(model, device_ids=[0])
model.train()

[stdout:0] 
nid03303:9716:9716 [0] NCCL INFO Bootstrap : Using [0]ipogif0:148.187.45.2<0>
nid03303:9716:9716 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
nid03303:9716:9716 [0] NCCL INFO NET/IB : No device found.
nid03303:9716:9716 [0] NCCL INFO NET/Socket : Using [0]ipogif0:148.187.45.2<0>
nid03303:9716:9716 [0] NCCL INFO Using network Socket
NCCL version 2.7.8+cuda11.1


[stderr:0] 
libibverbs: Could not locate libibgni (/usr/lib64/libibgni.so.1: undefined symbol: verbs_uninit_context)


[stdout:0] 
nid03303:9716:9829 [0] NCCL INFO Channel 00/02 :    0   1
nid03303:9716:9829 [0] NCCL INFO Channel 01/02 :    0   1
nid03303:9716:9829 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 8/8/64
nid03303:9716:9829 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1|-1->0->1/-1/-1 [1] -1/-1/-1->0->1|1->0->-1/-1/-1
nid03303:9716:9829 [0] NCCL INFO Setting affinity for GPU 0 to ffffff
nid03303:9716:9829 [0] NCCL INFO Channel 00 : 1[2000] -> 0[2000] [receive] via NET/Socket/0
nid03303:9716:9829 [0] NCCL INFO Channel 00 : 0[2000] -> 1[2000] [send] via NET/Socket/0
nid03303:9716:9829 [0] NCCL INFO Channel 01 : 1[2000] -> 0[2000] [receive] via NET/Socket/0
nid03303:9716:9829 [0] NCCL INFO Channel 01 : 0[2000] -> 1[2000] [send] via NET/Socket/0
nid03303:9716:9829 [0] NCCL INFO 2 coll channels, 2 p2p channels, 1 p2p channels per peer
nid03303:9716:9829 [0] NCCL INFO comm 0x2aad88002dc0 rank 0 nranks 2 cudaDev 0 busId 2000 - Init COMPLETE
nid03303:9716:9716 [0] NCCL INFO Launch mode Parallel
[stdout

[stderr:1] 
libibverbs: Could not locate libibgni (/usr/lib64/libibgni.so.1: undefined symbol: verbs_uninit_context)


[0;31mOut[0:13]: [0m
DistributedDataParallel(
  (module): BertForQuestionAnswering(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=76

[0;31mOut[1:13]: [0m
DistributedDataParallel(
  (module): BertForQuestionAnswering(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=76

In [16]:
%%px
optim = AdamW(model.parameters(), lr=5e-5)

In [17]:
%%px
for epoch in range(1):
    for i, batch in enumerate(train_loader):
        if i > 100:
            break

        optim.zero_grad()
        outputs = model(input_ids=batch[0].to(device),
                        token_type_ids=batch[1].to(device),
                        attention_mask=batch[2].to(device),
                        start_positions=batch[3].to(device),
                        end_positions=batch[4].to(device)
                       )
        loss = outputs[0]
        loss.backward()
        optim.step()

In [18]:
%%px
model.eval()

[0;31mOut[0:16]: [0m
DistributedDataParallel(
  (module): BertForQuestionAnswering(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=76

[0;31mOut[1:16]: [0m
DistributedDataParallel(
  (module): BertForQuestionAnswering(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=76

In [21]:
%%px
import string
import re


def normalize_text(text):
    text = text.lower()

    # Remove punctuations
    exclude = set(string.punctuation)
    text = "".join(ch for ch in text if ch not in exclude)

    # Remove articles
    regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
    text = re.sub(regex, " ", text)

    # Remove extra white space
    text = " ".join(text.split())
    return text


class ExactMatch():
    """
    Each `SquadExample` object contains the character level offsets for each token
    in its input paragraph. We use them to get back the span of text corresponding
    to the tokens between our predicted start and end tokens.
    All the ground-truth answers are also present in each `SquadExample` object.
    We calculate the percentage of data points where the span of text obtained
    from model predictions matches one of the ground-truth answers.
    """

    def __init__(self, x_eval, y_eval, model, squad_examples):
        self.model = model
        self.squad_examples = squad_examples
        self.input_ids = x_eval[0].to(device)
        self.token_type_ids = x_eval[1].to(device)
        self.attention_mask = x_eval[2].to(device)
        self.batch_size = self.input_ids.size()[0]

    def score(self, logs=None):
        outputs = model(input_ids=self.input_ids,
                        token_type_ids=self.token_type_ids,
                        attention_mask=self.attention_mask
                       )
        pred_start = outputs.start_logits.cpu().detach().numpy()
        pred_end = outputs.end_logits.cpu().detach().numpy()
        count = 0
        eval_examples_no_skip = [i for i in self.squad_examples
                                 if i.skip == False]
        for idx, (start, end) in enumerate(zip(pred_start, pred_end)):
            squad_eg = eval_examples_no_skip[idx]
            offsets = squad_eg.context_token_to_char
            start = np.argmax(start)
            end = np.argmax(end)
            if start >= len(offsets):
                continue

            pred_char_start = offsets[start][0]
            if end < len(offsets):
                pred_char_end = offsets[end][1]
                pred_ans = squad_eg.context[pred_char_start:pred_char_end]
            else:
                pred_ans = squad_eg.context[pred_char_start:]

            normalized_pred_ans = normalize_text(pred_ans)
            normalized_true_ans = [normalize_text(_)
                                   for _ in squad_eg.all_answers]
            if normalized_pred_ans in normalized_true_ans:
                count += 1
                
            print(f'  - {normalized_pred_ans:25.25s} |'
                  f' ref: {squad_eg.answer_text:30s} |'
                  f' {squad_eg.question}')

        acc = count / self.batch_size
        return acc

In [56]:
%%px
for i, eval_batch in enumerate(test_loader):
    if i < 3:
        print(eval_batch[0].shape, eval_batch[0][0][10])

[stdout:1] 
torch.Size([8, 384]) tensor(4802)
torch.Size([8, 384]) tensor(5250)
torch.Size([8, 384]) tensor(2705)
[stdout:0] 
torch.Size([8, 384]) tensor(20739)
torch.Size([8, 384]) tensor(13345)
torch.Size([8, 384]) tensor(2011)


In [33]:
# %%px
# eval_batch = next(iter(test_loader))

In [57]:
%%px
samples = np.random.choice(eval_batch[0].shape[0], eval_batch[0].shape[0], replace=False)

em = ExactMatch((eval_batch[0][samples], eval_batch[1][samples], eval_batch[2][samples]),
                (eval_batch[3][samples], eval_batch[4][samples]),
                model,
                eval_squad_examples[samples])
em.score()

[stdout:0] 
  -                           | ref: Denver Broncos                 | Which NFL team won Super Bowl 50?
  - american football game to | ref: gold                           | What color was used to emphasize the 50th anniversary of the Super Bowl?
  - nfl for                   | ref: "golden anniversary"           | What was the theme of Super Bowl 50?
  - golden                    | ref: Carolina Panthers              | Which NFL team represented the NFC at Super Bowl 50?
  - super bowl 50 was america | ref: Denver Broncos                 | Which NFL team represented the AFC at Super Bowl 50?
[stdout:1] 
  - champion denver broncos d | ref: Santa Clara, California        | Where did Super Bowl 50 take place?
  - game would have been know | ref: Denver Broncos                 | Which NFL team won Super Bowl 50?
  - game with roman           | ref: gold                           | What color was used to emphasize the 50th anniversary of the Super Bowl?
  - their third super b

[0;31mOut[0:55]: [0m0.0

[0;31mOut[1:55]: [0m0.0

In [58]:
%%px
torch.cuda.empty_cache()

In [60]:
%ipcluster stop

IPCluster not running.
