# AI Model Optimization Techniques

    AI model optimization techniques refer to methods used to improve model performance, efficiency, and robustness during training and inference. These techniques are essential for dealing with issues such as overfitting, underfitting, slow convergence, and inefficient use of computational resources. 


## What is overfitting, underfitting, slow convergence in Artificial Intelligence

    In Artificial Intelligence (AI) and Machine Learning (ML), the performance of a model is influenced by how well it generalizes to new, unseen data. Key concepts related to model performance include overfitting, underfitting, and slow convergence:

## 1. Overfitting

    Overfitting occurs when a model learns the details and noise in the training data to such an extent that it negatively impacts its performance on new data (test data or real-world data). 
    
    The model becomes too complex and overly sensitive to the training data, capturing both the signal (relevant patterns) and the noise (irrelevant details).

    Symptoms:
    High accuracy on training data but poor performance on test data.

    The model has memorized the training set rather than learning general patterns.

    Example: If a neural network is trained on a small dataset with many parameters, it might learn irrelevant details specific to that dataset, leading to poor generalization.

    Solution:
    Regularization techniques (e.g., L1/L2 regularization, dropout in neural networks).

    Cross-validation to monitor performance on validation data.

    Simplifying the model by reducing the number of parameters.
    
    More data to help the model learn more generalizable patterns.
 
## 2. Underfitting

    Underfitting occurs when a model is too simple to capture the underlying patterns in the data. 
    
    This can happen when the model has too few parameters or when it's not given enough training time. 
    As a result, the model performs poorly both on the training data and the test data.


    Symptoms:
    Low accuracy on both the training data and the test data.
    The model is unable to capture the important patterns in the data.
    
    Example: A linear model being used to fit data with a nonlinear relationship will result in underfitting because the model is too simplistic to represent the data's complexity.

    Solution:
    Increase the model complexity (e.g., using deeper neural networks or more features).

    Train the model longer to allow it to capture patterns in the data.
    
    Feature engineering to add more informative features.
 

## 3. Slow Convergence

    Slow convergence refers to the phenomenon where the training of a machine learning model progresses very slowly.
    
    It takes a long time for the model to reach an optimal solution or a satisfactory level of accuracy. 
    
    Convergence is typically related to how quickly the optimization algorithm (e.g., gradient descent) minimizes the loss function.


    Symptoms:
    The model’s performance improves very slowly during training.

    It takes many epochs (iterations) for the model to reach an acceptable level of accuracy or loss reduction.

    Example: When using a neural network, if the learning rate is too low or the initialization of parameters is poor, the model may take a long time to converge to a minimum.

    Solution:
    Increase the learning rate (but not too high, or it can cause divergence).

    Use adaptive optimization algorithms (e.g., Adam, RMSprop) that adjust learning rates dynamically.

    Normalize the data to ensure all features are on a similar scale, which helps the model converge faster.

    Use batch normalization or momentum in gradient descent to speed up convergence.

# Techniques for Quantization, Pruning, and Compression

    Quantization, pruning, and compression techniques are essential strategies for optimizing AI models, particularly when deploying them in resource-constrained environments (such as mobile devices, edge computing, or IoT). 

    These techniques reduce the model’s size, computational requirements, and energy consumption, while maintaining (or minimally affecting) accuracy.

## 1. Quantization
    Quantization reduces the precision of the numbers representing model parameters (such as weights and activations) from 32-bit floating-point numbers to lower-bit representations (like 16-bit, 8-bit, or even lower). This leads to smaller model sizes and faster inference times.

    Techniques:
    Post-Training Quantization:

    Convert the model after it has been fully trained. It’s the most common approach because it doesn’t require changes to the training process.

    Types include:

    Dynamic Range Quantization: Only weights are quantized, typically from 32-bit floating-point to 8-bit integers.

    Full Integer Quantization: Both weights and activations are quantized to 8-bit integers.
    
    Float16 Quantization: Weights are stored in float16 format (half precision), which reduces the size without impacting performance as significantly.


## Quantization, pruning, and compression

Quantization, pruning, and compression are popular techniques to optimize machine learning models for deployment in resource-constrained environments (like mobile or embedded devices). These techniques help reduce the model's size and improve inference speed while maintaining acceptable performance.

Here’s a step-by-step guide to applying these techniques using TensorFlow with Python:

