Welcome to the comprehensive guide for Keras quantization-aware training.

Use this page to quickly find the APIs you need for your use case via the navigation sidebar. Once you know which APIs you need, find the parameters and the low-level details in the
[API docs](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity). 

*  If you want to see the benefits of quantization-aware training and what's supported, see the [overview](https://www.tensorflow.org/model_optimization/guide/quantization/training.md). 
*  For a single end-to-end example, see the [quantization-aware training example](https://www.tensorflow.org/model_optimization/guide/quantization/quantization_aware_training_guide.md).

The following corresponds to the navigation sidebar:

You will either want to **deploy with quantization** or **research quantization**.
* Without a quantization-aware model, you must **define** the model. Training
  the model is standard Keras.
* For Keras HDF5 models only, you need special **checkpointing and deserialization**.

Run the boilerplate code below once to start.

# Boilerplate: run once per Colab session

In [0]:
# Run this section once per Colab session.

! pip uninstall -y tensorflow
! pip install -q tf-nightly==2.2.0.dev20200305
! pip install -q --extra-index-url=https://test.pypi.org/simple/ tensorflow-model-optimization==0.3.0.dev3

import tempfile

import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot

import tempfile

input_shape = [20]
x_train = np.random.randn(1, 20).astype(np.float32)
y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=20)

def setup_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Dense(20, input_shape=input_shape),
      tf.keras.layers.Flatten()
  ])
  return model

def setup_pretrained_weights():
  model= setup_model()

  model.compile(
      loss=tf.keras.losses.categorical_crossentropy,
      optimizer='adam',
      metrics=['accuracy']
  )
    
  model.fit(x_train, y_train)

  _, pretrained_weights = tempfile.mkstemp('.h5')

  model.save_weights(pretrained_weights)

  return pretrained_weights

def setup_pretrained_model():
  model = setup_model()
  pretrained_weights = setup_pretrained_weights()
  model.load_weights(pretrained_weights)
  return model
  
setup_model()
pretrained_weights = setup_pretrained_weights()



# Deploy with quantization with defaults

By creating models in the following fashion, there is an available path to deployment to backends listed in the [overview page](https://www.tensorflow.org/model_optimization/guide/quantization/training.md).

## Define quantization-aware model

### Quantize all layers in Functional and Sequential models

**Tips** for better model accuracy:

* Try "Quantize some layers" on the navigation sidebar to skip quantizing the layers that reduce accuracy the most
and focus on the ones that benefit latency the most.
* Generally better to start from pre-trained weights.


In [0]:
model = setup_model()
model.load_weights(pretrained_weights) # optional but recommended.

quant_aware_model = tfmot.quantization.keras.quantize_model(model)

quant_aware_model.summary()

Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
quant_dense_4 (QuantizeWrapp (None, 20)                425       
_________________________________________________________________
quant_flatten_4 (QuantizeWra (None, 20)                1         
Total params: 426
Trainable params: 420
Non-trainable params: 6
_________________________________________________________________


### Quantize subset of layers in Functional and Sequential models

During deployment, the non-quantized layers would execute in float. 

In the example below, we achieve the same result three ways: 
* quantizing all Dense layers
* quantizing the only Dense layer
* skipping quantization for the only non-Dense layer.

**Tips** for better model accuracy:

* Generally better to start from pre-trained weights.
* Try quantizing the later layers instead of the first layers.
* Avoid quantizing critical layers (e.g. attention mechanism). 

In [0]:
model = setup_model()
model.load_weights(pretrained_weights) # optional but recommended

## Version 1: Quantize all dense layers.
def apply_quantization_to_dense(layer):
  if isinstance(layer, tf.keras.layers.Dense):
    return tfmot.quantization.keras.quantize_annotate_layer(layer)
  return layer
    
annotated_model = tf.keras.models.clone_model(
    model, 
    clone_function=apply_quantization_to_dense,
)

quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)

print("quantize_dense_layers")
quant_aware_model.summary()

## Version 2: Quantize only the first layer.
def layers_to_quantize():
  # Knowing that the first layer is the Dense layer.
  return {model.layers[0]: 'default'}

def apply_quantization_to_first(layer):
  if layer in layers_to_quantize():
    return tfmot.quantization.keras.quantize_annotate_layer(layer)
  return layer

annotated_model_first_layer = tf.keras.models.clone_model(
    model,
    clone_function = apply_quantization_to_first
)

quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)

print("\n")
print("quantize_first_layer")
quant_aware_model.summary()

## Version 3: Skip quantizing only the last layer. 
def layers_to_skip():
  # Knowing that the last layer is the only non-Dense layer.
  return {model.layers[1]: 'skip'}

def skip_quantizing_one_layer(layer):
  if layer in layers_to_skip():
    return layer
  return tfmot.quantization.keras.quantize_annotate_layer(layer)

annotated_model_first_layer = tf.keras.models.clone_model(
    model,
    clone_function = skip_quantizing_one_layer,
)

quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)

