# Finetuning BERT with Keras and tf.Module

In this experiment we convert a pre-trained BERT model checkpoint into a trainable Keras layer, which we use to solve a sentence pair classification task.

We achieve this using a tf.Module, which is a neat abstraction designed to handle pre-trained Tensorflow models.
Exported modules can be easily integrated into other models, which facilitates experiments with powerful NN architectures.

The plan for this experiment is:

1.   getting a pre-trained BERT model checkpoint
2.   defining the specification of the tf.Module
3.   exporting the module
4.   building the text preprocessing pipeline
5.   implementing a custom Keras layer
6.   training a Keras model to solve a sentence-pair classification task


# What is in this guide?
This guide is about integrating pre-trained Tensorflow models into Keras pipelines. It contains implementations of two things: a BERT tf.Module and a Keras layer built on top of it.
# What does it take?
For a reader familiar with TensorFlow it should take around 30 minutes to finish this guide.



In [0]:
!test -d bert_repo || git clone https://github.com/google-research/bert bert_repo

import re
import os
import sys
import json

import logging
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_hub as hub

from tensorflow import keras
from tensorflow.keras.callbacks import ReduceLROnPlateau, ModelCheckpoint

from sklearn.model_selection import train_test_split
from google.colab import auth, drive

if not 'bert_repo' in sys.path:
    sys.path.insert(0, 'bert_repo')

from modeling import BertModel, BertConfig
from tokenization import FullTokenizer, convert_to_unicode
from extract_features import InputExample, convert_examples_to_features


# get TF logger 
log = logging.getLogger('tensorflow')
log.handlers = []

## Step 1: getting the pre-trained model

In [0]:
!wget https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip
!unzip multi_cased_L-12_H-768_A-12.zip

## Step 2: building a tf.Module

**tf.Modules** are designed to provide a simple way to manipulate reusable parts of pre-trained machine learning models in Tensorflow. Google maintains a curated library of such modules at tf.Hub. In this guide however, we will build one by ourselves.

To that end, we will need to implement a ***module_fn*** which will contain the full specification of the module inner workings. 
We begin by defining input placeholders. Then the BERT graph is created from a configuration file passed through ***config_path***. Then we model outputs are defined: the final encoder layer to seq_output and pooled *'**CLS**'* token representation to pool_output.

Additionally, extra assets may be bundled with the module. In this example, we add a ***vocab_file*** containing the WordPiece vocabulary to the module assets. As a result, the vocabulary file will be exported with the module, which will make it self-contained.

In [0]:
def build_module_fn(config_path, vocab_path, do_lower_case=True):

    def bert_module_fn(is_training):
        """Spec function for a token embedding module."""

        input_ids = tf.placeholder(shape=[None, None], dtype=tf.int32, name="input_ids")
        input_mask = tf.placeholder(shape=[None, None], dtype=tf.int32, name="input_mask")
        token_type = tf.placeholder(shape=[None, None], dtype=tf.int32, name="segment_ids")

        config = BertConfig.from_json_file(config_path)
        model = BertModel(config=config, is_training=is_training,
                          input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type)
          
        seq_output = model.all_encoder_layers[-1]
        pool_output = model.get_pooled_output()

        config_file = tf.constant(value=config_path, dtype=tf.string, name="config_file")
        vocab_file = tf.constant(value=vocab_path, dtype=tf.string, name="vocab_file")
        lower_case = tf.constant(do_lower_case)

        tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, config_file)
        tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_file)
        
        input_map = {"input_ids": input_ids,
                     "input_mask": input_mask,
                     "segment_ids": token_type}
        
        output_map = {"pooled_output": pool_output,
                      "sequence_output": seq_output}

        output_info_map = {"vocab_file": vocab_file,
                           "do_lower_case": lower_case}
                
        hub.add_signature(name="tokens", inputs=input_map, outputs=output_map)
        hub.add_signature(name="tokenization_info", inputs={}, outputs=output_info_map)

    return bert_module_fn

Finally, we define signatures, which are particular transformations of inputs to outputs exposed to module consumers. One could think of it as a module interface with the outside world.

