##### Copyright 2021 The TensorFlow Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Pruning for on-device inference w/ XNNPACK

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/model_optimization/guide/pruning/pruning_for_on_device_inference"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/g3doc/guide/pruning/pruning_for_on_device_inference.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/g3doc/guide/pruning/pruning_for_on_device_inference.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/model-optimization/tensorflow_model_optimization/g3doc/guide/pruning/pruning_for_on_device_inference.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

Welcome to the guide on Keras weights pruning for improving latency of on-device inference via [XNNPACK](https://github.com/google/XNNPACK).

This guide presents the usage of the newly introduced `tfmot.sparsity.keras.PruningPolicy` API and demonstrates how it could be used for accelerating mostly convolutional models on modern CPUs using [XNNPACK Sparse inference](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md#sparse-inference).

The guide covers the following steps of the model creation process:
* Build and train the dense baseline
* Fine-tune model with pruning
* Convert to TFLite
* On-device benchmark

The guide doesn't cover the best practices for the fine-tuning with pruning. For more detailed information on this topic, please check out our [comprehensive guide](https://www.tensorflow.org/model_optimization/guide/pruning/comprehensive_guide.md).

## Setup

In [None]:
! pip install -q tensorflow
! pip install -q tensorflow-model-optimization

In [None]:
import tempfile

import tensorflow as tf
import numpy as np

from tensorflow import keras
import tensorflow_datasets as tfds
import tensorflow_model_optimization as tfmot

%load_ext tensorboard

## Build and train the dense model

We build and train a simple baseline CNN for classification task on [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.

In [None]:
# Load CIFAR10 dataset.
(ds_train, ds_val, ds_test), ds_info = tfds.load(
    'cifar10',
    split=['train[:90%]', 'train[90%:]', 'test'],
    as_supervised=True,
    with_info=True,
)

# Normalize the input image so that each pixel value is between 0 and 1.
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.image.convert_image_dtype(image, tf.float32), label

# Load the data in batches of 128 images.
batch_size = 128
def prepare_dataset(ds, buffer_size=None):
  ds = ds.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  ds = ds.cache()
  if buffer_size:
    ds = ds.shuffle(buffer_size)
  ds = ds.batch(batch_size)
  ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
  return ds

ds_train = prepare_dataset(ds_train,
                           buffer_size=ds_info.splits['train'].num_examples)
ds_val = prepare_dataset(ds_val)
ds_test = prepare_dataset(ds_test)

# Build the dense baseline model.
dense_model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(32, 32, 3)),
    keras.layers.ZeroPadding2D(padding=1),
    keras.layers.Conv2D(
        filters=8,
        kernel_size=(3, 3),
        strides=(2, 2),
        padding='valid'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.Conv2D(filters=16, kernel_size=(1, 1)),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.ZeroPadding2D(padding=1),
    keras.layers.DepthwiseConv2D(
        kernel_size=(3, 3), strides=(2, 2), padding='valid'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.Conv2D(filters=32, kernel_size=(1, 1)),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Flatten(),
    keras.layers.Dense(10)
])

# Compile and train the dense model for 10 epochs.
dense_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])

dense_model.fit(
  ds_train,
  epochs=10,
  validation_data=ds_val)

# Evaluate the dense model.
_, dense_model_accuracy = dense_model.evaluate(ds_test, verbose=0)

## Build the sparse model

Using the instructions from the [comprehensive guide](https://www.tensorflow.org/model_optimization/guide/pruning/comprehensive_guide.md), we apply `tfmot.sparsity.keras.prune_low_magnitude` function with parameters that target on-device acceleration via pruning i.e. `tfmot.sparsity.keras.PruneForLatencyOnXNNPack` policy.

In [None]:
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after after 5 epochs.
end_epoch = 5

num_iterations_per_epoch = len(ds_train)
end_step =  num_iterations_per_epoch * end_epoch

# Define parameters for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.25,
                                                               final_sparsity=0.75,
                                                               begin_step=0,
                                                               end_step=end_step),
      'pruning_policy': tfmot.sparsity.keras.PruneForLatencyOnXNNPack()
}

# Try to apply pruning wrapper with pruning policy parameter.
try:
  model_for_pruning = prune_low_magnitude(dense_model, **pruning_params)
except ValueError as e:
  print(e)

The call `prune_low_magnitude` results in `ValueError` with the message `Could not find a GlobalAveragePooling2D layer with keepdims = True in all output branches`. The message indicates that the model isn't supported for pruning with policy `tfmot.sparsity.keras.PruneForLatencyOnXNNPack` and specifically the layer `GlobalAveragePooling2D` requires the parameter `keepdims = True`. Let's fix that and reapply `prune_low_magnitude` function.

