Skip to content
Permalink
Browse files Browse the repository at this point in the history
Skip reordering dq-q patterns when the new quantization dimension is …
…unknown.

It also adds the canonicalization patterns to remove redundant TF reshape ops
to infer the quantization dimension from the user ops (e.g, conv2d, depthwise
conv2d) over those redundant TF reshape ops.

PiperOrigin-RevId: 463080509
  • Loading branch information
abattery authored and tensorflower-gardener committed Jul 25, 2022
1 parent 66b1a1f commit aa0b852
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 6 deletions.
14 changes: 8 additions & 6 deletions tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc
Expand Up @@ -168,12 +168,10 @@ quant::UniformQuantizedPerAxisType ResetAxisAndBroadcast(
BroadcastVector<int64_t>(shaped.getDimSize(quant_dim), zero_points)) {
return {};
}
} else if ((new_shape.size() == shape.size() + 1) && new_shape.back() == 1) {
// This is a trivial shift left, then we shift the quant_dim as well.
if (std::equal(shape.begin(), shape.end(), new_shape.begin()) &&
quant_dim == -1) {
quant_dim = shape.size() + quant_dim;
} else {
} else if ((new_shape.size() == shape.size() + 1) && new_shape.front() == 1) {
// Handle the [A, B, C] -> [1, A, B, C] reshape case.
if (!(std::equal(shape.begin(), shape.end(), new_shape.begin() + 1) &&
quant_dim == new_shape.size() - 1)) {
return {};
}
} else {
Expand Down Expand Up @@ -343,6 +341,10 @@ TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder,
// Reset the quantization dimensions if it is per-axis.
if (auto per_axis =
qtype.dyn_cast_or_null<quant::UniformQuantizedPerAxisType>()) {
// For the pass-through ops, we don't know which the dimension will be the
// new quantization dimension. Only if the new quantization dimension can
// be inferred, it is safe to reset the per-axis quantized type.
if (axis == -1) return {};
qtype =
ResetAxisAndBroadcast(source_type.getShape(), per_axis, target, axis);
}
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
Expand Up @@ -1370,6 +1370,8 @@ void PrepareTFPass::runOnOperation() {

patterns.add<RemoveIdentity>(ctx);
TFL::populateWithGenerated(patterns);
// Remove redundant reshape ops.
TF::ReshapeOp::getCanonicalizationPatterns(patterns, ctx);
// TODO(karimnosseir): Split to separate pass probably after
// deciding on long term plan for this optimization.
// This will allow optimizing any TF_Mul->TF_Conv in the graph
Expand Down Expand Up @@ -1399,6 +1401,8 @@ void PrepareTFPass::runOnOperation() {
ConvertRfftToRfft2d, RemoveIdentity>(ctx);
phase_2_patterns.add<ConvertTFConv2D, ConvertTFDepthwiseConv2dNative>(
ctx, allow_bf16_and_f16_type_legalization_);
// Remove redundant reshape ops.
TF::ReshapeOp::getCanonicalizationPatterns(phase_2_patterns, ctx);

(void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
}
Expand Down
38 changes: 38 additions & 0 deletions tensorflow/lite/python/lite_v2_test.py
Expand Up @@ -2311,6 +2311,44 @@ def testKerasFullyConnectedOutputShape3D(self):
list(output_details[0]['shape_signature']),
list(model.layers[-1].output_shape))

@test_util.run_v2_only
def testKerasConv2DTransposedWithMismatchQuantizedAxes(self):

class QuantConv2DTransposed(tf.keras.layers.Layer):

def build(self, input_shape):
self.kernel = self.add_weight('kernel', [3, 3, input_shape[-1], 24])

def call(self, inputs):
filters = tf.quantization.fake_quant_with_min_max_vars_per_channel(
self.kernel,
-3.0 * tf.ones([24]),
3.0 * tf.ones([24]),
narrow_range=True)
filters = tf.transpose(filters, (0, 1, 3, 2))
return tf.nn.conv2d_transpose(inputs, filters, [*inputs.shape[:-1], 24],
1)

inp = tf.keras.Input(shape=(6, 8, 48), batch_size=1)
x = tf.quantization.fake_quant_with_min_max_vars(
inp, -3.0, 3.0, narrow_range=True)
x = QuantConv2DTransposed()(x)
x = tf.quantization.fake_quant_with_min_max_vars(
x, -3.0, 3.0, narrow_range=True)

model = tf.keras.Model(inp, x)

saved_model_dir = os.path.join(self.get_temp_dir(),
'keras_conv2d_transpose')
model.save(saved_model_dir)
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

with self.assertRaises(convert.ConverterError) as error:
_ = converter.convert()
self.assertIn('mismatched quantized axes of input and output',
str(error.exception))

def _createModelWithInputShape(self, shape):
"""Create a simple SavedModel with a certain shape."""
saved_model_dir = os.path.join(self.get_temp_dir(), 'input_shape_model')
Expand Down

0 comments on commit aa0b852

Please sign in to comment.