print("\n")
print("skip_quantizing_one_layer")
quant_aware_model.summary()

quantize_dense_layers
Model: "sequential_14"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
quant_dense_16 (QuantizeWrap (None, 20)                425       
_________________________________________________________________
flatten_15 (Flatten)         (None, 20)                0         
Total params: 425
Trainable params: 420
Non-trainable params: 5
_________________________________________________________________


quantize_first_layer
Model: "sequential_14"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
quant_dense_16 (QuantizeWrap (None, 20)                425       
_________________________________________________________________
flatten_15 (Flatten)         (None, 20)                0         
Total params: 425
Trainable params: 420
Non-trainable params: 5
_______________________________________________________

#### More readable but potentially lower model accuracy

This is not compatible with using pre-trained weights, which is why it may be less accurate than the above examples.

**Functional example**

In [0]:
i = tf.keras.Input(shape=(20,))
x = tfmot.quantization.keras.quantize_annotate_layer(tf.keras.layers.Dense(10))(i)
o = tf.keras.layers.Flatten()(x)

annotated_model = tf.keras.Model(inputs=i, outputs=o)
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)

# For deployment, the tool adds `QuantizeLayer` after `InputLayer` so that the 
# quantized model can take in float inputs instead of only uint8.
quant_aware_model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 20)]              0         
_________________________________________________________________
quantize_layer (QuantizeLaye (None, 20)                3         
_________________________________________________________________
quant_dense_6 (QuantizeWrapp (None, 10)                215       
_________________________________________________________________
flatten_6 (Flatten)          (None, 10)                0         
Total params: 218
Trainable params: 210
Non-trainable params: 8
_________________________________________________________________


**Sequential example**


In [0]:
annotated_model = tf.keras.Sequential([
  tfmot.quantization.keras.quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=input_shape)),
  tf.keras.layers.Flatten()
])

quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)

quant_aware_model.summary()

True
Model: "sequential_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
quant_dense_7 (QuantizeWrapp (None, 20)                425       
_________________________________________________________________
flatten_7 (Flatten)          (None, 20)                0         
Total params: 425
Trainable params: 420
Non-trainable params: 5
_________________________________________________________________


## Create quantized model and deploy

See the documentation from the deployment backend that you are interested in using. 

As an example, this is how it's done for TFLite. 

In [0]:
model = setup_pretrained_model()
quant_aware_model = tfmot.quantization.keras.quantize_model(model)

# Typically you finetune the model first. 

converter = tf.lite.TFLiteConverter.from_keras_model(quant_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()



# Research quantization

For understanding, we recommend first reading the simpler
"Deploy with quantization with defaults" section.

There is no supported path to deployment here from defining the models
in the following ways.

## Define model

## Modify quantization parameters or parts of layer to quantize, or quantize custom Keras layer.



This example modifies the default quantization implementation for Dense. 

The weights now use 4-bits instead of 8-bits and the activation is no longer quantized. The rest of the model continues to use the default quantization implementation.

**Your use case**: handling custom Keras layers uses the same `QuantizeConfig` interface used to modify quantization for the built-in Dense layer.

**Common mistake:** quantizing the bias to fewer to 32-bits usually harms model accuracy too much.

TODO: change to usage pattern from "Quantize subset of layers" that supports pre-trained models while linking to "More readable ..." section.

In [0]:
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model

# TODO: change to .quantizers.
LastValueQuantizer = tfmot.quantization.keras.LastValueQuantizer
QuantizeConfig = tfmot.quantization.keras.QuantizeConfig

custom_object_scope = tf.keras.utils.custom_object_scope


# TODO: should make this num_bits change simpler. Proposal on simple
# changes we can do for this.
class DenseQuantizeConfig(QuantizeConfig):
    """Custom QuantizeConfig for Dense layer.

    The QuantizeConfig allows you to precisely choose,
      a. what to quantize in a layer and
      b. how to quantize it via the Quantizer
    """

    def get_weights_and_quantizers(self, layer):
      # Use 4-bits to quantized weights instead of 8.
      # TODO: make per_axis=True supported for LastValueQuantizer. 
      return [(layer.kernel, LastValueQuantizer(num_bits=4, symmetric=True, narrow_range=False, per_axis=False))]
    
    def get_activations_and_quantizers(self, layer):
      # Don't quantize the activation.
      return []

    def set_quantize_weights(self, layer, quantize_weights):
      layer.kernel = quantize_weights[0]

    def set_quantize_activations(self, layer, quantize_activations):
      return

    def get_output_quantizers(self, layer):
      return []

    def get_config(self): 
      return {}

model = quantize_annotate_model(tf.keras.Sequential([
   # Quantize Dense with custom implementation.
   quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=(20,)), DenseQuantizeConfig()),
   # Other layers use default quantization implementation via `quantize_annotate_model`
   tf.keras.layers.Flatten()
]))

