#Fine Tuning BERT for question answering task (given a context to find answers for)
### Pretrained Model is being taken from tensorflow hub. Further fine tuning is done as per our parameters

Input sequence is special start token + context string + separator token + question + sepator token

Output is start and end logits (indexes with highest probability of being the answer).

In [None]:
!pip install -r requirements.txt

In [2]:
#importing the required libraries
import json
import os
import re
import string
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow import keras
from tensorflow.keras import layers
from tokenizers import BertWordPieceTokenizer

======================================== PREPARING DATASET ===================================================

In [3]:
class Sample:
    def __init__(self, question, context, start_char_idx=None, answer_text=None, all_answers=None):
        self.question = question
        self.context = context
        self.start_char_idx = start_char_idx
        self.answer_text = answer_text
        self.all_answers = all_answers
        self.skip = False
        self.start_token_idx = -1
        self.end_token_idx = -1

    def preprocess(self):
        #  clean context and question
        context = " ".join(str(self.context).split())
        question = " ".join(str(self.question).split())
        # tokenize context and question -- appends cls at beginning and sep at end
        tokenized_context = tokenizer.encode(context)  
        tokenized_question = tokenizer.encode(question)
        # if this is validation or training sample, preprocess answer
        if self.answer_text is not None:
            answer = " ".join(str(self.answer_text).split())
            end_char_idx = self.start_char_idx + len(answer)
            # check if end character index is in the context
            if end_char_idx >= len(context):
                self.skip = True
                return
            # mark all the character indexes in context that are also in answer 
            is_char_in_ans = [0] * len(context)
            for idx in range(self.start_char_idx, end_char_idx):
                is_char_in_ans[idx] = 1
            ans_token_idx = []
            for idx, (start, end) in enumerate(tokenized_context.offsets):
                if sum(is_char_in_ans[start:end]) > 0:
                    # find all the tokens that are in the answers
                    ans_token_idx.append(idx)
            if len(ans_token_idx) == 0:
                self.skip = True
                return
            # get start and end token indexes
            self.start_token_idx = ans_token_idx[0]
            self.end_token_idx = ans_token_idx[-1]
        # create inputs 
        input_ids = tokenized_context.ids + tokenized_question.ids[1:] 
        #here 0th index is removed because we do not want to add cls token again
        token_type_ids = [0] * len(tokenized_context.ids) + [1] * len(tokenized_question.ids[1:])
        attention_mask = [1] * len(input_ids)
        padding_length = max_seq_length - len(input_ids)
        # add padding if necessary
        if padding_length > 0:
            input_ids = input_ids + ([0] * padding_length)
            attention_mask = attention_mask + ([0] * padding_length)
            token_type_ids = token_type_ids + ([0] * padding_length)
        elif padding_length < 0:
            self.skip = True
            return
        self.input_word_ids = input_ids
        self.input_type_ids = token_type_ids
        self.input_mask = attention_mask
        self.context_token_to_char = tokenized_context.offsets

============================================= CREATING INPUT SEQUENCE ==========================================

In [4]:
def create_squad_examples(raw_data):
    squad_examples = []
    for item in raw_data["data"]:
        for para in item["paragraphs"]:
            context = para["context"]
            for qa in para["qas"]:
                question = qa["question"]
                if "answers" in qa:
                    answer_text = qa["answers"][0]["text"]
                    all_answers = [_["text"] for _ in qa["answers"]]
                    start_char_idx = qa["answers"][0]["answer_start"]
                    squad_eg = Sample(question, context, start_char_idx, answer_text, all_answers)
                else:
                    squad_eg = Sample(question, context)
                squad_eg.preprocess()
                squad_examples.append(squad_eg)
    return squad_examples


def create_inputs_targets(squad_examples):
    dataset_dict = {
        "input_word_ids": [],
        "input_type_ids": [],
        "input_mask": [],
        "start_token_idx": [],
        "end_token_idx": [],
    }
    for item in squad_examples:
        if item.skip == False:
            for key in dataset_dict:
                dataset_dict[key].append(getattr(item, key))
    for key in dataset_dict:
        dataset_dict[key] = np.array(dataset_dict[key])
    x = [dataset_dict["input_word_ids"],
         dataset_dict["input_mask"],
         dataset_dict["input_type_ids"]]
    y = [dataset_dict["start_token_idx"], dataset_dict["end_token_idx"]]
    return x, y

