## Using Below URL as Reference
https://www.tensorflow.org/lite/tutorials/model_maker_question_answer 

https://arxiv.org/abs/1810.04805

# Import Lib

In [1]:
!pip install -q tflite-model-maker

import numpy as np
import os

import tensorflow as tf

from tensorflow import keras
from tflite_model_maker import model_spec
from tflite_model_maker import question_answer
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.question_answer import DataLoader
from tflite_model_maker.config import QuantizationConfig

[K     |████████████████████████████████| 593kB 8.4MB/s 
[K     |████████████████████████████████| 686kB 19.2MB/s 
[K     |████████████████████████████████| 122kB 24.3MB/s 
[K     |████████████████████████████████| 174kB 25.5MB/s 
[K     |████████████████████████████████| 645kB 45.3MB/s 
[K     |████████████████████████████████| 849kB 50.8MB/s 
[K     |████████████████████████████████| 6.3MB 52.8MB/s 
[K     |████████████████████████████████| 112kB 55.5MB/s 
[K     |████████████████████████████████| 1.1MB 49.5MB/s 
[K     |████████████████████████████████| 71kB 10.7MB/s 
[K     |████████████████████████████████| 92kB 13.2MB/s 
[K     |████████████████████████████████| 1.2MB 45.6MB/s 
[K     |████████████████████████████████| 102kB 13.7MB/s 
[K     |████████████████████████████████| 38.2MB 76kB/s 
[K     |████████████████████████████████| 358kB 43.8MB/s 
[K     |████████████████████████████████| 194kB 52.7MB/s 
[?25h  Building wheel for fire (setup.py) ... [?25l[?25hd

# Load Model and Data

In [2]:
# https://www.tensorflow.org/lite/api_docs/python/tflite_model_maker/question_answer/BertQaSpec
# Follow https://arxiv.org/abs/1810.04805 BERT authors recommendations for fine-tuning
# Change 4e-05 into 2e-5
# predict_batch_size 8 into 16
spec = question_answer.MobileBertQaSpec(
    uri='https://tfhub.dev/google/mobilebert/uncased_L-24_H-128_B-512_A-4_F-4_OPT/1',
    model_dir=None, seq_len=512, query_len=64, doc_stride=128,
    dropout_rate=0.1, initializer_range=0.02, learning_rate=3e-5,
    distribution_strategy='off', num_gpus=-1, tpu='',
    trainable=True, predict_batch_size=16, do_lower_case=True, is_tf2=False,
    tflite_input_name=None, tflite_output_name=None, init_from_squad_model=False,
    default_batch_size=32, name='MobileBert'
)

In [3]:
# Follow TF Lite Bert Tutorial as reference and dataset
train_data_path = tf.keras.utils.get_file(
    fname='triviaqa-web-train-8000.json',
    origin='https://storage.googleapis.com/download.tensorflow.org/models/tflite/dataset/triviaqa-web-train-8000.json')
validation_data_path = tf.keras.utils.get_file(
    fname='triviaqa-verified-web-dev.json',
    origin='https://storage.googleapis.com/download.tensorflow.org/models/tflite/dataset/triviaqa-verified-web-dev.json')

#Revert to default due need long time to train dataset with nested
#train_data_path = tf.keras.utils.get_file(
#    fname='QuAC-train.json',
#    origin='https://s3.amazonaws.com/my89public/quac/train_v0.2.json')
#validation_data_path = tf.keras.utils.get_file(
#    fname='QuAC-val.json',
#    origin='https://s3.amazonaws.com/my89public/quac/val_v0.2.json')

Downloading data from https://storage.googleapis.com/download.tensorflow.org/models/tflite/dataset/triviaqa-web-train-8000.json
Downloading data from https://storage.googleapis.com/download.tensorflow.org/models/tflite/dataset/triviaqa-verified-web-dev.json


In [4]:
train_data = DataLoader.from_squad(train_data_path, spec, is_training=True)
validation_data = DataLoader.from_squad(validation_data_path, spec, is_training=False)

Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: '<' not supported between instances of 'str' and 'Literal'


Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: '<' not supported between instances of 'str' and 'Literal'


Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: '<' not supported between instances of 'str' and 'Literal'


# Transfer Learning

In [5]:
#Use shuffle for random distribution
model = question_answer.create(train_data, model_spec=spec, epochs=4,shuffle=True)

INFO:tensorflow:Retraining the models...


INFO:tensorflow:Retraining the models...


Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


In [6]:
model.summary()

Model: "bert_span_labeler"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_word_ids (InputLayer)     [(None, 512)]        0                                            
__________________________________________________________________________________________________
input_mask (InputLayer)         [(None, 512)]        0                                            
__________________________________________________________________________________________________
input_type_ids (InputLayer)     [(None, 512)]        0                                            
__________________________________________________________________________________________________
core_model (Functional)         [(None, 512, 512), ( 24581888    input_word_ids[0][0]             
                                                                 input_mask[0][0] 

# Model Evaluation

In [7]:
model.evaluate(validation_data)

INFO:tensorflow:Made predictions for 400 records.


INFO:tensorflow:Made predictions for 400 records.


INFO:tensorflow:Made predictions for 800 records.


INFO:tensorflow:Made predictions for 800 records.


{'exact_match': 0.5476190476190477, 'final_f1': 0.6288730571593717}

# Export Model 

Reduce model size and inference latency by using QuantizationConfig

In [8]:
config = QuantizationConfig.for_float16()
print('Saving model...')
model.export(export_dir='.', export_format=ExportFormat.TFLITE, quantization_config=config)
model.evaluate_tflite('model.tflite', validation_data)

Saving model...
INFO:tensorflow:Assets written to: /tmp/tmpp1s0eibp/saved_model/assets


INFO:tensorflow:Assets written to: /tmp/tmpp1s0eibp/saved_model/assets


INFO:tensorflow:Vocab file is inside the TFLite model with metadata.


INFO:tensorflow:Vocab file is inside the TFLite model with metadata.


INFO:tensorflow:Saved vocabulary in /tmp/tmpftay63c2/vocab.txt.


INFO:tensorflow:Saved vocabulary in /tmp/tmpftay63c2/vocab.txt.


INFO:tensorflow:Finished populating metadata and associated file to the model:


INFO:tensorflow:Finished populating metadata and associated file to the model:


INFO:tensorflow:./model.tflite


INFO:tensorflow:./model.tflite


INFO:tensorflow:The associated file that has been been packed to the model is:


INFO:tensorflow:The associated file that has been been packed to the model is:


INFO:tensorflow:['vocab.txt']


INFO:tensorflow:['vocab.txt']


INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite


INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite


INFO:tensorflow:Made predictions for 100 records.


INFO:tensorflow:Made predictions for 100 records.


INFO:tensorflow:Made predictions for 200 records.


INFO:tensorflow:Made predictions for 200 records.


INFO:tensorflow:Made predictions for 300 records.


INFO:tensorflow:Made predictions for 300 records.


INFO:tensorflow:Made predictions for 400 records.


INFO:tensorflow:Made predictions for 400 records.


INFO:tensorflow:Made predictions for 500 records.


INFO:tensorflow:Made predictions for 500 records.


INFO:tensorflow:Made predictions for 600 records.


INFO:tensorflow:Made predictions for 600 records.


INFO:tensorflow:Made predictions for 700 records.


INFO:tensorflow:Made predictions for 700 records.


INFO:tensorflow:Made predictions for 800 records.


INFO:tensorflow:Made predictions for 800 records.


INFO:tensorflow:Made predictions for 900 records.


INFO:tensorflow:Made predictions for 900 records.


{'exact_match': 0.5476190476190477, 'final_f1': 0.6288730571593717}