# MobileBERT QAT Tutorial

This notebook provides a basic example code to build, run, and fine-tune [MobileBERT](https://arxiv.org/pdf/2004.02984.pdf) with QAT toolkit.

Pretrained models downloaded from the [TensorFlow Hub](https://tfhub.dev/google/qat/nlp/mobilebert_xs_qat) and the [TensorFlow Model Garden](https://github.com/tensorflow/models/tree/master/official/projects/qat/nlp), which are both trained on [SQuAD](https://deepmind.com/research/open-source/kinetics) dateset for Q&A task. You will run inference the models with dummy inputs.

## Setup

In [None]:
# Install packages

# tf-models-official is the stable Model Garden package
# tf-models-nightly includes latest changes
!pip install -q tf-models-nightly

In [None]:
# Run imports
import os

import numpy as np
import tensorflow as tf
import tensorflow_hub as hub

## Launch QAT Training

Follow the [training guideline](https://github.com/tensorflow/models/tree/master/official/projects/qat/nlp#training) to start QAT training using the pretrained checkpoint.

## Running model from TFHub

Running QAT trained MobileBERT model from tfhub. Note that it contains Fake-quant op and all ops are float32. It becomes actual int8 op when you convert them to TFLite using TFLite converter.

In [None]:
loaded_obj = hub.load("https://tfhub.dev/google/qat/nlp/mobilebert_xs_qat/1")
serving_model = loaded_obj.signatures['serving_default']

# Dummy inputs
input_type_ids = tf.zeros(shape=[1, 384], dtype=tf.int32)
input_word_ids = tf.zeros(shape=[1, 384], dtype=tf.int32)
input_mask = tf.zeros(shape=[1, 384], dtype=tf.int32)

bert_inputs = dict(
    input_type_ids=input_type_ids, input_word_ids=input_word_ids, input_mask=input_mask)

bert_outputs = serving_model(**bert_inputs)

start_logits = bert_outputs["start_logits"]
end_logits =  bert_outputs["end_logits"]

print(start_logits.shape)
print(end_logits.shape)

(1, 384)
(1, 384)


## Running TFLite Model Inference
Running inference with trained quantized TFLite model with dummy dataset. We assume that data is already converted to integer from an input string using vocabulary.

In [None]:
# First download the TFLite model.
! curl https://storage.googleapis.com/tf_model_garden/nlp/qat/mobilebert/model_qat.tflite --output model_qat.tflite

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 33.8M  100 33.8M    0     0   102M      0 --:--:-- --:--:-- --:--:--  101M


In [None]:
def get_dequantized_tensor(interpreter, output_detail):
  if ('quantization' not in output_detail or
      np.dtype(output_detail['dtype']) == np.dtype(np.float32)):
    return interpreter.get_tensor(output_detail['index'])
  output_scale, output_zero_point = output_detail['quantization']
  return (np.array(interpreter.get_tensor(output_detail['index']), dtype=np.float32) - output_zero_point) * output_scale

def run_tflite(interpreter, input_word_ids, input_mask, input_type_ids):
  input_word_ids_index, input_mask_index, input_type_ids_index = [
      detail['index'] for detail in interpreter.get_input_details()]
  interpreter.set_tensor(input_word_ids_index, input_word_ids)
  interpreter.set_tensor(input_mask_index, input_mask)
  interpreter.set_tensor(input_type_ids_index, input_type_ids)
  interpreter.invoke()

  start_logits_detail, end_logits_detail = interpreter.get_output_details()

  return get_dequantized_tensor(interpreter, start_logits_detail), get_dequantized_tensor(interpreter, end_logits_detail)

In [None]:
tflite_file = 'model_qat.tflite'
with open(tflite_file, 'rb') as fp:
  tflite_model = fp.read()

interpreter = tf.lite.Interpreter(
    model_content=tflite_model,
    experimental_preserve_all_tensors=True)
interpreter.allocate_tensors()

In [None]:
# Dummy inputs
input_type_ids = np.zeros(shape=[1, 384], dtype=np.int32)
input_word_ids = np.zeros(shape=[1, 384], dtype=np.int32)
input_mask = np.zeros(shape=[1, 384], dtype=np.int32)

start_logits, end_logits = run_tflite(interpreter, input_type_ids, input_word_ids, input_mask)

print(start_logits.shape)
print(end_logits.shape)

(1, 384)
(1, 384)
