In [72]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import os
os.environ["CUDA_VISIBLE_DEVICES"]="-1"
def get_gzipped_model_size(file):
  # Returns size of gzipped model, in bytes.
  import os
  import zipfile

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)

In [84]:
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 and 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=6, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.summary()
model.fit(
  train_images,
  train_labels,
  epochs=4,
  validation_split=0.1,
)

Model: "sequential_11"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 reshape_11 (Reshape)        (None, 28, 28, 1)         0         
                                                                 
 conv2d_11 (Conv2D)          (None, 26, 26, 6)         60        
                                                                 
 max_pooling2d_11 (MaxPoolin  (None, 13, 13, 6)        0         
 g2D)                                                            
                                                                 
 flatten_11 (Flatten)        (None, 1014)              0         
                                                                 
 dense_11 (Dense)            (None, 10)                10150     
                                                                 
Total params: 10,210
Trainable params: 10,210
Non-trainable params: 0
_________________________________________________

<keras.callbacks.History at 0x1e6e9069f10>

In [85]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

open("modelFortest.tflite", "wb").write(tflite_model)
import os
model_size = os.path.getsize("./modelFortest.tflite")
print("Basic model is %d bytes" % model_size)
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size('modelFortest.tflite')))



INFO:tensorflow:Assets written to: C:\Users\zs\AppData\Local\Temp\tmpw6tzmef0\assets


INFO:tensorflow:Assets written to: C:\Users\zs\AppData\Local\Temp\tmpw6tzmef0\assets


Basic model is 44140 bytes
Size of gzipped baseline Keras model: 39649.00 bytes


In [86]:
import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after 2 epochs.
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set.

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define model for pruning.
ps = tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0,
                                                               final_sparsity=0.5,
                                                               begin_step=100,
                                                               end_step=end_step)


model_for_pruning = prune_low_magnitude(model, pruning_schedule=ps)

# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model_for_pruning.summary()

Model: "sequential_11"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_reshape  (None, 28, 28, 1)        1         
 _11 (PruneLowMagnitude)                                         
                                                                 
 prune_low_magnitude_conv2d_  (None, 26, 26, 6)        116       
 11 (PruneLowMagnitude)                                          
                                                                 
 prune_low_magnitude_max_poo  (None, 13, 13, 6)        1         
 ling2d_11 (PruneLowMagnitud                                     
 e)                                                              
                                                                 
 prune_low_magnitude_flatten  (None, 1014)             1         
 _11 (PruneLowMagnitude)                                         
                                                     

In [87]:
import tempfile

logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

model_for_pruning.fit(train_images, train_labels,
                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x1e6e93aa6a0>

In [88]:
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
tflite_model = converter.convert()

open("pruningModelFortest.tflite", "wb").write(tflite_model)
import os
model_size = os.path.getsize("./pruningModelFortest.tflite")
print("Basic model is %d bytes" % model_size)
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size('pruningModelFortest.tflite')))



INFO:tensorflow:Assets written to: C:\Users\zs\AppData\Local\Temp\tmpczpyh49b\assets


INFO:tensorflow:Assets written to: C:\Users\zs\AppData\Local\Temp\tmpczpyh49b\assets


Basic model is 44140 bytes
Size of gzipped baseline Keras model: 24568.00 bytes


In [89]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
open("ModelFortest_q.tflite", "wb").write(tflite_model)
import os
model_size = os.path.getsize("./ModelFortest.tflite")
print("Basic model is %d bytes" % model_size)
model_size = os.path.getsize("./ModelFortest_q.tflite")
print("Basic model is %d bytes" % model_size)
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size('ModelFortest_q.tflite')))



INFO:tensorflow:Assets written to: C:\Users\zs\AppData\Local\Temp\tmpc0b9poio\assets


INFO:tensorflow:Assets written to: C:\Users\zs\AppData\Local\Temp\tmpc0b9poio\assets


Basic model is 44140 bytes
Basic model is 13808 bytes
Size of gzipped baseline Keras model: 7718.00 bytes


In [90]:
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_and_pruned_tflite_model = converter.convert()
open("pruningModelFortest_q.tflite", "wb").write(quantized_and_pruned_tflite_model)
import os
model_size = os.path.getsize("./pruningModelFortest.tflite")
print("Basic model is %d bytes" % model_size)
model_size = os.path.getsize("./pruningModelFortest_q.tflite")
print("Basic model is %d bytes" % model_size)
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size('pruningModelFortest_q.tflite')))



INFO:tensorflow:Assets written to: C:\Users\zs\AppData\Local\Temp\tmpq8vlr6st\assets


INFO:tensorflow:Assets written to: C:\Users\zs\AppData\Local\Temp\tmpq8vlr6st\assets


Basic model is 44140 bytes
Basic model is 13808 bytes
Size of gzipped baseline Keras model: 7732.00 bytes
