# Text Extraction with BERT

**Author:** [Apoorv Nandan](https://twitter.com/NandanApoorv)<br>
**Date created:** 2020/05/23<br>
**Last modified:** 2020/05/23<br>
**Description:** Fine tune pretrained BERT from HuggingFace Transformers on SQuAD.

## Introduction

This demonstration uses SQuAD (Stanford Question-Answering Dataset).
In SQuAD, an input consists of a question, and a paragraph for context.
The goal is to find the span of text in the paragraph that answers the question.
We evaluate our performance on this data with the "Exact Match" metric,
which measures the percentage of predictions that exactly match any one of the
ground-truth answers.

We fine-tune a BERT model to perform this task as follows:

1. Feed the context and the question as inputs to BERT.
2. Take two vectors S and T with dimensions equal to that of
   hidden states in BERT.
3. Compute the probability of each token being the start and end of
   the answer span. The probability of a token being the start of
   the answer is given by a dot product between S and the representation
   of the token in the last layer of BERT, followed by a softmax over all tokens.
   The probability of a token being the end of the answer is computed
   similarly with the vector T.
4. Fine-tune BERT and learn S and T along the way.

**References:**

- [BERT](https://arxiv.org/abs/1810.04805)
- [SQuAD](https://arxiv.org/abs/1606.05250)

## Setup

In [26]:
# ---------- 0) Env & Imports ----------
import os
os.environ["KERAS_BACKEND"] = "tensorflow"  # ensure tf-keras

import re, json, string
import numpy as np
import tensorflow as tf
import tf_keras as keras
from tf_keras import layers
from transformers import BertTokenizer, TFBertModel, BertConfig
from tokenizers import BertWordPieceTokenizer

# ---------- 1) Config ----------
max_len = 128

## Set-up BERT tokenizer

In [27]:
# ---------- 2) Tokenizer (save slow -> load fast) ----------
slow_tok = BertTokenizer.from_pretrained("bert-base-uncased")
save_path = "bert_base_uncased"
os.makedirs(save_path, exist_ok=True)
slow_tok.save_pretrained(save_path)
tokenizer = BertWordPieceTokenizer(f"{save_path}/vocab.txt", lowercase=True)

## Load the data

In [28]:
# ---------- 3) Load SQuAD v1.1 ----------
train_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json"
eval_url  = "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json"
train_path = keras.utils.get_file("train.json", train_url)
eval_path  = keras.utils.get_file("eval.json",  eval_url)

with open(train_path) as f: raw_train = json.load(f)
with open(eval_path)  as f: raw_eval  = json.load(f)

## Preprocess the data

1. Go through the JSON file and store every record as a `SquadExample` object.
2. Go through each `SquadExample` and create `x_train, y_train, x_eval, y_eval`.

In [29]:
# ---------- 4) Example holder ----------
class SquadExample:
    def __init__(self, question, context, start_char_idx, answer_text, all_answers):
        self.question = question
        self.context = context
        self.start_char_idx = start_char_idx
        self.answer_text = answer_text
        self.all_answers = all_answers
        self.skip = False

    def preprocess(self):
        context  = " ".join(str(self.context).split())
        question = " ".join(str(self.question).split())
        answer   = " ".join(str(self.answer_text).split())
        end_char_idx = self.start_char_idx + len(answer)
        if end_char_idx >= len(context):
            self.skip = True; return

        # mark answer chars
        is_in_ans = [0]*len(context)
        for i in range(self.start_char_idx, end_char_idx):
            is_in_ans[i] = 1

        # tokenize context & find token span
        tok_ctx = tokenizer.encode(context)
        ans_tok_idx = [i for i,(s,e) in enumerate(tok_ctx.offsets) if sum(is_in_ans[s:e])>0]
        if not ans_tok_idx:
            self.skip = True; return

        start_token_idx = ans_tok_idx[0]
        end_token_idx   = ans_tok_idx[-1]

        # tokenize question
        tok_q = tokenizer.encode(question)

        # build [context tokens] + [question tokens w/o first special]
        input_ids = tok_ctx.ids + tok_q.ids[1:]
        token_type_ids = [0]*len(tok_ctx.ids) + [1]*len(tok_q.ids[1:])
        attention_mask = [1]*len(input_ids)

        # pad / skip on overflow
        pad = max_len - len(input_ids)
        if pad < 0:
            self.skip = True; return
        if pad > 0:
            input_ids      += [0]*pad
            attention_mask += [0]*pad
            token_type_ids += [0]*pad

        self.input_ids = input_ids
        self.token_type_ids = token_type_ids
        self.attention_mask = attention_mask
        self.start_token_idx = start_token_idx
        self.end_token_idx   = end_token_idx
        self.context_token_to_char = tok_ctx.offsets
        self.context = context
        self.all_answers = self.all_answers  # unchanged

def build_examples(raw):
    out = []
    for item in raw["data"]:
        for para in item["paragraphs"]:
            ctx = para["context"]
            for qa in para["qas"]:
                ans_text = qa["answers"][0]["text"]
                ans_all  = [a["text"] for a in qa["answers"]]
                start    = qa["answers"][0]["answer_start"]
                ex = SquadExample(qa["question"], ctx, start, ans_text, ans_all)
                ex.preprocess()
                out.append(ex)
    return out

def to_arrays(examples):
    buf = {k: [] for k in ["input_ids","token_type_ids","attention_mask","start_token_idx","end_token_idx"]}
    for e in examples:
        if not e.skip:
            buf["input_ids"].append(e.input_ids)
            buf["token_type_ids"].append(e.token_type_ids)
            buf["attention_mask"].append(e.attention_mask)
            buf["start_token_idx"].append(e.start_token_idx)
            buf["end_token_idx"].append(e.end_token_idx)
    for k in buf: buf[k] = np.array(buf[k])
    x = [buf["input_ids"], buf["token_type_ids"], buf["attention_mask"]]
    y = [buf["start_token_idx"], buf["end_token_idx"]]
    return x, y

train_examples = build_examples(raw_train)
eval_examples  = build_examples(raw_eval)
print(f"{len(train_examples)} train examples, {len(eval_examples)} eval examples")

(x_ids_tr, x_seg_tr, x_mask_tr), (y_start_tr, y_end_tr) = to_arrays(train_examples)
(x_ids_ev, x_seg_ev, x_mask_ev), (y_start_ev, y_end_ev) = to_arrays(eval_examples)

# cast to int32 for TF
x_ids_tr  = x_ids_tr.astype("int32");  x_seg_tr  = x_seg_tr.astype("int32");  x_mask_tr  = x_mask_tr.astype("int32")
x_ids_ev  = x_ids_ev.astype("int32");  x_seg_ev  = x_seg_ev.astype("int32");  x_mask_ev  = x_mask_ev.astype("int32")
y_start_tr = y_start_tr.astype("int32"); y_end_tr = y_end_tr.astype("int32")
y_start_ev = y_start_ev.astype("int32"); y_end_ev = y_end_ev.astype("int32")

x_train = {"input_ids": x_ids_tr, "attention_mask": x_mask_tr, "token_type_ids": x_seg_tr}
y_train = [y_start_tr, y_end_tr]
x_eval  = {"input_ids": x_ids_ev, "attention_mask": x_mask_ev, "token_type_ids": x_seg_ev}
y_eval  = [y_start_ev, y_end_ev]

# ---------- 5) Strategy ----------
def get_strategy():
    try:
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local')
        try: tf.config.experimental_connect_to_cluster(resolver)
        except Exception: pass
        tf.tpu.experimental.initialize_tpu_system(resolver)
        print("✅ Using TPU")
        return tf.distribute.TPUStrategy(resolver)
    except Exception as e:
        print(f"TPU not found/usable: {e}")
    gpus = tf.config.list_physical_devices("GPU")
    if len(gpus) > 1:
        strat = tf.distribute.MirroredStrategy()
        print(f"✅ Using {strat.num_replicas_in_sync} GPUs (MirroredStrategy)")
        return strat
    elif len(gpus) == 1:
        print("✅ Using single GPU")
        return tf.distribute.OneDeviceStrategy("/GPU:0")
    else:
        print("✅ Using CPU")
        return tf.distribute.OneDeviceStrategy("/CPU:0")

strategy = get_strategy()

87599 train examples, 10570 eval examples
TPU not found/usable: TPUs not found in the cluster. Failed in initialization: No OpKernel was registered to support Op 'ConfigureDistributedTPU' used by {{node ConfigureDistributedTPU}} with these attrs: [tpu_cancellation_closes_chips=2, embedding_config="", tpu_embedding_config="", enable_whole_mesh_compilations=false, is_global_init=false, compilation_failure_closes_chips=false]
Registered devices: [CPU, GPU]
Registered kernels:
  <no registered kernels>

	 [[ConfigureDistributedTPU]] [Op:__inference__tpu_init_fn_68324]
✅ Using single GPU


Create the Question-Answering Model using BERT and Functional API

In [30]:
# ---------- 6) Model ----------
def create_model():
    cfg = BertConfig.from_pretrained("bert-base-uncased", add_pooling_layer=False)
    encoder = TFBertModel.from_pretrained("bert-base-uncased", config=cfg, from_pt=True)

    input_ids      = layers.Input((max_len,), dtype=tf.int32, name="input_ids")
    attention_mask = layers.Input((max_len,), dtype=tf.int32, name="attention_mask")
    token_type_ids = layers.Input((max_len,), dtype=tf.int32, name="token_type_ids")

    enc = encoder(input_ids=input_ids,
                  attention_mask=attention_mask,
                  token_type_ids=token_type_ids)  # let Keras set training/inference
    seq = enc.last_hidden_state  # (B, L, H)

    start_logits = layers.Dense(1, use_bias=False, name="start_logit")(seq)
    end_logits   = layers.Dense(1, use_bias=False, name="end_logit")(seq)
    start_logits = layers.Flatten()(start_logits)
    end_logits   = layers.Flatten()(end_logits)

    start_probs = layers.Activation("softmax", name="start_probs")(start_logits)
    end_probs   = layers.Activation("softmax",  name="end_probs")(end_logits)

    model = keras.Model(
        inputs=[input_ids, attention_mask, token_type_ids],
        outputs=[start_probs, end_probs],
        name="bert_qa"
    )
    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False)
    opt  = keras.optimizers.Adam(learning_rate=5e-5)
    model.compile(optimizer=opt, loss=[loss, loss], run_eagerly=False)
    return model

with strategy.scope():
    model = create_model()

# warm-up to create variables before tf.function graph tracing
_ = model(
    {"input_ids": tf.zeros((1, max_len), tf.int32),
     "attention_mask": tf.ones((1, max_len), tf.int32),
     "token_type_ids": tf.zeros((1, max_len), tf.int32)},
    training=False
)
model.summary()

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already

Model: "bert_qa"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_ids (InputLayer)      [(None, 128)]                0         []                            
                                                                                                  
 attention_mask (InputLayer  [(None, 128)]                0         []                            
 )                                                                                                
                                                                                                  
 token_type_ids (InputLayer  [(None, 128)]                0         []                            
 )                                                                                                
                                                                                            

This code should preferably be run on Google Colab TPU runtime.
With Colab TPUs, each epoch will take 5-6 minutes.

## Create evaluation Callback

This callback will compute the exact match score using the validation data
after every epoch.

In [31]:
# ---------- 7) Exact Match callback ----------
def normalize_text(text):
    text = text.lower()
    text = "".join(ch for ch in text if ch not in set(string.punctuation))
    text = re.sub(r"\b(a|an|the)\b", " ", text)
    return " ".join(text.split())

class ExactMatch(keras.callbacks.Callback):
    def __init__(self, x_eval, y_eval, eval_examples):
        self.x_eval = x_eval
        self.y_eval = y_eval
        self.eval_examples = [e for e in eval_examples if not e.skip]

    def on_epoch_end(self, epoch, logs=None):
        pred_start, pred_end = self.model.predict(self.x_eval, verbose=0)
        count = 0
        for i, (ps, pe) in enumerate(zip(pred_start, pred_end)):
            e = self.eval_examples[i]
            offsets = e.context_token_to_char
            s = int(np.argmax(ps))
            t = int(np.argmax(pe))
            if s >= len(offsets): continue
            cs = offsets[s][0]
            if t < len(offsets):
                ce = offsets[t][1]
                pred = e.context[cs:ce]
            else:
                pred = e.context[cs:]
            if normalize_text(pred) in [normalize_text(a) for a in e.all_answers]:
                count += 1
        acc = count / len(self.y_eval[0])
        print(f"\nEpoch {epoch+1}: Exact Match = {acc:.4f}")

exact_match_cb = ExactMatch(x_eval, y_eval, eval_examples)

## Train and Evaluate

In [32]:
# ---------- 8) Train ----------
model.fit(
    x_train,
    y_train,
    epochs=1,              # increase to 2-3+ for better demo results
    batch_size=16,
    verbose=2,
    callbacks=[exact_match_cb],
)




Epoch 1: Exact Match = 0.7165
2623/2623 - 706s - loss: 2.7384 - start_probs_loss: 1.4397 - end_probs_loss: 1.2987 - 706s/epoch - 269ms/step


<tf_keras.src.callbacks.History at 0x7c8d95d424e0>