============================================= TRAINING ======================================================

In [5]:
class ValidationCallback(keras.callbacks.Callback):

    def normalize_text(self, text):
        #lowercase text
        text = text.lower()
        # remove redundant whitespaces
        text = "".join(ch for ch in text if ch not in set(string.punctuation))
        #remove articles
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        text = re.sub(regex, " ", text)
        text = " ".join(text.split())
        return text

    def __init__(self, x_eval, y_eval):
        self.x_eval = x_eval
        self.y_eval = y_eval

    def on_epoch_end(self, epoch, logs=None):
        # get the offsets of the first and last tokens of predicted answers
        pred_start, pred_end = self.model.predict(self.x_eval)
        count = 0
        eval_examples_no_skip = [_ for _ in eval_squad_examples if _.skip == False]
        # for every pair of offsets
        for idx, (start, end) in enumerate(zip(pred_start, pred_end)):
            # take the required Sample object with the ground-truth answers in it
            squad_eg = eval_examples_no_skip[idx]
            # use offsets to get back the span of text corresponding to
            # our predicted first and last tokens
            offsets = squad_eg.context_token_to_char
            start = np.argmax(start)
            end = np.argmax(end)
            if start >= len(offsets):
                continue
            pred_char_start = offsets[start][0]
            if end < len(offsets):
                pred_char_end = offsets[end][1]
                pred_ans = squad_eg.context[pred_char_start:pred_char_end]
            else:
                pred_ans = squad_eg.context[pred_char_start:]
            normalized_pred_ans = self.normalize_text(pred_ans)
            # clean the real answers
            normalized_true_ans = [self.normalize_text(_) for _ in squad_eg.all_answers]
            # check if the predicted answer is in an array of the ground-truth answers
            if normalized_pred_ans in normalized_true_ans:
                count += 1
        acc = count / len(self.y_eval[0])
        print(f"\nepoch={epoch + 1}, exact match score={acc:.2f}")

============================================= FETCHING DATASET ================================================

In [6]:
train_path = keras.utils.get_file("train.json", "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json")
eval_path = keras.utils.get_file("eval.json", "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json")
with open(train_path) as f: raw_train_data = json.load(f)
with open(eval_path) as f: raw_eval_data = json.load(f)

Downloading data from https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json
Downloading data from https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json


In [None]:
#max length of input. If it crossess this, doc stride (sliding window concept is used)
max_seq_length = 384

input_word_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
input_mask = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
bert_layer = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2", trainable=True)
pooled_output, sequence_output = bert_layer([input_word_ids, input_mask, input_type_ids])
do_lower_case = bert_layer.resolved_object.do_lower_case.numpy()
tokenizer = BertWordPieceTokenizer('vocab.txt', lowercase=True)

============================================ SETTING TRAINING PARAMS ============================================