# Step 1: Install TensorFlow and Other Necessary Libraries

    Make sure you have TensorFlow installed. If not, install it using:

    pip install tensorflow

# Step 2: Define and Train a Simple Model

For this example, we’ll use the MNIST dataset and build a basic neural network to classify handwritten digits.

In [2]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np

# Load the MNIST dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train, x_test = x_train / 255.0, x_test / 255.0  # Normalize pixel values


# Reshape data to add a channel dimension
x_train = x_train[..., np.newaxis]

x_test = x_test[..., np.newaxis]

# Build a simple model
model = models.Sequential([
    layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
    layers.MaxPooling2D(pool_size=(2, 2)),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
])

# Compile the model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train the model
model.fit(x_train, y_train, epochs=3, validation_split=0.1)

Epoch 1/3


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 7ms/step - accuracy: 0.9060 - loss: 0.3153 - val_accuracy: 0.9828 - val_loss: 0.0605
Epoch 2/3
[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 8ms/step - accuracy: 0.9824 - loss: 0.0552 - val_accuracy: 0.9832 - val_loss: 0.0611
Epoch 3/3
[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 7ms/step - accuracy: 0.9894 - loss: 0.0323 - val_accuracy: 0.9865 - val_loss: 0.0552


<keras.src.callbacks.history.History at 0x17ae955d0>

# Step 3: Quantization

    Quantization reduces the precision of the model’s weights and activations, commonly from 32-bit floating point (float32) to 8-bit integer (int8), which results in a smaller model size and faster inference.


## 3.1 Post-Training Quantization

    TensorFlow provides several quantization options. Here, we’ll use post-training quantization to convert the model to 8-bit precision.


In [16]:
#help(model)
#print(model.variables)
model.export("/Users/surendra/ai_embed/Machinelearnex/savemodel",format='tf_saved_model',verbose=True)

INFO:tensorflow:Assets written to: /Users/surendra/ai_embed/Machinelearnex/savemodel/assets


INFO:tensorflow:Assets written to: /Users/surendra/ai_embed/Machinelearnex/savemodel/assets


Saved artifact at '/Users/surendra/ai_embed/Machinelearnex/savemodel'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='keras_tensor')
Output Type:
  TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)
Captures:
  6262029584: TensorSpec(shape=(), dtype=tf.resource, name=None)
  6262025552: TensorSpec(shape=(), dtype=tf.resource, name=None)
  6262025744: TensorSpec(shape=(), dtype=tf.resource, name=None)
  6264948304: TensorSpec(shape=(), dtype=tf.resource, name=None)
  6264947344: TensorSpec(shape=(), dtype=tf.resource, name=None)
  6264947920: TensorSpec(shape=(), dtype=tf.resource, name=None)


In [17]:
# Example: TensorFlow Lite post-training quantization

import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_saved_model('/Users/surendra/ai_embed/Machinelearnex/savemodel')

converter.optimizations = [tf.lite.Optimize.DEFAULT]

tflite_model = converter.convert()

W0000 00:00:1730190346.956310 1148170 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1730190346.956600 1148170 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
2024-10-29 13:55:46.957039: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /Users/surendra/ai_embed/Machinelearnex/savemodel
2024-10-29 13:55:46.957414: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2024-10-29 13:55:46.957419: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /Users/surendra/ai_embed/Machinelearnex/savemodel
2024-10-29 13:55:46.960835: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2024-10-29 13:55:46.984755: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /Users/surendra/ai_embed/Machinelearnex/savemodel
2024-10-29 13:55:46.990548: I tensorflow/cc/saved_model/loader.cc:466] SavedModel load for tags { s

In [18]:
#help(tflite_model)


In [19]:
# Convert the model to TensorFlow Lite format with quantization

converter = tf.lite.TFLiteConverter.from_keras_model(model)

converter.optimizations = [tf.lite.Optimize.DEFAULT]  # Apply default quantization

tflite_model_quantized = converter.convert()

# Save the quantized model
with open("model_quantized.tflite", "wb") as f:
    f.write(tflite_model_quantized)

INFO:tensorflow:Assets written to: /var/folders/j9/t7f_l7rd20101rcpcynrvqyh0000gn/T/tmpq0komsft/assets


INFO:tensorflow:Assets written to: /var/folders/j9/t7f_l7rd20101rcpcynrvqyh0000gn/T/tmpq0komsft/assets


Saved artifact at '/var/folders/j9/t7f_l7rd20101rcpcynrvqyh0000gn/T/tmpq0komsft'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='keras_tensor')
Output Type:
  TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)