Here we add two signatures to the module: one that takes raw text features as input and returns computed text representations as output. The other takes no inputs and returns the path to vocabulary file and lowercase flag.

## Step 3: exporting the module

Now that the module_fn is defined, we can use it to build and export the module. Passing the tags_and_args argument to create_module_spec will result in two graph variants being added to the module: for training with tags ***{"train"}*** and for inference with an empty set of tags. This allows to control dropout, which is disabled at inference time, and enabled during training.

In [0]:
MODEL_DIR = "multi_cased_L-12_H-768_A-12" #@param {type:"string"} ['uncased_L-12_H-768_A-12','multi_cased_L-12_H-768_A-12']

config_path = "/content/{}/bert_config.json".format(MODEL_DIR)
vocab_path = "/content/{}/vocab.txt".format(MODEL_DIR)

tags_and_args = []
for is_training in (True, False):
  tags = set()
  if is_training:
    tags.add("train")
  tags_and_args.append((tags, dict(is_training=is_training)))

module_fn = build_module_fn(config_path, vocab_path)
spec = hub.create_module_spec(module_fn, tags_and_args=tags_and_args)
spec.export("bert-module", 
            checkpoint_path="/content/{}/bert_model.ckpt".format(MODEL_DIR))

## Step 4: building the text preprocessing pipeline

The BERT model requires that text is represented as 3 matrices containing ***input_ids***, ***input_mask***, and ***segment_ids***. In this step we build a pipeline which takes a list of strings, and outputs these three matrices, as simple as that.

First of all, raw input text is converted into ***InputExamples***. If the input text is a sentence pair, separated by a special '|||' sequence, the sentences are split.

In [0]:
def read_examples(str_list):
    """Read a list of `InputExample`s from a list of strings."""
    unique_id = 0
    for s in str_list:
        line = convert_to_unicode(s)
        if not line:
            continue
        line = line.strip()
        text_a = None
        text_b = None
        m = re.match(r"^(.*) \|\|\| (.*)$", line)
        if m is None:
            text_a = line
        else:
            text_a = m.group(1)
            text_b = m.group(2)
        yield InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)
        unique_id += 1

***InputExamples*** are then tokenized and converted to ***InputFeatures*** using the ***convert_examples_to_features*** function from the original repository. However, we will require these features to be converted to np.arrays to use with Keras.

In [0]:
def features_to_arrays(features):

    all_input_ids = []
    all_input_mask = []
    all_segment_ids = []

    for feature in features:
        all_input_ids.append(feature.input_ids)
        all_input_mask.append(feature.input_mask)
        all_segment_ids.append(feature.input_type_ids)

    return (np.array(all_input_ids, dtype='int32'), 
            np.array(all_input_mask, dtype='int32'), 
            np.array(all_segment_ids, dtype='int32'))

Finally, let us put it all together in a single pipeline.

In [0]:
def build_preprocessor(voc_path, seq_len, lower=True):
  tokenizer = FullTokenizer(vocab_file=voc_path, do_lower_case=lower)
  
  def strings_to_arrays(sents):
  
      sents = np.atleast_1d(sents).reshape((-1,))

      examples = []
      for example in read_examples(sents):
          examples.append(example)

      features = convert_examples_to_features(examples, seq_len, tokenizer)
      arrays = features_to_arrays(features)
      return arrays
  
  return strings_to_arrays

All done!

## Step 5: implementing a BERT Keras layer

