Skip to content

Commit

Permalink
Implement output quantization of tensors.
Browse files Browse the repository at this point in the history
QuantizeWrapper now supports quantizing the output result
from the layers it wraps. Tested by using an activation
layer.

PiperOrigin-RevId: 276582600
  • Loading branch information
nutsiepully authored and tensorflower-gardener committed Oct 24, 2019
1 parent 452b898 commit 681c284
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,26 @@ def build(self, input_shape):

self._quantize_activations.append(quantize_activation)

self._output_quantizers = self.quantize_provider.get_output_quantizers(
self.layer)
if self._output_quantizers:
self._output_min_max = self._add_range_weights('output')

def compute_output_shape(self, input_shape):
return self.layer.compute_output_shape(self.layer.input_shape)

def _dict_vars(self, min_var, max_var):
return {'min_var': min_var, 'max_var': max_var}

def _make_quantizer_fn(self, quantizer, x, training, min_var, max_var):
"""Use currying to return True/False specialized fns to the cond."""

def quantizer_fn():
return quantizer(x, self.optimizer_step, training,
**self._dict_vars(min_var, max_var))

return quantizer_fn

def call(self, inputs, training=None):
if training is None:
training = K.learning_phase()
Expand All @@ -125,21 +139,12 @@ def call(self, inputs, training=None):

quantized_weights = []
for unquantized_weight, quantizer, min_var, max_var in self._weight_vars:

def make_quantizer_fn(training):
"""Use currying to return True/False specialized fns to the cond."""

def quantizer_fn(unquantized_weight=unquantized_weight,
quantizer=quantizer,
min_var=min_var,
max_var=max_var):
return quantizer(unquantized_weight, self.optimizer_step, training,
**self._dict_vars(min_var, max_var))

return quantizer_fn

quantized_weight = tf_utils.smart_cond(
training, make_quantizer_fn(True), make_quantizer_fn(False))
training,
self._make_quantizer_fn(
quantizer, unquantized_weight, True, min_var, max_var),
self._make_quantizer_fn(
quantizer, unquantized_weight, False, min_var, max_var))
quantized_weights.append(quantized_weight)

self.quantize_provider.set_quantize_weights(self.layer, quantized_weights)
Expand All @@ -153,7 +158,23 @@ def quantizer_fn(unquantized_weight=unquantized_weight,
self.quantize_provider.set_quantize_activations(
self.layer, self._quantize_activations)

return self.layer.call(inputs)
outputs = self.layer.call(inputs)

if not self._output_quantizers:
return outputs

# Assuming outputs is a single tensor. There might be some rare layers
# where this is not true. Handle them when enabling such a layer.
if isinstance(outputs, list) or isinstance(outputs, tuple):
raise RuntimeError('Multiple output tensors not handled currently.')

output_quantizer = self._output_quantizers[0]
return tf_utils.smart_cond(
training,
self._make_quantizer_fn(output_quantizer, outputs, True,
*self._output_min_max),
self._make_quantizer_fn(output_quantizer, outputs, False,
*self._output_min_max))

def get_config(self):
base_config = super(QuantizeWrapper, self).get_config()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,22 @@ def _get_quantized_weights(shape, dtype): # pylint: disable=unused-argument
model.predict(inputs), -6.0, 6.0, num_bits=8, narrow_range=False)
self.assertAllClose(expected_output, quantized_model.predict(inputs))

def testQuantizesOutputsFromLayer(self):
# TODO(pulkitb): Increase coverage by adding other output quantize layers
# such as AveragePooling etc.

layer = layers.ReLU()
quantized_model = keras.Sequential([QuantizeWrapper(
layers.ReLU(),
quantize_provider=self.quantize_registry.get_quantize_provider(layer))])

model = keras.Sequential([layers.ReLU()])

inputs = np.random.rand(1, 2, 1)
expected_output = tf.fake_quant_with_min_max_vars(
model.predict(inputs), -6.0, 6.0, num_bits=8, narrow_range=False)
self.assertAllClose(expected_output, quantized_model.predict(inputs))

def testSerializationQuantizeWrapper(self):
input_shape = (2,)
layer = keras.layers.Dense(3)
Expand Down

0 comments on commit 681c284

Please sign in to comment.