In [None]:
import os
import pathlib
import datetime

import numpy as np
import matplotlib.pyplot as plt

# WARNING: Training on GPU is currently non-deterministic!
# Uncomment to train on CPU.
# os.environ["CUDA_VISIBLE_DEVICES"]="-1"

import warnings

warnings.filterwarnings(action="ignore")
import tensorflow as tf
from tensorflow import keras

In [None]:
# random seed setter to make training reproducible
import random

SEED = 123


def set_all_seeds(seed=SEED):
    tf.random.set_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

In [None]:
# enable training in notebook
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

In [None]:
print("GPU Available: ", tf.test.is_gpu_available())
print("Eager execution enabled: ", tf.executing_eagerly())

# Load and rescale data

In [None]:
(train_images, train_labels), (
    test_images,
    test_labels,
) = tf.keras.datasets.mnist.load_data()

train_images = np.expand_dims(train_images, axis=-1)
train_labels = np.expand_dims(train_labels, axis=-1)
test_images = np.expand_dims(test_images, axis=-1)
test_labels = np.expand_dims(test_labels, axis=-1)

scale = tf.constant(255, dtype=tf.dtypes.float32)
x_train, x_test = train_images / scale - 0.5, test_images / scale - 0.5
y_train, y_test = train_labels, test_labels

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

set_all_seeds()
datagen = ImageDataGenerator(
    featurewise_center=False,  # set input mean to 0 over the dataset
    samplewise_center=False,  # set each sample mean to 0
    featurewise_std_normalization=False,  # divide inputs by std of the dataset
    samplewise_std_normalization=False,  # divide each input by its std
    zca_whitening=False,  # apply ZCA whitening
    zca_epsilon=1e-06,  # epsilon for ZCA whitening
    rotation_range=0,  # randomly rotate images in the range (degrees, 0 to 180)
    # randomly shift images horizontally (fraction of total width)
    width_shift_range=0.1,
    # randomly shift images vertically (fraction of total height)
    height_shift_range=0.1,
    shear_range=0.0,  # set range for random shear
    zoom_range=0.0,  # set range for random zoom
    channel_shift_range=0.0,  # set range for random channel shifts
    # set mode for filling points outside the input boundaries
    fill_mode="nearest",
    cval=0.0,  # value used for fill_mode = "constant"
    horizontal_flip=True,  # randomly flip images
    vertical_flip=False,  # randomly flip images
    # set rescaling factor (applied before any other transformation)
    rescale=None,
    # set function that will be applied on each input
    preprocessing_function=None,
    # image data format, either "channels_first" or "channels_last"
    data_format=None,
    # fraction of images reserved for validation (strictly between 0 and 1)
    validation_split=0.0,
)
datagen.fit(x_train)

# Define and compile

In [None]:
from tensorflow.keras.layers import (
    Conv2D,
    Dense,
    MaxPool2D,
    Flatten,
    Input,
    BatchNormalization,
    ReLU,
)

# single dense layer, i.e. multiple logistic regression
model = keras.Sequential(
    [
        Conv2D(filters=32, kernel_size=5, padding="same", input_shape=(28, 28, 1)),
        BatchNormalization(),
        ReLU(),
        MaxPool2D(pool_size=2, strides=2),
        Flatten(),
        Dense(10, activation="softmax"),
    ]
)

training_params = {
    "optimizer": "adam",
    "loss": "sparse_categorical_crossentropy",
    "metrics": ["accuracy"],
}

set_all_seeds()
model.compile(**training_params)

model.summary()

In [None]:
batch_size = 32
epochs = 35
USE_TENSORBOARD = False

# run the training
if USE_TENSORBOARD:
    log_dir = "logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=log_dir, histogram_freq=0
    )

    set_all_seeds()
    model.fit_generator(
        datagen.flow(x_train, y_train, batch_size=batch_size),
        epochs=1,
        validation_data=(x_test, y_test),
        callbacks=[tensorboard_callback],
    )
    model.fit_generator(
        datagen.flow(x_train, y_train, batch_size=batch_size),
        epochs=epochs - 1,
        validation_data=(x_test, y_test),
        callbacks=[tensorboard_callback],
    )