In [0]:
class BertLayer(tf.keras.layers.Layer):
    def __init__(self, bert_path, seq_len=64, n_tune_layers=3, 
                 pooling="cls", do_preprocessing=True, verbose=False,
                 tune_embeddings=False, trainable=True, **kwargs):

        self.trainable = trainable
        self.n_tune_layers = n_tune_layers
        self.tune_embeddings = tune_embeddings
        self.do_preprocessing = do_preprocessing

        self.verbose = verbose
        self.seq_len = seq_len
        self.pooling = pooling
        self.bert_path = bert_path

        self.var_per_encoder = 16
        if self.pooling not in ["cls", "mean", None]:
            raise NameError(
                f"Undefined pooling type (must be either 'cls', 'mean', or None, but is {self.pooling}"
            )

        super(BertLayer, self).__init__(**kwargs)

    def build(self, input_shape):

        self.bert = hub.Module(self.build_abspath(self.bert_path), 
                               trainable=self.trainable, name=f"{self.name}_module")

        trainable_layers = []
        if self.tune_embeddings:
            trainable_layers.append("embeddings")

        if self.pooling == "cls":
            trainable_layers.append("pooler")

        if self.n_tune_layers > 0:
            encoder_var_names = [var.name for var in self.bert.variables if 'encoder' in var.name]
            n_encoder_layers = int(len(encoder_var_names) / self.var_per_encoder)
            for i in range(self.n_tune_layers):
                trainable_layers.append(f"encoder/layer_{str(n_encoder_layers - 1 - i)}/")
        
        # Add module variables to layer's trainable weights
        for var in self.bert.variables:
            if any([l in var.name for l in trainable_layers]):
                self._trainable_weights.append(var)
            else:
                self._non_trainable_weights.append(var)

        if self.verbose:
            print("*** TRAINABLE VARS *** ")
            for var in self._trainable_weights:
                print(var)

        self.build_preprocessor()
        self.initialize_module()

        super(BertLayer, self).build(input_shape)

    def build_abspath(self, path):
        if path.startswith("https://") or path.startswith("gs://"):
          return path
        else:
          return os.path.abspath(path)

    def build_preprocessor(self):
        sess = tf.keras.backend.get_session()
        tokenization_info = self.bert(signature="tokenization_info", as_dict=True)
        vocab_file, do_lower_case = sess.run([tokenization_info["vocab_file"],
                                              tokenization_info["do_lower_case"]])
        self.preprocessor = build_preprocessor(vocab_file, self.seq_len, do_lower_case)

    def initialize_module(self):
        sess = tf.keras.backend.get_session()
        
        vars_initialized = sess.run([tf.is_variable_initialized(var) 
                                     for var in self.bert.variables])

        uninitialized = []
        for var, is_initialized in zip(self.bert.variables, vars_initialized):
            if not is_initialized:
                uninitialized.append(var)

        if len(uninitialized):
            sess.run(tf.variables_initializer(uninitialized))

    def call(self, input):

        if self.do_preprocessing:
          input = tf.numpy_function(self.preprocessor, 
                                    [input], [tf.int32, tf.int32, tf.int32], 
                                    name='preprocessor')
          for feature in input:
            feature.set_shape((None, self.seq_len))
        
        input_ids, input_mask, segment_ids = input
        
        bert_inputs = dict(
            input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids
        )
        output = self.bert(inputs=bert_inputs, signature="tokens", as_dict=True)
        
        if self.pooling == "cls":
            pooled = output["pooled_output"]
        else:
            result = output["sequence_output"]
            
            input_mask = tf.cast(input_mask, tf.float32)
            mul_mask = lambda x, m: x * tf.expand_dims(m, axis=-1)
            masked_reduce_mean = lambda x, m: tf.reduce_sum(mul_mask(x, m), axis=1) / (
                    tf.reduce_sum(m, axis=1, keepdims=True) + 1e-10)
            
            if self.pooling == "mean":
              pooled = masked_reduce_mean(result, input_mask)
            else:
              pooled = mul_mask(result, input_mask)

        return pooled

    def get_config(self):
        config_dict = {
            "bert_path": self.bert_path, 
            "seq_len": self.seq_len,
            "pooling": self.pooling,
            "n_tune_layers": self.n_tune_layers,
            "tune_embeddings": self.tune_embeddings,
            "do_preprocessing": self.do_preprocessing,
            "verbose": self.verbose
        }
        super(BertLayer, self).get_config()
        return config_dict

## Step 6: sentence pair classification

