Skip to content

Commit

Permalink
Refactor and update TFLiteQuantizeRegistry.
Browse files Browse the repository at this point in the history
Refactor TFLiteQuantizeRegistry to make management of
various classes and params easier.

Also, update the Registry by removing unsupported classes,
and updating properties of supported classes.

PiperOrigin-RevId: 276577345
  • Loading branch information
nutsiepully authored and tensorflower-gardener committed Oct 24, 2019
1 parent 3e65026 commit 452b898
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 196 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,32 @@
QuantizeProvider = quantize_provider.QuantizeProvider


class _QuantizeInfo(object):
"""QuantizeInfo."""

def __init__(self,
layer_type,
weight_attrs,
activation_attrs,
quantize_output=False):
"""QuantizeInfo.
Args:
layer_type: Type of keras layer.
weight_attrs: List of quantizable weight attributes of layer.
activation_attrs: List of quantizable activation attributes of layer.
quantize_output: Bool. Should we quantize the output of the layer.
"""
self.layer_type = layer_type
self.weight_attrs = weight_attrs
self.activation_attrs = activation_attrs
self.quantize_output = quantize_output


def _no_quantize(layer_type):
return _QuantizeInfo(layer_type, [], [], False)


class _RNNHelper(object):
"""Helper functions for working with RNNs."""

Expand All @@ -45,191 +71,118 @@ def _get_rnn_cells(self, rnn_layer):
class TFLiteQuantizeRegistry(quantize_registry.QuantizeRegistry, _RNNHelper):
"""QuantizationRegistry for built-in Keras classes for TFLite scheme."""

# Layer Attribute definitions.
# TODO(pulkitb): Double check all attributes used are correct.

_LAYERS_WEIGHTS_MAP = {
layers.advanced_activations.ELU: [],
layers.advanced_activations.LeakyReLU: [],
layers.advanced_activations.ReLU: [],
layers.advanced_activations.Softmax: [],
layers.advanced_activations.ThresholdedReLU: [],
layers.convolutional.Conv1D: ['kernel'],
layers.convolutional.Conv2D: ['kernel'],
layers.convolutional.Conv2DTranspose: ['kernel'],
layers.convolutional.Conv3D: ['kernel'],
layers.convolutional.Conv3DTranspose: ['kernel'],
layers.convolutional.Cropping1D: [],
layers.convolutional.Cropping2D: [],
layers.convolutional.Cropping3D: [],
layers.convolutional.DepthwiseConv2D: [],
layers.convolutional.SeparableConv1D: ['pointwise_kernel'],
layers.convolutional.SeparableConv2D: ['pointwise_kernel'],
layers.convolutional.UpSampling1D: [],
layers.convolutional.UpSampling2D: [],
layers.convolutional.UpSampling3D: [],
layers.convolutional.ZeroPadding1D: [],
layers.convolutional.ZeroPadding2D: [],
layers.convolutional.ZeroPadding3D: [],
layers.core.Activation: [],
layers.core.ActivityRegularization: [],
layers.core.Dense: ['kernel'],
layers.core.Dropout: [],
layers.core.Flatten: [],
layers.core.Lambda: [],
layers.core.Masking: [],
layers.core.Permute: [],
layers.core.RepeatVector: [],
layers.core.Reshape: [],
layers.core.SpatialDropout1D: [],
layers.core.SpatialDropout2D: [],
layers.core.SpatialDropout3D: [],
layers.embeddings.Embedding: ['embeddings'],
layers.local.LocallyConnected1D: ['kernel'],
layers.local.LocallyConnected2D: ['kernel'],
layers.merge.Add: [],
layers.merge.Average: [],
layers.merge.Concatenate: [],
layers.merge.Dot: [],
layers.merge.Maximum: [],
layers.merge.Minimum: [],
layers.merge.Multiply: [],
layers.merge.Subtract: [],
layers.noise.AlphaDropout: [],
layers.noise.GaussianDropout: [],
layers.noise.GaussianNoise: [],
layers.normalization.BatchNormalization: [],
layers.normalization.LayerNormalization: [],
layers.pooling.AveragePooling1D: [],
layers.pooling.AveragePooling2D: [],
layers.pooling.AveragePooling3D: [],
layers.pooling.GlobalAveragePooling1D: [],
layers.pooling.GlobalAveragePooling2D: [],
layers.pooling.GlobalAveragePooling3D: [],
layers.pooling.GlobalMaxPooling1D: [],
layers.pooling.GlobalMaxPooling2D: [],
layers.pooling.GlobalMaxPooling3D: [],
layers.pooling.MaxPooling1D: [],
layers.pooling.MaxPooling2D: [],
layers.pooling.MaxPooling3D: [],
# TODO(tf-mot): if more transforms handle quantization instead of using
# wrapper, add way for transforms to indicate that without modifying
# registry.
conv_batchnorm._ConvBatchNorm2D: [], # pylint: disable=protected-access
conv_batchnorm._DepthwiseConvBatchNorm2D: [], # pylint: disable=protected-access
}