Captures:
  6262029584: TensorSpec(shape=(), dtype=tf.resource, name=None)
  6262025552: TensorSpec(shape=(), dtype=tf.resource, name=None)
  6262025744: TensorSpec(shape=(), dtype=tf.resource, name=None)
  6264948304: TensorSpec(shape=(), dtype=tf.resource, name=None)
  6264947344: TensorSpec(shape=(), dtype=tf.resource, name=None)
  6264947920: TensorSpec(shape=(), dtype=tf.resource, name=None)


W0000 00:00:1730190407.835683 1148170 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1730190407.835695 1148170 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
2024-10-29 13:56:47.835813: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /var/folders/j9/t7f_l7rd20101rcpcynrvqyh0000gn/T/tmpq0komsft
2024-10-29 13:56:47.836157: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2024-10-29 13:56:47.836162: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /var/folders/j9/t7f_l7rd20101rcpcynrvqyh0000gn/T/tmpq0komsft
2024-10-29 13:56:47.839198: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2024-10-29 13:56:47.858164: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /var/folders/j9/t7f_l7rd20101rcpcynrvqyh0000gn/T/tmpq0komsft
2024-10-29 13:56:47.864361: I tensorflow/cc/saved_model/loader.cc:

# 2. Pruning

    Pruning involves removing unnecessary neurons or parameters from a model, reducing its size and computation complexity without significantly impacting its performance. 

    The idea is that not all parameters contribute equally to the model’s predictions, so redundant or less impactful ones can be removed.

    Techniques:

    Magnitude-based Pruning:
    
    Remove weights that are smaller than a predefined threshold. This is the simplest and most common type of pruning.

In [26]:
pip install tf_keras

Collecting tf_keras
  Downloading tf_keras-2.18.0-py3-none-any.whl.metadata (1.6 kB)
Downloading tf_keras-2.18.0-py3-none-any.whl (1.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: tf_keras
Successfully installed tf_keras-2.18.0
Note: you may need to restart the kernel to use updated packages.


In [24]:
 pip install tensorflow_model_optimization

Note: you may need to restart the kernel to use updated packages.


# Pruning in Keras example

    Train a keras model for MNIST from scratch.

    Fine tune the model by applying the pruning API and see the accuracy.

    Create 3x smaller TF and TFLite models from pruning.

    Create a 10x smaller TFLite model from combining pruning and post-training quantization.
    
    See the persistence of accuracy from TF to TFLite.

In [3]:
! pip install -q tensorflow-model-optimization

In [4]:
import tempfile
import os

import tensorflow as tf
import numpy as np

from tensorflow_model_optimization.python.core.keras.compat import keras

%load_ext tensorboard

# Train a model for MNIST without pruning

In [5]:
# Load MNIST dataset
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

# Normalize the input image so that each pixel value is between 0 and 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
  train_images,
  train_labels,
  epochs=4,
  validation_split=0.1,
)

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


<tf_keras.src.callbacks.History at 0x14a63d650>

# Evaluate baseline test accuracy and save the model for later usage.

In [6]:
_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
keras.models.save_model(model, keras_file, include_optimizer=False)
print('Saved baseline model to:', keras_file)

Baseline test accuracy: 0.9781000018119812
Saved baseline model to: /var/folders/j9/t7f_l7rd20101rcpcynrvqyh0000gn/T/tmpse81zmjp.h5


  keras.models.save_model(model, keras_file, include_optimizer=False)


# Fine-tune pre-trained model with pruning

    Define the model

    You will apply pruning to the whole model and see this in the model summary.

    In this example, you start the model with 50% sparsity (50% zeros in weights) and end with 80% sparsity.

In [7]:
import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after 2 epochs.
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set.

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define model for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
}

model_for_pruning = prune_low_magnitude(model, **pruning_params)

# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model_for_pruning.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_reshap  (None, 28, 28, 1)         1         
 e (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_conv2d  (None, 26, 26, 12)        230       
  (PruneLowMagnitude)                                            
                                                                 
 prune_low_magnitude_max_po  (None, 13, 13, 12)        1         
 oling2d (PruneLowMagnitude                                      
 )                                                               
                                                                 
 prune_low_magnitude_flatte  (None, 2028)              1         
 n (PruneLowMagnitude)                                           
                                                        

# Train and evaluate the model against baseline

## Fine tune with pruning for two epochs.

    tfmot.sparsity.keras.UpdatePruningStep is required during training, and 
    tfmot.sparsity.keras.PruningSummaries provides logs for tracking progress and debugging.

In [9]:
logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

model_for_pruning.fit(train_images, train_labels,
                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)

# For this example, there is minimal loss in test accuracy after pruning, compared to the baseline.


_, model_for_pruning_accuracy = model_for_pruning.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Pruned test accuracy:', model_for_pruning_accuracy)

Epoch 1/2
Epoch 2/2
Baseline test accuracy: 0.9781000018119812
Pruned test accuracy: 0.9685999751091003


### The logs show the progression of sparsity on a per-layer basis.

In [None]:
#docs_infra: no_execute
%tensorboard --logdir={logdir}

# Create 3x smaller models from pruning

    Both tfmot.sparsity.keras.strip_pruning and applying a standard compression algorithm (e.g. via gzip) are necessary to see the compression benefits of pruning.

    strip_pruning is necessary since it removes every tf.Variable that pruning only needs during training, which would otherwise add to model size during inference

    Applying a standard compression algorithm is necessary since the serialized weight matrices are the same size as they were before pruning. However, pruning makes most of the weights zeros, which is added redundancy that algorithms can utilize to further compress the model.

    First, create a compressible model for TensorFlow.

In [11]:
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

_, pruned_keras_file = tempfile.mkstemp('.h5')
keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)
print('Saved pruned Keras model to:', pruned_keras_file)

Saved pruned Keras model to: /var/folders/j9/t7f_l7rd20101rcpcynrvqyh0000gn/T/tmpjojz6g0h.h5


  keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)


# Create a compressible model for TFLite.

In [12]:
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
pruned_tflite_model = converter.convert()

_, pruned_tflite_file = tempfile.mkstemp('.tflite')

with open(pruned_tflite_file, 'wb') as f:
  f.write(pruned_tflite_model)

print('Saved pruned TFLite model to:', pruned_tflite_file)

INFO:tensorflow:Assets written to: /var/folders/j9/t7f_l7rd20101rcpcynrvqyh0000gn/T/tmp0fdzp_t5/assets


INFO:tensorflow:Assets written to: /var/folders/j9/t7f_l7rd20101rcpcynrvqyh0000gn/T/tmp0fdzp_t5/assets


Saved pruned TFLite model to: /var/folders/j9/t7f_l7rd20101rcpcynrvqyh0000gn/T/tmpgpi4fy_k.tflite


W0000 00:00:1730209385.128790 1372965 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1730209385.129009 1372965 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
2024-10-29 19:13:05.129529: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /var/folders/j9/t7f_l7rd20101rcpcynrvqyh0000gn/T/tmp0fdzp_t5
2024-10-29 19:13:05.130033: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2024-10-29 19:13:05.130038: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /var/folders/j9/t7f_l7rd20101rcpcynrvqyh0000gn/T/tmp0fdzp_t5
I0000 00:00:1730209385.132777 1372965 mlir_graph_optimization_pass.cc:401] MLIR V1 optimization pass is not enabled
2024-10-29 19:13:05.133151: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2024-10-29 19:13:05.146976: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /var/folder

# Define a helper function to actually compress the models via gzip and measure the zipped size.

In [13]:
def get_gzipped_model_size(file):
  # Returns size of gzipped model, in bytes.
  import os
  import zipfile

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)

# Compare and see that the models are 3x smaller from pruning.

In [14]:
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped pruned Keras model: %.2f bytes" % (get_gzipped_model_size(pruned_keras_file)))
print("Size of gzipped pruned TFlite model: %.2f bytes" % (get_gzipped_model_size(pruned_tflite_file)))

Size of gzipped baseline Keras model: 78292.00 bytes
Size of gzipped pruned Keras model: 25819.00 bytes
Size of gzipped pruned TFlite model: 24967.00 bytes