In [None]:
train_squad_examples = create_squad_examples(raw_train_data)
x_train, y_train = create_inputs_targets(train_squad_examples)
print(f"{len(train_squad_examples)} training points created.")
eval_squad_examples = create_squad_examples(raw_eval_data)
x_eval, y_eval = create_inputs_targets(eval_squad_examples)
print(f"{len(eval_squad_examples)} evaluation points created.")
start_logits = layers.Dense(1, name="start_logit", use_bias=False)(sequence_output)
start_logits = layers.Flatten()(start_logits)
end_logits = layers.Dense(1, name="end_logit", use_bias=False)(sequence_output)
end_logits = layers.Flatten()(end_logits)
start_probs = layers.Activation(keras.activations.softmax)(start_logits)
end_probs = layers.Activation(keras.activations.softmax)(end_logits)
model = keras.Model(inputs=[input_word_ids, input_mask, input_type_ids], outputs=[start_probs, end_probs])
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False)
optimizer = keras.optimizers.Adam(lr=1e-5, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
model.compile(optimizer=optimizer, loss=[loss, loss])
model.summary()
model.fit(x_train, y_train, epochs=2, batch_size=8, callbacks=[ValidationCallback(x_eval, y_eval)])

=========================================== SAVING MODEL AND IT'S STATE =============================================

In [None]:
# saving model and architecture to single file
model.save("model.h5")
#to be used later... 
#from tensorflow.keras.models import load_model
# load model
#model = load_model('model.h5', custom_objects={'KerasLayer': hub.KerasLayer})

#saving model weights
model.save_weights("./weights.h5")
# serialize model to JSON
model_json = model.to_json()
with open("model.json", "w") as json_file:
    json_file.write(model_json)

# later...
#from tensorflow.keras.models import model_from_json
 
# load json and create model
#json_file = open('model.json', 'r')
#loaded_model_json = json_file.read()
#json_file.close()
#loaded_model = model_from_json(loaded_model_json, custom_objects={'KerasLayer': hub.KerasLayer})
# load weights into new model
#loaded_model.load_weights("weights.h5")

==================================================== TESTING =====================================================

In [None]:
data = {"data":
    [
        {"title": "Project Apollo",
         "paragraphs": [
             {
                 "context": "The Apollo program, also known as Project Apollo, was the third United States human "
                            "spaceflight program carried out by the National Aeronautics and Space Administration ("
                            "NASA), which accomplished landing the first humans on the Moon from 1969 to 1972. First "
                            "conceived during Dwight D. Eisenhower's administration as a three-man spacecraft to "
                            "follow the one-man Project Mercury which put the first Americans in space, Apollo was "
                            "later dedicated to President John F. Kennedy's national goal of landing a man on the "
                            "Moon and returning him safely to the Earth by the end of the 1960s, which he proposed in "
                            "a May 25, 1961, address to Congress. Project Mercury was followed by the two-man Project "
                            "Gemini. The first manned flight of Apollo was in 1968. Apollo ran from 1961 to 1972, "
                            "and was supported by the two man Gemini program which ran concurrently with it from 1962 "
                            "to 1966. Gemini missions developed some of the space travel techniques that were "
                            "necessary for the success of the Apollo missions. Apollo used Saturn family rockets as "
                            "launch vehicles. Apollo/Saturn vehicles were also used for an Apollo Applications "
                            "Program, which consisted of Skylab, a space station that supported three manned missions "
                            "in 1973-74, and the Apollo-Soyuz Test Project, a joint Earth orbit mission with the "
                            "Soviet Union in 1975.",
                 "qas": [
                     {"question": "What project put the first Americans into space?",
                      "id": "Q1"
                      },
                     {"question": "What program was created to carry out these projects and missions?",
                      "id": "Q2"
                      },
                     {"question": "What year did the first manned Apollo flight occur?",
                      "id": "Q3"
                      },
                     {"question": "Who did the U.S. collaborate with on an Earth orbit mission in 1975?",
                      "id": "Q4"
                      },
                     {"question": "How long did Project Apollo run?",
                      "id": "Q5"
                      },
                     {"question": "What program helped develop space travel techniques that Project Apollo used?",
                      "id": "Q6"
                      },
                     {"question": "What space station supported three manned missions in 1973-1974?",
                      "id": "Q7"
                      }
                 ]}]}]}

test_samples = create_squad_examples(data)
x_test, _ = create_inputs_targets(test_samples)
pred_start, pred_end = model.predict(x_test)
for idx, (start, end) in enumerate(zip(pred_start, pred_end)):
    test_sample = test_samples[idx]
    offsets = test_sample.context_token_to_char
    start = np.argmax(start)
    end = np.argmax(end)
    pred_ans = None
    if start >= len(offsets):
        continue
    pred_char_start = offsets[start][0]
    if end < len(offsets):
        pred_ans = test_sample.context[pred_char_start:offsets[end][1]]
    else:
        pred_ans = test_sample.context[pred_char_start:]
    print("Q: " + test_sample.question)
    print("A: " + pred_ans if pred_ans is not None else 'Cannot be determined from given context')