In [1]:
import os

import numpy as np
import tensorflow as tf


2023-09-26 21:24:25.122944: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


allocate on the GPU only the memory that is needed

In [None]:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)

In [2]:
try:
    # Disable all GPUS
    tf.config.set_visible_devices([], 'GPU')
    visible_devices = tf.config.get_visible_devices()
    for device in visible_devices:
        assert device.device_type != 'GPU'
except:
    # Invalid device or cannot modify virtual devices once initialized.
    pass

In [3]:
def create_model(units=20):
  """Create a keras LSTM model for MNIST recognition

    Args:
        units (int, optional): dimensionality of the output space for the model.
          Defaults to 20.

    Returns:
        tf.keras.Model: A Keras LSTM model
    """

  model = tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(28, 28), name="input"),
      tf.keras.layers.LSTM(units, return_sequences=True),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(10, activation=tf.nn.softmax, name="output")
  ])
  model.compile(optimizer="adam",
                loss="sparse_categorical_crossentropy",
                metrics=["accuracy"])
  model.summary()
  return model


In [4]:
def get_train_data():
  """Get MNIST train and test data

    Returns:
        tuple: (data, label) pairs for train and test
    """
  (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
  x_train = x_train / 255.  # normalize pixel values to 0-1
  x_train = x_train.astype(np.float32)
  return (x_train, y_train)



In [5]:
def train_lstm_model(epochs, x_train, y_train):
  """Train keras LSTM model on MNIST dataset

    Args: epochs (int) : number of epochs to train the model
        x_train (numpy.array): list of the training data
        y_train (numpy.array): list of the corresponding array

    Returns:
        tf.keras.Model: A trained keras LSTM model
  """
  model = create_model()
  callback = tf.keras.callbacks.EarlyStopping(
      monitor="val_loss",
      patience=3)  #early stop if validation loss does not drop anymore
  model.fit(x_train,
            y_train,
            epochs=epochs,
            validation_split=0.2,
            batch_size=32,
            callbacks=[callback])
  return model



In [6]:
def convert_quantized_tflite_model(model, x_train):
  """Convert the save TF model to tflite model, then save it as .tflite flatbuffer format

    See
    https://www.tensorflow.org/lite/performance/post_training_integer_quant#convert_using_integer-only_quantization

    Args:
        model (tf.keras.Model): the trained LSTM Model
        x_train (numpy.array): list of the training data

    Returns:
        The converted model in serialized format.
  """

  def representative_dataset_gen(num_samples=100):
    for data in x_train[:num_samples]:
      yield [data.reshape(1, 28, 28)]

  converter = tf.lite.TFLiteConverter.from_keras_model(model)
  converter.optimizations = [tf.lite.Optimize.DEFAULT]
  converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
  converter.inference_input_type = tf.int8
  converter.inference_output_type = tf.int8
  converter.representative_dataset = representative_dataset_gen
  tflite_model = converter.convert()
  return tflite_model


In [7]:
def convert_tflite_model(model):
  """Convert the save TF model to tflite model, then save it as .tflite flatbuffer format

    Args:
        model (tf.keras.Model): the trained LSTM Model

    Returns:
        The converted model in serialized format.
  """
  converter = tf.lite.TFLiteConverter.from_keras_model(model)
  tflite_model = converter.convert()
  return tflite_model



In [15]:
def save_tflite_model(tflite_model, save_dir, model_name):
  """save the converted tflite model

  Args:
      tflite_model (binary): the converted model in serialized format.
      save_dir (str): the save directory
      model_name (str): model name to be saved
  """
  if not os.path.exists(save_dir):
    os.makedirs(save_dir)
  save_path = os.path.join(save_dir, model_name)
  with open(save_path, "wb") as f:
    f.write(tflite_model)
  print("Tflite model saved to %s", save_dir)


In [8]:
def prepare_trained_model(trained_model):
  """Fix the input of the trained model for inference

    Args:
        trained_model (tf.keras.Model): the trained LSTM model

    Returns:
        run_model (tf.keras.Model): the trained model with fixed input tensor size for inference
  """
  # TFLite converter requires fixed shape input to work, alternative: b/225231544
  fixed_input = tf.keras.layers.Input(shape=[28, 28],
                                      batch_size=1,
                                      dtype=trained_model.inputs[0].dtype,
                                      name="fixed_input")
  fixed_output = trained_model(fixed_input)
  run_model = tf.keras.models.Model(fixed_input, fixed_output)
  return run_model



In [9]:
x_train, y_train = get_train_data()

In [10]:
trained_model = train_lstm_model(20, x_train, y_train)

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 lstm (LSTM)                 (None, 28, 20)            3920      
                                                                 
 flatten (Flatten)           (None, 560)               0         
                                                                 
 output (Dense)              (None, 10)                5610      
                                                                 
Total params: 9530 (37.23 KB)
Trainable params: 9530 (37.23 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
Epoch 1/20


2023-09-26 21:24:33.767796: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7fbf8c00c510 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2023-09-26 21:24:33.767840: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2023-09-26 21:24:34.165849: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:255] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.


  13/1500 [..............................] - ETA: 13s - loss: 2.2115 - accuracy: 0.2308   

2023-09-26 21:24:34.849979: I ./tensorflow/compiler/jit/device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


In [12]:
run_model = prepare_trained_model(trained_model)

In [13]:
# Save the tf model
run_model.save("models", save_format="tf")
print("TF model saved to %s", "models")

INFO:tensorflow:Assets written to: models/assets


INFO:tensorflow:Assets written to: models/assets


TF model saved to %s models


In [16]:
tflite_model = convert_tflite_model(run_model)
save_tflite_model(tflite_model,
                "models",
                model_name="mnist_lstm.tflite")

INFO:tensorflow:Assets written to: /tmp/tmpgflo2vbn/assets


INFO:tensorflow:Assets written to: /tmp/tmpgflo2vbn/assets


Tflite model saved to %s models


2023-09-26 21:36:21.716872: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
2023-09-26 21:36:21.716907: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
2023-09-26 21:36:21.717181: I tensorflow/cc/saved_model/reader.cc:45] Reading SavedModel from: /tmp/tmpgflo2vbn
2023-09-26 21:36:21.724170: I tensorflow/cc/saved_model/reader.cc:91] Reading meta graph with tags { serve }
2023-09-26 21:36:21.724197: I tensorflow/cc/saved_model/reader.cc:132] Reading SavedModel debug info (if present) from: /tmp/tmpgflo2vbn
2023-09-26 21:36:21.749604: I tensorflow/cc/saved_model/loader.cc:231] Restoring SavedModel bundle.
2023-09-26 21:36:21.831716: I tensorflow/cc/saved_model/loader.cc:215] Running initialization op on SavedModel bundle at path: /tmp/tmpgflo2vbn
2023-09-26 21:36:21.887099: I tensorflow/cc/saved_model/loader.cc:314] SavedModel load for tags { serve }; Status: success: OK. Took 169919 

