In [4]:
from tensorflow import keras
from keras import layers
import tensorflow as tf
import numpy as np

import tensorflow_hub as hub

BASE_MODEL = "outputs/Mobile V1"
NUM_CLASSES = 5

Define the basic model:

In [5]:
model = tf.keras.Sequential(
    layers = [
        layers.Rescaling(scale=(1./127.5), offset=-1, name="Preprocessing"),  
        hub.KerasLayer(
            handle=f"{BASE_MODEL}/features",
            trainable=True,
            arguments=dict(batch_norm_momentum=0.997),
            name="Backbone"
        ),
        layers.Dense(NUM_CLASSES, activation=None, name="Output")
    ],
    name=f'{BASE_MODEL.replace(" ", "_")}'
)

model.build([None, 112, 112, 3])
model.summary()

Model: "outputs/Mobile_V1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 Preprocessing (Rescaling)   (None, 112, 112, 3)       0         
                                                                 
 Backbone (KerasLayer)       (None, 256)               218544    
                                                                 
 Output (Dense)              (None, 5)                 1285      
                                                                 
Total params: 219829 (858.71 KB)
Trainable params: 214357 (837.33 KB)
Non-trainable params: 5472 (21.38 KB)
_________________________________________________________________


Check the final sizes:

In [6]:
# Full size model ~2MB | ~3.9MB
keras.saving.save_model(model, f'{BASE_MODEL}/base_model')

# TFLite conversion ~800KB | ~1.6MB
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
with open(f'{BASE_MODEL}/model-lite/model.tflite', 'wb') as f:
  f.write(tflite_model)

# TFLITE quant version ~292KB | ~574KB
converter.optimizations = [ tf.lite.Optimize.DEFAULT ]
tflite_quant_model = converter.convert()
with open(f'{BASE_MODEL}/model-quant/model.tflite', 'wb') as f:
  f.write(tflite_quant_model)





INFO:tensorflow:Assets written to: outputs/Mobile V1/base_model/assets


INFO:tensorflow:Assets written to: outputs/Mobile V1/base_model/assets


INFO:tensorflow:Assets written to: /var/folders/5j/vfb1vn5d7mxd7fmy30glls2c0000gn/T/tmp0tn2mk89/assets


INFO:tensorflow:Assets written to: /var/folders/5j/vfb1vn5d7mxd7fmy30glls2c0000gn/T/tmp0tn2mk89/assets
2024-01-13 20:19:55.901749: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2024-01-13 20:19:55.901776: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
2024-01-13 20:19:55.906201: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /var/folders/5j/vfb1vn5d7mxd7fmy30glls2c0000gn/T/tmp0tn2mk89
2024-01-13 20:19:55.921274: I tensorflow/cc/saved_model/reader.cc:51] Reading meta graph with tags { serve }
2024-01-13 20:19:55.921305: I tensorflow/cc/saved_model/reader.cc:146] Reading SavedModel debug info (if present) from: /var/folders/5j/vfb1vn5d7mxd7fmy30glls2c0000gn/T/tmp0tn2mk89
2024-01-13 20:19:55.955503: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:388] MLIR V1 optimization pass is not enabled
2024-01-13 20:19:55.972329: I tensorflow/cc/saved_model/load

INFO:tensorflow:Assets written to: /var/folders/5j/vfb1vn5d7mxd7fmy30glls2c0000gn/T/tmpus9h5whl/assets


INFO:tensorflow:Assets written to: /var/folders/5j/vfb1vn5d7mxd7fmy30glls2c0000gn/T/tmpus9h5whl/assets
2024-01-13 20:20:03.412529: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2024-01-13 20:20:03.412554: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
2024-01-13 20:20:03.412891: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /var/folders/5j/vfb1vn5d7mxd7fmy30glls2c0000gn/T/tmpus9h5whl
2024-01-13 20:20:03.426576: I tensorflow/cc/saved_model/reader.cc:51] Reading meta graph with tags { serve }
2024-01-13 20:20:03.426601: I tensorflow/cc/saved_model/reader.cc:146] Reading SavedModel debug info (if present) from: /var/folders/5j/vfb1vn5d7mxd7fmy30glls2c0000gn/T/tmpus9h5whl
2024-01-13 20:20:03.476319: I tensorflow/cc/saved_model/loader.cc:233] Restoring SavedModel bundle.
2024-01-13 20:20:03.901308: I tensorflow/cc/saved_model/loader.cc:217] Running initialization

In [7]:
# TFLITE quant version ~322KB | ~649KB
def representative_dataset():
    for _ in range(100):
        data = np.random.rand(1, 112, 112, 3)
        yield [ data.astype(np.float32) ]

converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS_INT8 ]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_quant_full_model = converter.convert()
with open(f'{BASE_MODEL}/model-quant-full/model.tflite', 'wb') as f:
  f.write(tflite_quant_full_model)

INFO:tensorflow:Assets written to: /var/folders/5j/vfb1vn5d7mxd7fmy30glls2c0000gn/T/tmpghq78xbw/assets


INFO:tensorflow:Assets written to: /var/folders/5j/vfb1vn5d7mxd7fmy30glls2c0000gn/T/tmpghq78xbw/assets
2024-01-13 20:20:19.286462: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2024-01-13 20:20:19.286487: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
2024-01-13 20:20:19.286775: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /var/folders/5j/vfb1vn5d7mxd7fmy30glls2c0000gn/T/tmpghq78xbw
2024-01-13 20:20:19.301780: I tensorflow/cc/saved_model/reader.cc:51] Reading meta graph with tags { serve }
2024-01-13 20:20:19.301805: I tensorflow/cc/saved_model/reader.cc:146] Reading SavedModel debug info (if present) from: /var/folders/5j/vfb1vn5d7mxd7fmy30glls2c0000gn/T/tmpghq78xbw
2024-01-13 20:20:19.352430: I tensorflow/cc/saved_model/loader.cc:233] Restoring SavedModel bundle.
2024-01-13 20:20:19.842403: I tensorflow/cc/saved_model/loader.cc:217] Running initialization

## Testing functionalities:

In [13]:
interpreter = tf.lite.Interpreter(model_path="./Mobile V1/model-quant-full/model.tflite")
interpreter.allocate_tensors()

In [14]:
test_image = np.expand_dims(np.random.rand(112, 112, 3), axis=0).astype(np.int8)

input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]

interpreter.set_tensor(input_index, test_image)
interpreter.invoke()
predictions = interpreter.get_tensor(output_index)

In [15]:
interpreter.get_output_details()

[{'name': 'StatefulPartitionedCall:0',
  'index': 91,
  'shape': array([1, 6], dtype=int32),
  'shape_signature': array([-1,  6], dtype=int32),
  'dtype': numpy.int8,
  'quantization': (0.020725928246974945, -73),
  'quantization_parameters': {'scales': array([0.02072593], dtype=float32),
   'zero_points': array([-73], dtype=int32),
   'quantized_dimension': 0},
  'sparsity_parameters': {}}]