In [35]:
import numpy as np
import tensorflow as tf
import time
import pickle
import architectures

In [36]:
#one time training of the model can skip if model is already trained
def create_model():
    # Load MNIST dataset
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    # Normalize the input data
    x_train, x_test = x_train / 255.0, x_test / 255.0

    model = architectures.simple()

    # Compile the model
    model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])

    # Train the model
    model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))

    # Save the trained model
    model.save("models/mnist_dnn.keras")
    print('saved model')
    return model

create_model()

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
saved model


<keras.src.engine.sequential.Sequential at 0x7effb0422790>

In [12]:
#load the pretrained model and split it into two
def split_model():
    # Load the saved model
    model = tf.keras.models.load_model("models/mnist_dnn.keras")
    
    # Split the model into two
    model1 = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
    ])
    # print(model1.predict(tf.random.normal([1, 28, 28])))
    
    model2 = tf.keras.Sequential([
        tf.keras.layers.Dense(64,input_shape=(128,), activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    # print(model.summary())
    # print(model1.summary())
    # print(model2.summary())
    return model1, model2

#convert one model to tflite
def convert_to_tflite(model):
    def representative_dataset():
        for _ in range(100):
            yield [tf.random.normal([1, 28, 28]),]

    quantizer = tf.lite.TFLiteConverter.from_keras_model(model)
    quantizer.optimizations = [tf.lite.Optimize.DEFAULT]
    quantizer.representative_dataset = representative_dataset
    quantizer.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
    quantizer.inference_input_type = tf.uint8  # or tf.uint8
    quantizer.inference_output_type = tf.uint8  # or tf.uint8
    
    tflite_quant_model = quantizer.convert()

    # Save the TFLite model to a file
    with open("mnist_model_batched.tflite", "wb") as f:
        f.write(tflite_quant_model)
    print("saved tf lite model")
    return tflite_quant_model


model_keras, model_tflite = split_model()
model_keras.save("models/mnist_dnn1a.keras")
model_tflite = convert_to_tflite(model_tflite)

INFO:tensorflow:Assets written to: /tmp/tmps2p2wjab/assets


INFO:tensorflow:Assets written to: /tmp/tmps2p2wjab/assets
2024-04-22 18:23:31.003579: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2024-04-22 18:23:31.003612: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
2024-04-22 18:23:31.003917: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /tmp/tmps2p2wjab
2024-04-22 18:23:31.005032: I tensorflow/cc/saved_model/reader.cc:51] Reading meta graph with tags { serve }
2024-04-22 18:23:31.005051: I tensorflow/cc/saved_model/reader.cc:146] Reading SavedModel debug info (if present) from: /tmp/tmps2p2wjab
2024-04-22 18:23:31.007164: I tensorflow/cc/saved_model/loader.cc:233] Restoring SavedModel bundle.
2024-04-22 18:23:31.024989: I tensorflow/cc/saved_model/loader.cc:217] Running initialization op on SavedModel bundle at path: /tmp/tmps2p2wjab
2024-04-22 18:23:31.031465: I tensorflow/cc/saved_model/loader.cc:316] SavedModel

saved tf lite model


fully_quantize: 0, inference_type: 6, input_inference_type: UINT8, output_inference_type: UINT8


In [25]:
n_samples = 10000
def generate_input(n):
    return np.random.rand(n, 28, 28).astype(np.float32)
#load model 1

model_1a = tf.keras.models.load_model("models/mnist_dnn1a.keras")
start = time.perf_counter()
outputs = model_1a.predict(generate_input(n_samples), verbose=0)
print(outputs.shape)
end = time.perf_counter()
print("average inference time: ", (end - start)*1000/n_samples, " ms")


(10000, 128)
average inference time:  0.06477290317416191  ms


In [28]:

def generate_1b_input():
    return np.random.rand(1,128).astype(np.float32)

def tflite_inference(num_trials):
    
    #load model 1
    model_1a = tf.keras.models.load_model("models/mnist_dnn1a.keras")
    
    
    # Load the TFLite model
    interpreter = tf.lite.Interpreter(model_path="mnist_model_batched.tflite")
    interpreter.allocate_tensors()

    # Get input and output details
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    output_scale, output_zero_point = output_details[0]['quantization']
    input_scale, input_zero_point = input_details[0]["quantization"]
    
    # Prepare input data
    input_shape = input_details[0]['shape']
    inference_times = []
    inputs = [generate_1b_input() for _ in range(num_trials)]
    for input_data in inputs:
        start = time.perf_counter()
        input_data = (input_data / input_scale) + input_zero_point
        interpreter.set_tensor(input_details[0]['index'], input_data.astype(np.uint8))
        interpreter.invoke()
        output_data = interpreter.get_tensor(output_details[0]['index'])
        output_data = output_scale * (output_data - output_zero_point)
        end = time.perf_counter()
        inference_times.append((end - start)*1000)

    avg_inference_time = np.mean(inference_times)
    print("average inference time: ", avg_inference_time, "ms")

tflite_inference(10)

average inference time:  0.029266998171806335 ms