In [17]:
quantized_tflite_model = convert_quantized_tflite_model(run_model, x_train)
save_tflite_model(quantized_tflite_model,
                "models",
                model_name="mnist_lstm_quant.tflite")

INFO:tensorflow:Assets written to: /tmp/tmprfsopqmf/assets


INFO:tensorflow:Assets written to: /tmp/tmprfsopqmf/assets


Tflite model saved to %s models


2023-09-26 21:37:21.006647: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
2023-09-26 21:37:21.006691: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
2023-09-26 21:37:21.007216: I tensorflow/cc/saved_model/reader.cc:45] Reading SavedModel from: /tmp/tmprfsopqmf
2023-09-26 21:37:21.014373: I tensorflow/cc/saved_model/reader.cc:91] Reading meta graph with tags { serve }
2023-09-26 21:37:21.014410: I tensorflow/cc/saved_model/reader.cc:132] Reading SavedModel debug info (if present) from: /tmp/tmprfsopqmf
2023-09-26 21:37:21.039720: I tensorflow/cc/saved_model/loader.cc:231] Restoring SavedModel bundle.
2023-09-26 21:37:21.125011: I tensorflow/cc/saved_model/loader.cc:215] Running initialization op on SavedModel bundle at path: /tmp/tmprfsopqmf
2023-09-26 21:37:21.189005: I tensorflow/cc/saved_model/loader.cc:314] SavedModel load for tags { serve }; Status: success: OK. Took 181802 