diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py index a9545a5b8..9875384f3 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py @@ -111,12 +111,26 @@ def build(self, input_shape): self._quantize_activations.append(quantize_activation) + self._output_quantizers = self.quantize_provider.get_output_quantizers( + self.layer) + if self._output_quantizers: + self._output_min_max = self._add_range_weights('output') + def compute_output_shape(self, input_shape): return self.layer.compute_output_shape(self.layer.input_shape) def _dict_vars(self, min_var, max_var): return {'min_var': min_var, 'max_var': max_var} + def _make_quantizer_fn(self, quantizer, x, training, min_var, max_var): + """Use currying to return True/False specialized fns to the cond.""" + + def quantizer_fn(): + return quantizer(x, self.optimizer_step, training, + **self._dict_vars(min_var, max_var)) + + return quantizer_fn + def call(self, inputs, training=None): if training is None: training = K.learning_phase() @@ -125,21 +139,12 @@ def call(self, inputs, training=None): quantized_weights = [] for unquantized_weight, quantizer, min_var, max_var in self._weight_vars: - - def make_quantizer_fn(training): - """Use currying to return True/False specialized fns to the cond.""" - - def quantizer_fn(unquantized_weight=unquantized_weight, - quantizer=quantizer, - min_var=min_var, - max_var=max_var): - return quantizer(unquantized_weight, self.optimizer_step, training, - **self._dict_vars(min_var, max_var)) - - return quantizer_fn - quantized_weight = tf_utils.smart_cond( - training, make_quantizer_fn(True), make_quantizer_fn(False)) + training, + self._make_quantizer_fn( + quantizer, unquantized_weight, True, min_var, max_var), + self._make_quantizer_fn( + quantizer, unquantized_weight, False, min_var, max_var)) quantized_weights.append(quantized_weight) self.quantize_provider.set_quantize_weights(self.layer, quantized_weights) @@ -153,7 +158,23 @@ def quantizer_fn(unquantized_weight=unquantized_weight, self.quantize_provider.set_quantize_activations( self.layer, self._quantize_activations) - return self.layer.call(inputs) + outputs = self.layer.call(inputs) + + if not self._output_quantizers: + return outputs + + # Assuming outputs is a single tensor. There might be some rare layers + # where this is not true. Handle them when enabling such a layer. + if isinstance(outputs, list) or isinstance(outputs, tuple): + raise RuntimeError('Multiple output tensors not handled currently.') + + output_quantizer = self._output_quantizers[0] + return tf_utils.smart_cond( + training, + self._make_quantizer_fn(output_quantizer, outputs, True, + *self._output_min_max), + self._make_quantizer_fn(output_quantizer, outputs, False, + *self._output_min_max)) def get_config(self): base_config = super(QuantizeWrapper, self).get_config() diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper_test.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper_test.py index b2e0332b4..8d628d178 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper_test.py @@ -145,6 +145,22 @@ def _get_quantized_weights(shape, dtype): # pylint: disable=unused-argument model.predict(inputs), -6.0, 6.0, num_bits=8, narrow_range=False) self.assertAllClose(expected_output, quantized_model.predict(inputs)) + def testQuantizesOutputsFromLayer(self): + # TODO(pulkitb): Increase coverage by adding other output quantize layers + # such as AveragePooling etc. + + layer = layers.ReLU() + quantized_model = keras.Sequential([QuantizeWrapper( + layers.ReLU(), + quantize_provider=self.quantize_registry.get_quantize_provider(layer))]) + + model = keras.Sequential([layers.ReLU()]) + + inputs = np.random.rand(1, 2, 1) + expected_output = tf.fake_quant_with_min_max_vars( + model.predict(inputs), -6.0, 6.0, num_bits=8, narrow_range=False) + self.assertAllClose(expected_output, quantized_model.predict(inputs)) + def testSerializationQuantizeWrapper(self): input_shape = (2,) layer = keras.layers.Dense(3)