# Create a 10x smaller model from combining pruning and quantization

    You can apply post-training quantization to the pruned model for additional benefits.

In [15]:
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_and_pruned_tflite_model = converter.convert()

_, quantized_and_pruned_tflite_file = tempfile.mkstemp('.tflite')

with open(quantized_and_pruned_tflite_file, 'wb') as f:
  f.write(quantized_and_pruned_tflite_model)

print('Saved quantized and pruned TFLite model to:', quantized_and_pruned_tflite_file)

print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped pruned and quantized TFlite model: %.2f bytes" % (get_gzipped_model_size(quantized_and_pruned_tflite_file)))

INFO:tensorflow:Assets written to: /var/folders/j9/t7f_l7rd20101rcpcynrvqyh0000gn/T/tmpvjywptjd/assets


INFO:tensorflow:Assets written to: /var/folders/j9/t7f_l7rd20101rcpcynrvqyh0000gn/T/tmpvjywptjd/assets


Saved quantized and pruned TFLite model to: /var/folders/j9/t7f_l7rd20101rcpcynrvqyh0000gn/T/tmpx3vplcmm.tflite
Size of gzipped baseline Keras model: 78292.00 bytes
Size of gzipped pruned and quantized TFlite model: 8685.00 bytes


W0000 00:00:1730209447.916397 1372965 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1730209447.916422 1372965 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
2024-10-29 19:14:07.916544: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /var/folders/j9/t7f_l7rd20101rcpcynrvqyh0000gn/T/tmpvjywptjd
2024-10-29 19:14:07.916992: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2024-10-29 19:14:07.916997: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /var/folders/j9/t7f_l7rd20101rcpcynrvqyh0000gn/T/tmpvjywptjd
2024-10-29 19:14:07.919693: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2024-10-29 19:14:07.929450: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /var/folders/j9/t7f_l7rd20101rcpcynrvqyh0000gn/T/tmpvjywptjd
2024-10-29 19:14:07.933530: I tensorflow/cc/saved_model/loader.cc:

# See persistence of accuracy from TF to TFLite

    Define a helper function to evaluate the TF Lite model on the test dataset.

In [16]:
import numpy as np

def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on ever y image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print('Evaluated on {n} results so far.'.format(n=i))

    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.

    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy


You evaluate the pruned and quantized model and see that the accuracy from TensorFlow persists to the TFLite backend.

In [17]:
interpreter = tf.lite.Interpreter(model_content=quantized_and_pruned_tflite_model)
interpreter.allocate_tensors()

test_accuracy = evaluate_model(interpreter)

print('Pruned and quantized TFLite test_accuracy:', test_accuracy)
print('Pruned TF test accuracy:', model_for_pruning_accuracy)

Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Pruned and quantized TFLite test_accuracy: 0.9688
Pruned TF test accuracy: 0.9685999751091003


INFO: Created TensorFlow Lite XNNPACK delegate for CPU.


# Conclusion

    We saw how to create sparse models with the TensorFlow Model Optimization Toolkit API for both TensorFlow and TFLite. You then combined pruning with post-training quantization for additional benefits.

    We created a 10x smaller model for MNIST, with minimal accuracy difference.

    We encourage you to try this new capability, which can be particularly important for deployment in resource-constrained environments.

In [25]:
# Load MNIST dataset
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

# Normalize the input image so that each pixel value is between 0 and 1.
train_images = train_images / 255.0
test_images = test_images / 255.0


# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
  train_images,
  train_labels,
  epochs=4,
  validation_split=0.1,
)

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


<tf_keras.src.callbacks.History at 0x12e997c90>

In [26]:
import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after 2 epochs.
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set.

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define model for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
}

model_for_pruning = prune_low_magnitude(model, **pruning_params)

# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model_for_pruning.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_reshap  (None, 28, 28, 1)         1         
 e_1 (PruneLowMagnitude)                                         
                                                                 
 prune_low_magnitude_conv2d  (None, 26, 26, 12)        230       
 _1 (PruneLowMagnitude)                                          
                                                                 
 prune_low_magnitude_max_po  (None, 13, 13, 12)        1         
 oling2d_1 (PruneLowMagnitu                                      
 de)                                                             
                                                                 
 prune_low_magnitude_flatte  (None, 2028)              1         
 n_1 (PruneLowMagnitude)                                         
                                                      