Welcome to the comprehensive guide for Keras weight pruning. 

Use this page to quickly find the APIs you need for your use case via the navigation sidebar. Once you know which APIs you need, find the parameters and the low-level details in the
[API docs](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity). 

*  If you want to see the benefits of pruning and what's supported, see the [pruning overview](https://www.tensorflow.org/model_optimization/guide/pruning). 
*  For a single end-to-end example, see the [pruning example](https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras).

The following corresponds to the navigation sidebar:
* Without a pruned model, you must **define** and **train** the model.
* For Keras HDF5 models only, you need special **checkpointing and deserialization**. Checkpointing cannot be done with Keras HDF5 weights.
* For **deployment** only, you must take steps to see compression benefits.

For configuration of the pruning algorithm, refer to the [prune_low_magnitude
API docs](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/prune_low_magnitude).

Run the boilerplate code below once to start.

# Boilerplate: run once per Colab session

In [0]:
# Run this section once per Colab session.

! pip uninstall -y tensorflow
! pip install -q tensorflow==2.1.0
! pip install -q tensorflow-model-optimization

import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot

%load_ext tensorboard
import tensorboard

import tempfile

input_shape = [20]
x_train = np.random.randn(1, 20).astype(np.float32)
y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=20)

def setup_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Dense(20, input_shape=input_shape),
      tf.keras.layers.Flatten()
  ])
  return model

def setup_pretrained_weights():
  model = setup_model()

  model.compile(
      loss=tf.keras.losses.categorical_crossentropy,
      optimizer='adam',
      metrics=['accuracy']
  )

  model.fit(x_train, y_train)

  _, pretrained_weights = tempfile.mkstemp('.h5')

  model.save_weights(pretrained_weights)

  return pretrained_weights
  
def get_gzipped_model_size(model):
  # Returns size of gzipped model, in bytes.
  import os
  import zipfile

  _, keras_file = tempfile.mkstemp('.h5') 
  model.save(keras_file, include_optimizer=False)

  _, zipped_file = tempfile.mkstemp('.zip') 
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(keras_file)

  return os.path.getsize(zipped_file)

setup_model()
pretrained_weights = setup_pretrained_weights()

Uninstalling tensorflow-1.15.0:
  Successfully uninstalled tensorflow-1.15.0