Now let us try the layer on a real-world dataset. For this part we will use the [Quora Question Pairs](https://www.quora.com/q/quoradata/First-Quora-Dataset-Release-Question-Pairs) dataset which consists of over 400,000 potential question duplicate pairs labeled for semantic equivalence.

We join the question pairs with the "|||" sequence and split them into train and test.

In [0]:
df = pd.read_csv("sentiment.csv")
df.columns=["id","text","sentiment","polarity"]
df=df.sample(frac=1)
labels = df.polarity.values

texts = []
delimiter = " ||| "
for t in df.text.tolist():
  texts.append(t)

texts = np.array(texts)

trX, tsX, trY, tsY = train_test_split(texts, labels, shuffle=True, test_size=0.2)

In [0]:
len(trX),len(tsX)

Building and training a sentence-pair classification model is straighforward:

In [0]:
inp = tf.keras.Input(shape=(1,), dtype=tf.string)
encoder = BertLayer(bert_path="./bert-module/", seq_len=48, tune_embeddings=False,
                    pooling='cls', n_tune_layers=3, verbose=False)

pred = tf.keras.layers.Dense(1, activation='sigmoid')(encoder(inp))

model = tf.keras.models.Model(inputs=[inp], outputs=[pred])

INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore


In [0]:
model.summary()

model.compile(
      optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5, ),
      loss="binary_crossentropy",
      metrics=["accuracy"])

### Training

In [0]:
import logging
logging.getLogger("tensorflow").setLevel(logging.WARNING)

In [0]:
saver = keras.callbacks.ModelCheckpoint("bert_tuned.hdf5")

model.fit(trX, trY, validation_data=[tsX, tsY], batch_size=128, epochs=5, callbacks=[saver])

In [0]:
model.fit(trX, trY, validation_data=[tsX, tsY], batch_size=128, epochs=5, callbacks=[saver])

## Step 7: saving and restoring

In [0]:
testDS=pd.read_csv("test_dataset.csv")
testDS.columns=["id","s1","s2","label"]
testDS=testDS.sample(frac=1)
labels = testDS.label.values

texts = []
delimiter = " ||| "
for q1, q2 in zip(testDS.s1.tolist(), testDS.s2.tolist()):
  texts.append(delimiter.join((str(q1), str(q2))))

texts = np.array(texts)

# trX, tsX, trY, tsY = train_test_split(texts, labels, shuffle=True, test_size=0.2)

In [0]:
tsX=texts
tsY=labels

In [0]:
preds=model.predict(tsX[:])

In [0]:
predss=np.where(preds > 0.5, 1, 0)

In [0]:
evalDF=pd.DataFrame(columns=["input","product title","potential Key Phrase","gold","preds"])
evalDF["input"]=tsX
evalDF["product title"]=evalDF["input"]
evalDF["potential Key Phrase"]=evalDF["input"]
evalDF["product title"]=evalDF["product title"].str.split("|").str[0]
evalDF["potential Key Phrase"]=evalDF["potential Key Phrase"].str.split("|").str[-1]
evalDF["gold"]=tsY
evalDF["preds"]=predss
for index, row in evalDF.iterrows():
    print(row['potential Key Phrase'], "***************",row['preds'])# evalDF[["product title","potential Key Phrase","gold","preds"]].head(100)
# evalDF[['potential Key Phrase','preds']].head()

In [0]:
import json
json.dump(model.to_json(), open("model.json", "w"))

In [0]:
model = tf.keras.models.model_from_json(json.load(open("model.json")), 
                                        custom_objects={"BertLayer": BertLayer})

model.load_weights("bert_tuned.hdf5")

In [0]:
model.predict(trX[:10])

In some cases (e.g. when serving), one might want to optimize the trained model for maximum inference throughput. In TensorFlow this can be achieved by "freezing" the model. 

During "freezing" the model variables are replaced by constants, and the nodes required for training are pruned from the computational graph. The resulting graph becomes more lightweight, requires less RAM and achieves better performance.

In [0]:
from tensorflow.python.framework.graph_util import convert_variables_to_constants
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference

def freeze_keras_model(model, export_path=None, clear_devices=True):
    """
    Freezes a Keras model into a pruned computation graph.

    @param model The Keras model to be freezed.
    @param clear_devices Remove the device directives from the graph for better portability.
    @return The frozen graph definition.
    """
    
    sess = tf.keras.backend.get_session()
    graph = sess.graph
    
    with graph.as_default():

        input_tensors = model.inputs
        output_tensors = model.outputs
        dtypes = [t.dtype.as_datatype_enum for t in input_tensors]
        input_ops = [t.name.rsplit(":", maxsplit=1)[0] for t in input_tensors]
        output_ops = [t.name.rsplit(":", maxsplit=1)[0] for t in output_tensors]
        
        tmp_g = graph.as_graph_def()
        if clear_devices:
            for node in tmp_g.node:
                node.device = ""
        
        tmp_g = optimize_for_inference(
            tmp_g, input_ops, output_ops, dtypes, False)
        
        tmp_g = convert_variables_to_constants(sess, tmp_g, output_ops)
        
        if export_path is not None:
            with tf.gfile.GFile(export_path, "wb") as f:
                f.write(tmp_g.SerializeToString())
        
        return tmp_g

We freeze our trained model and write the serialized graph to file.

In [0]:
frozen_graph = freeze_keras_model(model, export_path="frozen_graph.pb")

Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
Instructions for updating:
Use `tf.compat.v1.graph_util.remove_training_nodes`
Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`


Now let's restore the frozen graph and do some inference.

In [0]:
!git clone https://github.com/gaphex/bert_experimental/

import tensorflow as tf
import numpy as np
import sys

sys.path.insert(0, "/content/bert_experimental")

from bert_experimental.finetuning.text_preprocessing import build_preprocessor
from bert_experimental.finetuning.graph_ops import load_graph

In [0]:
restored_graph = load_graph("frozen_graph.pb")





To run inference we need to get the handles for input and output tensors of the graph. This part a little tricky: we retrieve a list of all operations in the restored graph and then manually get the names of relevant ops. The list is sorted, so in this case it is enough to take the first and the last operation.

 To get the Tensor name we append **":0"** to the op name.

In [0]:
graph_ops = restored_graph.get_operations()
input_op, output_op = graph_ops[0].name, graph_ops[-1].name
print(input_op, output_op)

import/input_1_1 import/dense_1/Sigmoid


In [0]:
x = restored_graph.get_tensor_by_name(input_op + ':0')
y = restored_graph.get_tensor_by_name(output_op + ':0')

The preprocessing function we injected into the Keras layer is not serializable and was not restored in the new graph. No worries though - we can simply define it again with the same name.

In [0]:
preprocessor = build_preprocessor("./uncased_L-12_H-768_A-12/vocab.txt", 64)
py_func = tf.numpy_function(preprocessor, [x], [tf.int32, tf.int32, tf.int32], name='preprocessor')

In [0]:
py_func = tf.numpy_function(preprocessor, [x], [tf.int32, tf.int32, tf.int32])

Finally, we can get the predictions.

In [0]:
sess = tf.Session(graph=restored_graph)

In [0]:
trX[:10]

array(["What is the world's view on the design of India's new ₹500 and ₹2000 notes? ||| What do you think about the new design of 500 and 2000 rupee notes?",
       "Where does the Venus Express go since it's dead now? ||| Where is Venus right now (June 2016)?",
       'How can I learn communication skills? ||| How we improve our communication skills?',
       'How is energy stored in gasoline? ||| How is energy stored?',
       'How many goals did Messi score in his career? ||| Will Lionel Messi cancel his retirement?',
       'What were the major effects of the cambodia earthquake, and how do these effects compare to the Banda Sea earthquake in 1938? ||| What were the major effects of the cambodia earthquake, and how do these effects compare to the Concepcion earthquake in 1751?',
       'How would you know if a shy guy likes you? ||| How do I know if a guy I like is shy?',
       'What is it like to work for AT&T? ||| What was it like to work at AT&T?',
       "I have a Master's in 

In [0]:
y_out = sess.run(y, feed_dict={
        x: trX[:10].reshape((-1,1))
    })

y_out

array([[9.66316462e-01],
       [1.98765397e-02],
       [9.18453693e-01],
       [2.75392830e-02],
       [1.78414583e-03],
       [9.21200871e-01],
       [4.85856265e-01],
       [7.82228947e-01],
       [8.56636558e-04],
       [1.22578345e-01]], dtype=float32)