_LAYERS_ACTIVATIONS_MAP = {
layers.advanced_activations.ELU: [],
layers.advanced_activations.LeakyReLU: [],
layers.advanced_activations.ReLU: [],
layers.advanced_activations.Softmax: [],
layers.advanced_activations.ThresholdedReLU: [],
layers.convolutional.Conv1D: ['activation'],
layers.convolutional.Conv2D: ['activation'],
layers.convolutional.Conv2DTranspose: ['activation'],
layers.convolutional.Conv3D: ['activation'],
layers.convolutional.Conv3DTranspose: ['activation'],
layers.convolutional.Cropping1D: [],
layers.convolutional.Cropping2D: [],
layers.convolutional.Cropping3D: [],
layers.convolutional.DepthwiseConv2D: [],
layers.convolutional.SeparableConv1D: ['activation'],
layers.convolutional.SeparableConv2D: ['activation'],
layers.convolutional.UpSampling1D: [],
layers.convolutional.UpSampling2D: [],
layers.convolutional.UpSampling3D: [],
layers.convolutional.ZeroPadding1D: [],
layers.convolutional.ZeroPadding2D: [],
layers.convolutional.ZeroPadding3D: [],
layers.core.Activation: [],
layers.core.ActivityRegularization: [],
layers.core.Dense: ['activation'],
layers.core.Dropout: [],
layers.core.Flatten: [],
layers.core.Lambda: [],
layers.core.Masking: [],
layers.core.Permute: [],
layers.core.RepeatVector: [],
layers.core.Reshape: [],
layers.core.SpatialDropout1D: [],
layers.core.SpatialDropout2D: [],
layers.core.SpatialDropout3D: [],
layers.embeddings.Embedding: [],
layers.local.LocallyConnected1D: ['activation'],
layers.local.LocallyConnected2D: ['activation'],
layers.merge.Add: [],
layers.merge.Average: [],
layers.merge.Concatenate: [],
layers.merge.Dot: [],
layers.merge.Maximum: [],
layers.merge.Minimum: [],
layers.merge.Multiply: [],
layers.merge.Subtract: [],
layers.noise.AlphaDropout: [],
layers.noise.GaussianDropout: [],
layers.noise.GaussianNoise: [],
layers.normalization.BatchNormalization: [],
layers.normalization.LayerNormalization: [],
layers.pooling.AveragePooling1D: [],
layers.pooling.AveragePooling2D: [],
layers.pooling.AveragePooling3D: [],
layers.pooling.GlobalAveragePooling1D: [],
layers.pooling.GlobalAveragePooling2D: [],
layers.pooling.GlobalAveragePooling3D: [],
layers.pooling.GlobalMaxPooling1D: [],
layers.pooling.GlobalMaxPooling2D: [],
layers.pooling.GlobalMaxPooling3D: [],
layers.pooling.MaxPooling1D: [],
layers.pooling.MaxPooling2D: [],
layers.pooling.MaxPooling3D: [],
conv_batchnorm._ConvBatchNorm2D: [], # pylint: disable=protected-access
conv_batchnorm._DepthwiseConvBatchNorm2D: [], # pylint: disable=protected-access
}

_RNN_CELLS_WEIGHTS_MAP = {
layers.recurrent.GRUCell: ['kernel', 'recurrent_kernel'],
layers.recurrent.LSTMCell: ['kernel', 'recurrent_kernel'],
layers.recurrent.PeepholeLSTMCell: ['kernel', 'recurrent_kernel'],
layers.recurrent.SimpleRNNCell: ['kernel', 'recurrent_kernel'],
}

_RNN_CELLS_ACTIVATIONS_MAP = {
layers.recurrent.GRUCell: ['activation', 'recurrent_activation'],
layers.recurrent.LSTMCell: ['activation', 'recurrent_activation'],
layers.recurrent.PeepholeLSTMCell: ['activation', 'recurrent_activation'],
layers.recurrent.SimpleRNNCell: ['activation'],
}

_RNN_LAYERS = {
layers.recurrent.GRU,
layers.recurrent.LSTM,
layers.recurrent.RNN,
layers.recurrent.SimpleRNN,
}