else:
    set_all_seeds()
    model.fit_generator(
        datagen.flow(x_train, y_train, batch_size=batch_size),
        epochs=epochs,
        validation_data=(x_test, y_test),
    )

# Save keras model to disk

In [None]:
models_dir = pathlib.Path("./models/")
models_dir.mkdir(exist_ok=True, parents=True)

model.save(models_dir / "model.h5")

# Convert to TFLite and save to disk

In [None]:
# load keras model from disk for reproducibility
model = keras.models.load_model(models_dir / "model.h5")

### Float TFLite model

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
model_float_lite = converter.convert()

In [None]:
model_float_file = models_dir / "model_float.tflite"
size_float = model_float_file.write_bytes(model_float_lite)
print("Float model size: {:.0f} KB".format(size_float / 1024))

### Quantized TFLite model

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)

converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]  # this doesn't seem to do anything

# representative dataset to estimate activation distributions
x_train_ds = tf.data.Dataset.from_tensor_slices((x_train)).batch(1)


def representative_data_gen():
    for input_value in x_train_ds.take(100):
        yield [input_value]


converter.representative_dataset = representative_data_gen

model_quant_lite = converter.convert()

In [None]:
model_quant_file = models_dir / "model_quant.tflite"
size_quant = model_quant_file.write_bytes(model_quant_lite)
print("Quantized model size: {:.0f} KB".format(size_quant / 1024))

# Build interpreters and run inference on test set

In [None]:
interpreter_float = tf.lite.Interpreter(model_content=model_float_lite)
interpreter_float.allocate_tensors()
interpreter_quant = tf.lite.Interpreter(model_content=model_quant_lite)
interpreter_quant.allocate_tensors()

In [None]:
import sys


def eval_float(j, img):
    img = tf.expand_dims(img, 0)
    interpreter_float.set_tensor(interpreter_float.get_input_details()[0]["index"], img)
    interpreter_float.invoke()
    probability = interpreter_float.get_tensor(
        interpreter_float.get_output_details()[0]["index"]
    )
    return np.argmax(probability)


def eval_quant(j, img):
    if (j + 1) % 10 == 0:
        print("quant: {:6d}/10000".format(j + 1), end="\r")
        sys.stdout.flush()
    img = tf.expand_dims(img, 0)
    interpreter_quant.set_tensor(interpreter_quant.get_input_details()[0]["index"], img)
    interpreter_quant.invoke()
    probability = interpreter_quant.get_tensor(
        interpreter_quant.get_output_details()[0]["index"]
    )
    return np.argmax(probability)


predictions_float = np.NaN * np.zeros((y_test.shape[0],))
predictions_quant = np.NaN * np.zeros((y_test.shape[0],))

for j, img in enumerate(x_test):
    predictions_float[j] = eval_float(j, img)

for j, img in enumerate(x_test):
    predictions_quant[j] = eval_quant(j, img)

In [None]:
acc = tf.metrics.Accuracy()
print("Accuracy of models:")
print(
    "# Float TFLite model:     {:.2%}".format(
        acc(test_labels, predictions_float).numpy()
    )
)
print(
    "# Quantized TFLite model: {:.2%}".format(
        acc(test_labels, predictions_quant).numpy()
    )
)

# Convert tflite model graph

In [None]:
from tflite_utils import load_tflite_as_json, save_json_as_tflite
from tflite2xcore_utils import (
    clean_unused_opcodes,
    clean_unused_tensors,
    clean_unused_buffers,
)
from tflite2xcore_graph_conv import remove_float_inputs_outputs

model_quant_stripped_file = "models/model_quant_stripped.tflite"

json_model = load_tflite_as_json(model_quant_file)
remove_float_inputs_outputs(json_model)
clean_unused_opcodes(json_model)
clean_unused_tensors(json_model)
clean_unused_buffers(json_model)
save_json_as_tflite(json_model, model_quant_stripped_file)