with custom_object_scope({
      'DenseQuantizeConfig': DenseQuantizeConfig
}):
  quant_aware_model = tfmot.quantization.keras.quantize_apply(model)

quant_aware_model.summary()

Model: "sequential_10"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
quant_dense_12 (QuantizeWrap (None, 20)                423       
_________________________________________________________________
quant_flatten_11 (QuantizeWr (None, 20)                1         
Total params: 424
Trainable params: 420
Non-trainable params: 4
_________________________________________________________________


## Use custom quantization algorithm



For how to use `QuantizeConfig`, see the "Modify quantization parameters or parts of layer to quantize ..." section on the navigation sidebar, which this
increased flexibility is an extension of.

TODO: fix example. Our wrapper logic is hard coded to handle
only LastValueQuantizer and MovingAverageQuantizer and doesn't support
example below yet. 

In [0]:
# TODO: should still keep as Quantizer instead of quantizers.Quantizer.
Quantizer = tfmot.quantization.keras.Quantizer
QuantizeConfig = tfmot.quantization.keras.QuantizeConfig

class FixedRangeQuantizer(Quantizer):
  """Quantizer which keeps values between -1 and 1."""

  def build(self, tensor_shape, name, layer):
    # Not needed. No TensorFlow variables. 
    return

  def __call__(self, inputs, step, training, **kwargs):
    return tf.keras.backend.clip(inputs, -1.0, 1.0)

  def get_config():
    # Not needed. No __init__ parameters to serialize.
    return {}


# This custom Quantizer can now be used in a QuantizeConfig as specified above.
class DenseQuantizeConfig(QuantizeConfig):
    """Custom QuantizeConfig for Conv layer."""

    def get_weights_and_quantizers(self, layer):
      # Use FixedRangeQuantizer instead of default Quantizer.
      return [(layer.kernel, FixedRangeQuantizer())]
    
    def get_activations_and_quantizers(self, layer):
      # Keep defaults here.
      return [(layer.activation, MovingAverageQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False))]

    def set_quantize_weights(self, layer, quantize_weights):
      layer.kernel = quantize_weights[0]

    def set_quantize_activations(self, layer, quantize_activations):
      return

    def get_output_quantizers(self, layer):
      return []

    def get_config(self): 
      return {}

model = quantize_annotate_model(tf.keras.Sequential([
   # Quantize Dense with custom implementation.
   quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=(20,)), DenseQuantizeConfig()),
   # Other layers use default quantization implementation via `quantize_annotate_model`
   tf.keras.layers.Flatten()
]))

with custom_object_scope({
      'DenseQuantizeConfig': DenseQuantizeConfig
}):
  quant_aware_model = tfmot.quantization.keras.quantize_apply(model)

# TODO: test this with training also.

TypeError: ignored

# Checkpointing and Deserialization

**Your Use Case:** this code is only needed for the HDF5 model format (not HDF5 weights or other formats).

In [0]:
# See "Define model" on navigation sidebar for 
# how to define this model in other ways.
model = setup_model()
model.load_weights(pretrained_weights) # optional but recommended.
quant_aware_model = tfmot.quantization.keras.quantize_model(model)

_, keras_model_file = tempfile.mkstemp('.h5')

quant_aware_model.save(keras_model_file)

with tfmot.quantization.keras.quantize_scope():
  loaded_model = tf.keras.models.load_model(keras_model_file)

loaded_model.summary()





Model: "sequential_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
quant_dense_13 (QuantizeWrap (None, 20)                425       
_________________________________________________________________
quant_flatten_12 (QuantizeWr (None, 20)                1         
Total params: 426
Trainable params: 420
Non-trainable params: 6
_________________________________________________________________