[K     |████████████████████████████████| 3.9MB 2.9MB/s 
[K     |████████████████████████████████| 450kB 54.4MB/s 
[K     |████████████████████████████████| 102kB 3.3MB/s 
[?25hTrain on 1 samples


# Define model

### Prune all layers in Functional and Sequential models

**Tips** for better model accuracy:

* Try "Prune some layers" on the navigation sidebar to skip pruning the layers that affect accuracy the most.
* Generally better to start from pre-trained weights.

**More**: the [`prune_low_magnitude`](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/prune_low_magnitude) API docs provide details on configuring
the pruning algorithm.


In [0]:
model = setup_model()
model.load_weights(pretrained_weights) # optional but recommended.

pruned_model = tfmot.sparsity.keras.prune_low_magnitude(model)

pruned_model.summary()

### Prune some layers in Functional and Sequential models

**Tips** for better model accuracy:

* Generally better to start from pre-trained weights.
* Try pruning the later layers instead of the first layers.
* Avoid pruning critical layers (e.g. attention mechanism). 

**More**: the [`prune_low_magnitude`](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/prune_low_magnitude) API docs provide details on how to vary the pruning configuration per layer.

In [0]:
model = setup_model()
model.load_weights(pretrained_weights) # optional but recommended

## Version 1: Prune all dense layers.
def apply_pruning_to_dense(layer):
  if isinstance(layer, tf.keras.layers.Dense):
    return tfmot.sparsity.keras.prune_low_magnitude(layer)
  return layer
    
pruned_dense_layers_model = tf.keras.models.clone_model(
    model, 
    clone_function=apply_pruning_to_dense,
)

print("pruned_dense_layers")
pruned_dense_layers_model.summary()

## Version 2: Prune the first layer, achieving the same result.
def layers_to_prune():
  # Knowing that the first layer is the Dense layer.
  return {model.layers[0]: 'default'}

def apply_pruning_to_first(layer):
  if layer in layers_to_prune():
    return tfmot.sparsity.keras.prune_low_magnitude(layer)
  return layer

pruned_first_layer_model = tf.keras.models.clone_model(
    model,
    clone_function = apply_pruning_to_first
)

print("\n")
print("pruned_first_layer")
pruned_first_layer_model.summary()

pruned_dense_layers
Model: "sequential_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_dense_6  (None, 20)                822       
_________________________________________________________________
flatten_6 (Flatten)          (None, 20)                0         
Total params: 822
Trainable params: 420
Non-trainable params: 402
_________________________________________________________________


pruned_first_layer
Model: "sequential_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_dense_6  (None, 20)                822       
_________________________________________________________________
flatten_6 (Flatten)          (None, 20)                0         
Total params: 822
Trainable params: 420
Non-trainable params: 402
_________________________________________________________

#### More readable but potentially less accurate

This is not compatible with using pre-trained weights, which is why it may be less accurate.

**More**: the [`prune_low_magnitude`](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/prune_low_magnitude) API docs provide details on how to vary the pruning configuration per layer.

Functional example

In [0]:
i = tf.keras.Input(shape=(20,))
x = tfmot.sparsity.keras.prune_low_magnitude(tf.keras.layers.Dense(10))(i)
o = tf.keras.layers.Flatten()(x)

pruned_model = tf.keras.Model(inputs=i, outputs=o)

pruned_model.summary()

Sequential example

In [0]:
pruned_model = tf.keras.Sequential([
  tfmot.sparsity.keras.prune_low_magnitude(tf.keras.layers.Dense(20, input_shape=input_shape)),
  tf.keras.layers.Flatten()
])

pruned_model.summary()

## Prune layers in Subclassed model

**Note**: using pre-trained weights is not supported yet.

**Tips** for better model accuracy:
* Trying pruning the later layers instead of the first layers
* Avoid pruning critical layers (e.g. attention mechanism). 

**More**: the [`prune_low_magnitude`](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/prune_low_magnitude) API docs provide details on how to vary the pruning configuration per layer.

In [0]:
class MyPrunedModel(tf.keras.Model):
  def __init__(self):
    super(MyPrunedModel, self).__init__()
    self.dense = tfmot.sparsity.keras.prune_low_magnitude(tf.keras.layers.Dense(10))
    self.flatten = tf.keras.layers.Flatten()
    self.dense2 = tfmot.sparsity.keras.prune_low_magnitude(
        tf.keras.Sequential([tf.keras.layers.Dense(10)])
    )

  def call(self, inputs):
    x = self.dense(inputs)
    x = self.flatten(x)
    return self.dense2(x)

pruned_model = MyPrunedModel()

input_shape = (None, 20)
pruned_model.build(input_shape)

pruned_model.summary()

## Prune custom Keras layer or prune different weights from API default 

**Common mistake:** pruning the bias usually harms model accuracy too much.


In [0]:
class MyDenseLayer(tf.keras.layers.Dense, tfmot.sparsity.keras.PrunableLayer):

  def get_prunable_weights(self):
    # Prune bias also, though that usually harms model accuracy too much.
    return [self.kernel, self.bias]

class MyDenseLayer2(tf.keras.layers.Dense, tfmot.sparsity.keras.PrunableLayer):

  def get_prunable_weights(self):
    # Prune nothing.
    return []

# Train model

## Model.fit

In [0]:
# See "Define model" on navigation sidebar for how to prune this model
# in more other ways.
model = setup_model()
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(model)

log_dir = tempfile.mkdtemp()
callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    # Log sparsity and other metrics in Tensorboard.
    tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir)
]

pruned_model.compile(
      loss=tf.keras.losses.categorical_crossentropy,
      optimizer='adam',
      metrics=['accuracy']
)

pruned_model.fit(
    x_train,
    y_train,
    callbacks=callbacks
)

%tensorboard --logdir={log_dir}


## Custom training loop

In [0]:
# See "Define model" on navigation sidebar for how to prune this model
# in more other ways.
model = setup_model()
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(model)

# Boilerplate
loss = tf.keras.losses.categorical_crossentropy
optimizer = tf.keras.optimizers.Adam()
log_dir = tempfile.mkdtemp()
unused_arg = -1
epochs = 1
batches = 1 # example is hardcoded so that the number of batches cannot change.

# Non-boilerplate.
pruned_model.optimizer = optimizer
step_callback = tfmot.sparsity.keras.UpdatePruningStep()
step_callback.set_model(pruned_model)
log_callback = tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir) # Log sparsity and other metrics in Tensorboard.
log_callback.set_model(pruned_model)

step_callback.on_train_begin() # run pruning callback
for _ in range(epochs):
  for _ in range(batches):
    step_callback.on_train_batch_begin(batch=unused_arg) # run pruning callback
    
    with tf.GradientTape() as tape:
      logits = pruned_model(x_train, training=True)
      loss_value = loss(y_train, logits)
      grads = tape.gradient(loss_value, pruned_model.trainable_variables)
      optimizer.apply_gradients(zip(grads, pruned_model.trainable_variables))

  step_callback.on_epoch_end(batch=unused_arg) # run pruning callback
  log_callback.on_epoch_end(batch=unused_arg) # run pruning callback
 

%tensorboard --logdir={log_dir}

## Improve pruned model accuracy


First, look at the [prune_low_magnitude
API docs](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/prune_low_magnitude)
to understand what a pruning schedule is and the math of
each type of pruning schedule.

**Tips**: 

* Have a learning rate that's not too high or too low when the model is pruning. Consider the [pruning schedule](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/PruningSchedule) to be a hyperparameter.

* As a quick test, try running an experiment where you prune a model to the final sparsity with begin step 0 with a
[ConstantSparsity](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/ConstantSparsity) schedule. You might get lucky with good results.

* Do not prune very frequently to give the model time to recover. The [pruning schedule](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/PruningSchedule) provides a decent default frequency.

* For general ideas to improve model accuracy,
find your use case(s) under "Define model" on the navigation sidebar and see if there are tips.

# Checkpointing and Deserialization

**Your Use Case:**

* You cannot do checkpointing with Keras HDF5 weights since we need to preserve the step.

* This code is only needed for the HDF5 model format (not HDF5 weights or other formats).

In [0]:
# See "Define model" on navigation sidebar for how to prune this model
# in more other ways.
model = setup_model()
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(model)

_, keras_model_file = tempfile.mkstemp('.h5')
# Saving the optimizer is necessary for checkpointing (True is the default).
pruned_model.save(keras_model_file, include_optimizer=True)

with tfmot.sparsity.keras.prune_scope():
  loaded_model = tf.keras.models.load_model(keras_model_file)

loaded_model.summary()

# Deployment

## Export model with size compression

**Common mistake**: both `strip_pruning` and applying a standard compression algorithm (e.g. via gzip) are necessary to see the compression
benefits of pruning.

In [0]:
# See "Define model" and "Train model" on navigation sidebar for how to define
# and train this model in other ways.
model = setup_model()
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(model)

pruned_model.compile(
      loss=tf.keras.losses.categorical_crossentropy,
      optimizer='adam',
      metrics=['accuracy']
)

pruned_model.fit(
    x_train,
    y_train,
    callbacks=[tfmot.sparsity.keras.UpdatePruningStep()]
)

final_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

print("final model")
final_model.summary()

print("\n")
print("Size of gzipped pruned model without stripping: %.2f bytes" % (get_gzipped_model_size(pruned_model)))
print("Size of gzipped pruned model with stripping: %.2f bytes" % (get_gzipped_model_size(final_model)))

## Hardware-specific optimizations

Once the framework [enables pruning to improve latency]((https://github.com/tensorflow/model-optimization/issues/173)), using block sparsity can improve latency for certain hardware. For a target model accuracy, latency can still improve despite the fact that increasing the block size will
decrease the peak sparsity %.

In [0]:
model = setup_model()

# For using intrinsics on a CPU with 128-bit registers, together with 8-bit
# quantized weights, a 1x16 block size is nice because the block perfectly
# fits into the register.
pruning_params = {'block_size': [1, 16]}
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)

pruned_model.summary()