# Support functions.

def _is_supported_non_rnn_layer(self, layer):
return layer.__class__ in self._LAYERS_WEIGHTS_MAP

def _is_supported_rnn_layer(self, layer):
return layer.__class__ in self._RNN_LAYERS

def _is_supported_rnn_cell(self, layer):
return layer.__class__ in self._RNN_CELLS_WEIGHTS_MAP

def _weight_attrs(self, layer):
return self._LAYERS_WEIGHTS_MAP[layer.__class__]

def _activation_attrs(self, layer):
return self._LAYERS_ACTIVATIONS_MAP[layer.__class__]

def _weight_attrs_rnn_cell(self, cell):
return self._RNN_CELLS_WEIGHTS_MAP[cell.__class__]
_LAYER_QUANTIZE_INFO = [

# Activation Layers
_QuantizeInfo(layers.advanced_activations.ReLU, [], [], True),
_QuantizeInfo(layers.advanced_activations.Softmax, [], [], True),
# Enable once verified.
# layers.advanced_activations.ELU,
# layers.advanced_activations.LeakyReLU,
# layers.advanced_activations.PReLU,
# layers.advanced_activations.ThresholdedReLU,

# Convolution Layers
_QuantizeInfo(layers.convolutional.Conv1D, ['kernel'], ['activation']),
_QuantizeInfo(layers.convolutional.Conv2D, ['kernel'], ['activation']),
_QuantizeInfo(layers.convolutional.Conv3D, ['kernel'], ['activation']),
# TODO(pulkitb): Verify Transpose layers.
_QuantizeInfo(layers.convolutional.Conv2DTranspose,
['kernel'], ['activation']),
_QuantizeInfo(layers.convolutional.Conv3DTranspose,
['kernel'], ['activation']),
_no_quantize(layers.convolutional.Cropping1D),
_no_quantize(layers.convolutional.Cropping2D),
_no_quantize(layers.convolutional.Cropping3D),
_QuantizeInfo(layers.convolutional.DepthwiseConv2D,
['depthwise_kernel'], ['activation']),
_no_quantize(layers.convolutional.UpSampling1D),
_no_quantize(layers.convolutional.UpSampling2D),
_no_quantize(layers.convolutional.UpSampling3D),
_no_quantize(layers.convolutional.ZeroPadding1D),
_no_quantize(layers.convolutional.ZeroPadding2D),
_no_quantize(layers.convolutional.ZeroPadding3D),
# Enable once verified.
# layers.convolutional.SeparableConv1D,
# layers.convolutional.SeparableConv2D,

# Core Layers
_QuantizeInfo(layers.core.Activation, [], ['activation']),
_no_quantize(layers.core.ActivityRegularization),
_QuantizeInfo(layers.core.Dense, ['kernel'], ['activation']),
_no_quantize(layers.core.Dropout),
_no_quantize(layers.core.Flatten),
_no_quantize(layers.core.Masking),
_no_quantize(layers.core.Permute),
_no_quantize(layers.core.RepeatVector),
_no_quantize(layers.core.Reshape),
_no_quantize(layers.core.SpatialDropout1D),
_no_quantize(layers.core.SpatialDropout2D),
_no_quantize(layers.core.SpatialDropout3D),
# layers.core.Lambda needs custom handling by the user.

# Pooling Layers
_QuantizeInfo(layers.pooling.AveragePooling1D, [], [], True),
_QuantizeInfo(layers.pooling.AveragePooling2D, [], [], True),
_QuantizeInfo(layers.pooling.AveragePooling3D, [], [], True),
_QuantizeInfo(layers.pooling.GlobalAveragePooling1D, [], [], True),
_QuantizeInfo(layers.pooling.GlobalAveragePooling2D, [], [], True),
_QuantizeInfo(layers.pooling.GlobalAveragePooling3D, [], [], True),
_no_quantize(layers.pooling.GlobalMaxPooling1D),
_no_quantize(layers.pooling.GlobalMaxPooling2D),
_no_quantize(layers.pooling.GlobalMaxPooling3D),
_no_quantize(layers.pooling.MaxPooling1D),
_no_quantize(layers.pooling.MaxPooling2D),
_no_quantize(layers.pooling.MaxPooling3D),

# TODO(pulkitb): Verify Locally Connected layers.
_QuantizeInfo(layers.local.LocallyConnected1D,
['kernel'], ['activation']),
_QuantizeInfo(layers.local.LocallyConnected2D,
['kernel'], ['activation']),

# Enable once verified with TFLite behavior.
# layers.embeddings.Embedding: ['embeddings'],
# layers.normalization.BatchNormalization: [],

# Merge layers to be added.

# RNN Cells
# TODO(pulkitb): Verify RNN layers behavior.
_QuantizeInfo(layers.recurrent.GRUCell, ['kernel', 'recurrent_kernel'],
['activation', 'recurrent_activation']),
_QuantizeInfo(layers.recurrent.LSTMCell, ['kernel', 'recurrent_kernel'],
['activation', 'recurrent_activation']),
_QuantizeInfo(layers.recurrent.PeepholeLSTMCell,
['kernel', 'recurrent_kernel'],
['activation', 'recurrent_activation']),
_QuantizeInfo(layers.recurrent.SimpleRNNCell,
['kernel', 'recurrent_kernel'],
['activation', 'recurrent_activation']),

# TODO(tf-mot): Move layers out once Transforms indicate quantization.
_no_quantize(conv_batchnorm._ConvBatchNorm2D), # pylint: disable=protected-access
_no_quantize(conv_batchnorm._DepthwiseConvBatchNorm2D), # pylint: disable=protected-access
]