In [None]:
fixed_dense_model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(32, 32, 3)),
    keras.layers.ZeroPadding2D(padding=1),
    keras.layers.Conv2D(
        filters=8,
        kernel_size=(3, 3),
        strides=(2, 2),
        padding='valid'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.Conv2D(filters=16, kernel_size=(1, 1)),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.ZeroPadding2D(padding=1),
    keras.layers.DepthwiseConv2D(
        kernel_size=(3, 3), strides=(2, 2), padding='valid'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.Conv2D(filters=32, kernel_size=(1, 1)),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.GlobalAveragePooling2D(keepdims=True),
    keras.layers.Flatten(),
    keras.layers.Dense(10)
])

# Use the pretrained model for pruning instead of training from scratch.
fixed_dense_model.set_weights(dense_model.get_weights())

# Try to reapply pruning wrapper.
model_for_pruning = prune_low_magnitude(fixed_dense_model, **pruning_params)

Invocation of `prune_low_magnitude` has finished without any errors meaning that the model is fully supported for the `tfmot.sparsity.keras.PruneForLatencyOnXNNPack` policy and can be accelerated using [XNNPACK Sparse inference](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md#sparse-inference).

### Fine-tune the sparse model

Following the [pruning example](https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras.md), we fine-tune the sparse model using the weights of the dense model. We start fine-tuning of the model with 25% sparsity (25% of the weights are set to zero) and end with 75% sparsity.

In [None]:
logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

model_for_pruning.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])

model_for_pruning.fit(
  ds_train,
  epochs=15,
  validation_data=ds_val,
  callbacks=callbacks)

# Evaluate the dense model.
_, pruned_model_accuracy = model_for_pruning.evaluate(ds_test, verbose=0)

print('Dense model test accuracy:', dense_model_accuracy)
print('Pruned model test accuracy:', pruned_model_accuracy)

The logs show the progression of sparsity on a per-layer basis.

In [None]:
#docs_infra: no_execute
%tensorboard --logdir={logdir}

After the fine-tuning with pruning, test accuracy demonstrates a modest improvement (43% to 44%) compared to the dense model. Let's compare on-device latency using [TFLite benchmark](https://www.tensorflow.org/lite/performance/measurement).

## Model conversion and benchmarking

To convert the pruned model into TFLite, we need replace the `PruneLowMagnitude` wrappers with original layers via the `strip_pruning` function. Also, since the weights of the pruned model (`model_for_pruning`) are mostly zeros, we may apply an optimization `tf.lite.Optimize.EXPERIMENTAL_SPARSITY` to efficiently store the resulted TFLite model. This optimization flag is not required for the dense model.

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(dense_model)
dense_tflite_model = converter.convert()

_, dense_tflite_file = tempfile.mkstemp('.tflite')
with open(dense_tflite_file, 'wb') as f:
  f.write(dense_tflite_model)

model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
converter.optimizations = [tf.lite.Optimize.EXPERIMENTAL_SPARSITY]
pruned_tflite_model = converter.convert()

_, pruned_tflite_file = tempfile.mkstemp('.tflite')
with open(pruned_tflite_file, 'wb') as f:
  f.write(pruned_tflite_model)

Following the instructions of [TFLite Model Benchmarking Tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark), we build the tool, upload it to the Android device together with dense and pruned TFLite models, and benchmark both models on the device.

In [None]:
! adb shell /data/local/tmp/benchmark_model \
    --graph=/data/local/tmp/dense_model.tflite \
    --use_xnnpack=true \
    --num_runs=100 \
    --num_threads=1

In [None]:
! adb shell /data/local/tmp/benchmark_model \
    --graph=/data/local/tmp/pruned_model.tflite \
    --use_xnnpack=true \
    --num_runs=100 \
    --num_threads=1

Benchmarks on Pixel 4 resulted in average inference time of *17us* for the dense model and *12us* for the pruned model. The on-device benchmarks demonstrate a clear **5us** or **30%** improvements in latency even for such small models. In our experience, larger models based on [MobileNetV3](https://www.tensorflow.org/api_docs/python/tf/keras/applications/mobilenet_v3) or [EfficientNet-lite](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite) show similar performance improvements. The speed-up varies based on the relative contribution of 1x1 convolutions to the overall model.


## Conclusion

In this tutorial, we show how one may create sparse models for faster on-device performance using the new functionality introduced by the TF MOT API and XNNPack. These sparse models are smaller and faster than their dense counterparts while retaining or even surpassing their quality.

We encourage you to try this new capability which can be particularly important for deploying your models on device.