# Segmentation model demo for MNIST dataset on IMX8M Plus

## Library imports and drive mount

In [64]:
! pip install -q tensorflow
! pip install -q tensorflow-model-optimization


In [65]:
import tempfile
import os

import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow import keras


In [75]:
print("tensorflow version ",tf.__version__)

from google.colab import drive
drive.mount('/content/drive')
write_path = "/content/drive/MyDrive/mnist_demo/" # change or set as per drive for saving models and sample images

tensorflow version  2.8.2
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Base model training

**model definition**

In [71]:
HEIGHT = 28
WIDTH = 28
N_CLASSES = 10
def seg_model_function(interpolation):
  cnn_filters = 32
  model = keras.Sequential()
  model.add(keras.layers.Conv2D(cnn_filters, (2, 2), activation='relu', kernel_initializer='he_uniform', padding='same', input_shape=(HEIGHT, WIDTH,1), name = "input"))
  model.add(keras.layers.MaxPooling2D((2, 2), padding='same'))
  model.add(keras.layers.Conv2D(cnn_filters, (2, 2), activation='relu', kernel_initializer='he_uniform', padding='same'))
  model.add(keras.layers.MaxPooling2D((2, 2), padding='same'))
  model.add(keras.layers.Conv2D(cnn_filters, (2, 2), strides=(2,2), activation='relu', kernel_initializer='he_uniform', padding='same'))
  model.add(keras.layers.Conv2D(8, (2, 2), activation='relu', kernel_initializer='he_uniform', padding='same'))
  model.add(keras.layers.UpSampling2D((2, 2), interpolation=interpolation))
  model.add(keras.layers.Conv2D(cnn_filters, (2, 2), activation='relu'))
  model.add(keras.layers.UpSampling2D((2, 2), interpolation=interpolation))
  model.add(keras.layers.Conv2D(cnn_filters, (2, 2), activation='relu', kernel_initializer='he_uniform', padding='same'))
  model.add(keras.layers.UpSampling2D((2, 2), interpolation=interpolation))
  model.add(keras.layers.Conv2D(1, (2, 2), activation='sigmoid', padding='same'))
  model.add(keras.layers.Flatten())
  model.add(keras.layers.Dense(units = N_CLASSES, activation ='softmax', name = "output"))

  return model    


**model training**

In [None]:
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
interpolation_used = "nearest" # check and compare "bilinear" vs "nearest" upsampling2d implementions on IMX8MP
base_model = seg_model_function(interpolation_used)
base_model.summary()

# Compile and train the model
base_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])
base_model.fit(
  train_images,
  train_labels,
  epochs = 5
)

Model: "sequential_8"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input (Conv2D)              (None, 28, 28, 32)        160       
                                                                 
 max_pooling2d_16 (MaxPoolin  (None, 14, 14, 32)       0         
 g2D)                                                            
                                                                 
 conv2d_51 (Conv2D)          (None, 14, 14, 32)        4128      
                                                                 
 max_pooling2d_17 (MaxPoolin  (None, 7, 7, 32)         0         
 g2D)                                                            
                                                                 
 conv2d_52 (Conv2D)          (None, 4, 4, 32)          4128      
                                                                 
 conv2d_53 (Conv2D)          (None, 4, 4, 8)          

# Sample inference to verify the trained model

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pickle
import matplotlib.pyplot as plt
import numpy as np
# Generate reconstructions
num_samples = 5
samples = test_images[:num_samples]
true_labels = test_labels[:num_samples]
predictions = base_model.predict(samples)

# Plot the results
for i in np.arange(0, num_samples):
  sample = samples[i,:, :]
  prediction = np.argmax(predictions[i])
  true_label = np.amax(true_labels[i])
  fig, axes = plt.subplots(1, 1)
  axes.imshow(sample)
  fig.suptitle(f'true = {true_label}, predicted = {prediction}')
  plt.show()

# save samples and labels as pkl file for running inference on imx8mp
labels = true_labels
pkl_file = write_path + "mnist_samples_labels.pkl"
with open(pkl_file, 'wb') as fptr:
  pickle.dump(samples,fptr)
  pickle.dump(labels,fptr)


# Post training quantization into INT8 format

In [None]:
def ref_data_gen_v2():
    for loop_var in range(1000):
        sample = train_images[loop_var]
        sample = sample.reshape(BATCH_SIZE,HEIGHT,WIDTH,1)
        sample = sample.astype(np.float32)
        yield [sample]
BATCH_SIZE = 1

input_name = base_model.input_names[0]
index = base_model.input_names.index(input_name)
base_model.inputs[index].set_shape([BATCH_SIZE, HEIGHT, WIDTH, 1]) # to avoid dynamic tensors in tflite model, use 1 as batch size
converter = tf.lite.TFLiteConverter.from_keras_model(base_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = ref_data_gen_v2
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.experimental_new_converter = True
quantized_model = converter.convert()
# save model in .tflite format for running inference on imx8mp
model_name =  "mnist_" + interpolation_used + "_demo_ptq.tflite"
open(write_path + model_name, "wb").write(quantized_model)

# Visualize the quantized model

In [None]:
# verify that the operations are in INT8 format
tf.lite.experimental.Analyzer.analyze(model_content=quantized_model)