def __init__(self):
self._layer_quantize_map = {}
for quantize_info in self._LAYER_QUANTIZE_INFO:
self._layer_quantize_map[quantize_info.layer_type] = quantize_info

def _is_supported_layer(self, layer):
return layer.__class__ in self._layer_quantize_map

def _is_rnn_layer(self, layer):
return layer.__class__ in {
layers.recurrent.GRU,
layers.recurrent.LSTM,
layers.recurrent.RNN,
layers.recurrent.SimpleRNN,
}

def _activation_attrs_rnn_cell(self, cell):
return self._RNN_CELLS_ACTIVATIONS_MAP[cell.__class__]
def _get_quantize_info(self, layer):
return self._layer_quantize_map[layer.__class__]

# Interface functions.

Expand All @@ -245,14 +198,13 @@ def supports(self, layer):
True/False whether the layer type is supported.
"""
if self._is_supported_non_rnn_layer(layer):
if self._is_supported_layer(layer):
return True

if self._is_supported_rnn_layer(layer):
if self._is_rnn_layer(layer):
for rnn_cell in self._get_rnn_cells(layer):
# All cells in the RNN layer should be supported. It's possible to use
# custom cells in an RNN layer.
if not self._is_supported_rnn_cell(rnn_cell):
# All cells in the RNN layer should be supported.
if not self._is_supported_layer(rnn_cell):
return False
return True

Expand All @@ -274,18 +226,23 @@ def get_quantize_provider(self, layer):
'can use `QuantizeProvider` to specify a behavior for your layer.'
.format(layer.__class__))

if self._is_supported_non_rnn_layer(layer):
if self._is_supported_layer(layer):
quantize_info = self._get_quantize_info(layer)
return TFLiteQuantizeProvider(
self._weight_attrs(layer), self._activation_attrs(layer))
quantize_info.weight_attrs, quantize_info.activation_attrs,
quantize_info.quantize_output)

if self._is_supported_rnn_layer(layer):
if self._is_rnn_layer(layer):
weight_attrs = []
activation_attrs = []
for rnn_cell in self._get_rnn_cells(layer):
weight_attrs.append(self._weight_attrs_rnn_cell(rnn_cell))
activation_attrs.append(self._activation_attrs_rnn_cell(rnn_cell))
quantize_info = self._get_quantize_info(rnn_cell)
weight_attrs.append(quantize_info.weight_attrs)
activation_attrs.append(quantize_info.activation_attrs)

return TFLiteQuantizeProviderRNN(weight_attrs, activation_attrs)
# Result quantization for RNN isn't straight-forward like regular layers.
# To implement during full RNN support.
return TFLiteQuantizeProviderRNN(weight_attrs, activation_attrs, False)

# Should never come here.
raise ValueError('Invalid Layer type {}'.format(layer.__class__))
Expand All @@ -294,7 +251,7 @@ def get_quantize_provider(self, layer):
class TFLiteQuantizeProvider(QuantizeProvider):
"""QuantizeProvider for non recurrent Keras layers."""

def __init__(self, weight_attrs, activation_attrs, quantize_output=False):
def __init__(self, weight_attrs, activation_attrs, quantize_output):
self.weight_attrs = weight_attrs
self.activation_attrs = activation_attrs
self.quantize_output = quantize_output
Expand Down
Loading

0 comments on commit 452b898

Please sign in to comment.