# Fine-tuning a BERT model for text extraction with the SQuAD dataset

We are going to fine-tune [BERT implemented by HuggingFace](https://huggingface.co/bert-base-uncased) for the text-extraction task with a dataset of questions and answers with the [SQuAD (The Stanford Question Answering Dataset)](https://rajpurkar.github.io/SQuAD-explorer/) dataset.
The data is composed by a set of questions and corresponding paragraphs that contains the answers.
The model will be trained to locate the answer in the context by giving the positions where the answer starts and ends.

In this notebook we are going to do the training using multiple GPUs.

This notebook is based on [BERT (from HuggingFace Transformers) for Text Extraction](https://keras.io/examples/nlp/text_extraction_with_bert/).

More info:
- [Glossary - HuggingFace docs](https://huggingface.co/transformers/glossary.html#model-inputs)
- [BERT NLP — How To Build a Question Answering Bot](https://towardsdatascience.com/bert-nlp-how-to-build-a-question-answering-bot-98b1d1594d7b)

In [1]:
import ipcmagic

In [2]:
%ipcluster start -n 2

100%|██████████| 2/2 [00:06<00:00,  3.20s/engine]


In [3]:
%pxconfig --progress-after -1

In [4]:
%%px
import os
import utility.data_processing as dpp
import utility.testing as testing
import torch
import torch.distributed as dist
from datasets import load_dataset, load_metric
from datetime import datetime
from transformers import BertTokenizer, BertForQuestionAnswering
from tokenizers import BertWordPieceTokenizer
from torch.utils.data import DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel

  from .autonotebook import tqdm as notebook_tqdm


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
%%px
from datasets.utils import disable_progress_bar
from datasets import disable_caching


disable_progress_bar()
disable_caching()

In [6]:
%%px
hf_model = 'bert-base-uncased'
bert_cache = os.path.join(os.getcwd(), 'cache')

In [7]:
%%px
slow_tokenizer = BertTokenizer.from_pretrained(
    hf_model,
    cache_dir=os.path.join(bert_cache, f'_{hf_model}-tokenizer')
)
save_path = os.path.join(bert_cache, f'{hf_model}-tokenizer')
if not os.path.exists(save_path):
    os.makedirs(save_path)
    slow_tokenizer.save_pretrained(save_path)
    
# Load the fast tokenizer from saved file
tokenizer = BertWordPieceTokenizer(os.path.join(save_path, 'vocab.txt'),
                                   lowercase=True)

In [8]:
%%px
model = BertForQuestionAnswering.from_pretrained(
    hf_model,
    cache_dir=os.path.join(bert_cache, f'{hf_model}_qa')
)

[stderr:0] Downloading: 100%|██████████| 570/570 [00:00<00:00, 260kB/s]
Downloading: 100%|██████████| 420M/420M [00:11<00:00, 37.1MB/s]
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForQuestionAnswering: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSeque

[stderr:1] Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForQuestionAnswering: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-bas

In [9]:
%%px
from pt_distr_env import DistributedEnviron

distr_env = DistributedEnviron()
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()

[stderr:0] [W socket.cpp:401] [c10d] The server socket cannot be initialized on [::]:39591 (errno: 97 - Address family not supported by protocol).
[W socket.cpp:558] [c10d] The client socket cannot be initialized to connect to [nid04676]:39591 (errno: 97 - Address family not supported by protocol).
[W socket.cpp:558] [c10d] The client socket cannot be initialized to connect to [nid04676]:39591 (errno: 97 - Address family not supported by protocol).


[stderr:1] [W socket.cpp:558] [c10d] The client socket cannot be initialized to connect to [nid04676]:39591 (errno: 97 - Address family not supported by protocol).
[W socket.cpp:558] [c10d] The client socket cannot be initialized to connect to [nid04676]:39591 (errno: 97 - Address family not supported by protocol).


In [10]:
%%px
hf_dataset = load_dataset('squad')





In [11]:
%%px
max_len = 384

In [12]:
%%px
hf_dataset.flatten()
processed_dataset = hf_dataset.flatten().map(
    lambda example: dpp.process_squad_item_batched(example, max_len, tokenizer),
    remove_columns=hf_dataset.flatten()['train'].column_names,
    batched=True,
    num_proc=12
)

In [13]:
%%px
train_dataset = processed_dataset["train"]
train_dataset.set_format(type='torch')

eval_dataset = processed_dataset["validation"]
eval_dataset.set_format(type='torch')

In [14]:
%%px
per_device_train_batch_size = 16
per_device_eval_batch_size = 1

train_sampler = DistributedSampler(train_dataset, num_replicas=world_size,
                                   rank=rank, shuffle=False, seed=42)

train_dataloader = DataLoader(
    train_dataset,
    shuffle=False,  # sampler option is mutually exclusive with shuffle
    batch_size=per_device_train_batch_size,
    sampler=train_sampler
)

eval_dataloader = DataLoader(
    eval_dataset,
    shuffle=False,
    batch_size=per_device_eval_batch_size
)

In [15]:
%%px
device = 0
model.to(device)
model = DistributedDataParallel(model, device_ids=[device])
model.train()

model.training

[stderr:0] libibverbs: Could not locate libibgni (/usr/lib64/libibgni.so.1: undefined symbol: verbs_uninit_context)


[0;31mOut[0:12]: [0mTrue

[stderr:1] libibverbs: Could not locate libibgni (/usr/lib64/libibgni.so.1: undefined symbol: verbs_uninit_context)


[0;31mOut[1:12]: [0mTrue

In [16]:
%%px
optim = torch.optim.AdamW(model.parameters(), lr=5e-5)

In [17]:
%%px
for epoch in range(1):
    for i, batch in enumerate(train_dataloader):
        optim.zero_grad()
        outputs = model(input_ids=batch['input_ids'].to(device),
                        token_type_ids=batch['token_type_ids'].to(device),
                        attention_mask=batch['attention_mask'].to(device),
                        start_positions=batch['start_token_idx'].to(device),
                        end_positions=batch['end_token_idx'].to(device))        
        loss = outputs[0]
        loss.backward()
        optim.step()
        
        if i > 100:
            break

## Saving the model

In [18]:
%%px --target 0
model_hash = datetime.now().strftime("%Y-%m-%d-%H%M%S")
model_path_name = f'./cache/model_trained_pytorch_{model_hash}'

# save model's state_dict
# the model now is a DDP model
# use `model.module.state_dict()` in order the load it later on
# any number of nodes
torch.save(model.module.state_dict(), model_path_name)

In [19]:
%%px --target 0
# create the model again since the previous one is on the gpu
model_cpu = BertForQuestionAnswering.from_pretrained(
    "bert-base-uncased",
    cache_dir=os.path.join(bert_cache, 'bert-base-uncased_qa')
)

# load the model on cpu
model_cpu.load_state_dict(
    torch.load(model_path_name,
               map_location=torch.device('cpu'))
)

# load the model on gpu
# model.load_state_dict(torch.load(model_path_name))

[stderr:0] Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForQuestionAnswering: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-bas

[0;31mOut[0:16]: [0m<All keys matched successfully>

In [20]:
%%px --target 0
model.eval()

model.training

[0;31mOut[0:17]: [0mFalse

In [21]:
%%px --target 0
squad_example_objects = []
for item in hf_dataset['validation'].flatten():
    squad_examples = dpp.squad_examples_from_dataset(item, max_len, tokenizer)
    try:
        squad_example_objects.extend(squad_examples)
    except TypeError:
        squad_example_objects.append(squad_examples)
        
assert len(eval_dataset) == len(squad_example_objects)

In [22]:
%%px --target 0

start_sample = 0
num_test_samples = 10
for i, eval_batch in enumerate(eval_dataloader):
    if i > start_sample:
        testing.EvalUtility(eval_batch, [squad_example_objects[i]], model).results()

    if i > start_sample + num_test_samples:
        break

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

[output:0]

In [23]:
%ipcluster stop

IPCluster stopped.
