Skip to content

Commit

Permalink
Add support for legacy weight-only per-channel for Conv2D and Depthwi…
Browse files Browse the repository at this point in the history
…seConv2D

PiperOrigin-RevId: 531951144
  • Loading branch information
kyuyeunk authored and tensorflower-gardener committed May 14, 2023
1 parent 494e069 commit e267a61
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@ class RestoreWeightShapePattern
auto new_shape_const_attr =
DenseElementsAttr::get(shape_spec_type, new_shape.getShape());
rewriter.setInsertionPointAfter(weight_op);
auto new_shape_const = rewriter.create<arith::ConstantOp>(
auto new_shape_const = rewriter.create<TF::ConstOp>(
weight_op->getLoc(), shape_spec_type, new_shape_const_attr);
auto reshape_op = rewriter.create<TF::ReshapeOp>(
weight_op->getLoc(), new_shape, weight_op->getResult(0),
Expand All @@ -1026,7 +1026,10 @@ class RestoreWeightShapePattern
StringRef function_name = f_attr.getValue();
// TODO(b/228928859): Improve the getter function to match attributes rather
// than function name.
if (!function_name.startswith("quantized_")) {
// If enable_legacy_weight_only is enabled, QuantizeFunctionsPattern
// does not get called and function remains as composite
if (!function_name.startswith("quantized_") &&
!function_name.startswith("composite_")) {
return failure();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,27 +308,6 @@ def test_drq_per_channel_for_non_uniform_opset_raises_value_error(
self._input_saved_model_path, quantization_options=options
)

def test_weight_only_per_channel_with_legacy_weight_only_raises_value_error(
self,
):
model = self.SimpleModel()

saved_model_save.save(model, self._input_saved_model_path)

options = quant_opts_pb2.QuantizationOptions(
quantization_method=quant_opts_pb2.QuantizationMethod(
experimental_method=_ExperimentalMethod.WEIGHT_ONLY
),
op_set=quant_opts_pb2.XLA,
enable_per_channel_quantization=True,
enable_legacy_weight_only=True,
)

with self.assertRaises(ValueError):
quantize_model.quantize(
self._input_saved_model_path, quantization_options=options
)

@parameterized.named_parameters(
('weight_only_per_tensor', False),
('legacy_weight_only_per_tensor', True),
Expand Down Expand Up @@ -4779,14 +4758,16 @@ def test_matmul_model(
@parameterized.named_parameters(
# TODO(b/269421880): Enable legacy weight-only scheme with the uniform
# quantized opset
('to_xla_per_tensor', quant_opts_pb2.XLA, False),
('to_xla_per_channel', quant_opts_pb2.XLA, True),
('to_xla_per_tensor', quant_opts_pb2.XLA, False, False),
('to_xla_per_channel', quant_opts_pb2.XLA, True, False),
('to_xla_per_channel_legacy', quant_opts_pb2.XLA, True, True),
)
@test_util.run_in_graph_and_eager_modes
def test_conv_model(
self,
target_opset: quant_opts_pb2.OpSet,
enable_per_channel_quantization: bool,
enable_legacy_weight_only: bool,
):
input_shape = (1, 3, 4, 512)
filter_shape = (2, 3, 512, 2)
Expand All @@ -4807,6 +4788,7 @@ def test_conv_model(
),
op_set=target_opset,
enable_per_channel_quantization=enable_per_channel_quantization,
enable_legacy_weight_only=enable_legacy_weight_only,
)

converted_model = quantize_model.quantize(
Expand All @@ -4827,8 +4809,10 @@ def test_conv_model(
)
output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def

if not enable_legacy_weight_only:
self.assertTrue(self._contains_op(output_graphdef, 'XlaConvV2'))

# Due to other meta data, the compression is not exactly 1/4.
self.assertTrue(self._contains_op(output_graphdef, 'XlaConvV2'))
self.assertSizeRatioLessThan(
self._output_saved_model_path,
self._input_saved_model_path,
Expand Down Expand Up @@ -4870,14 +4854,16 @@ def test_conv_model(
@parameterized.named_parameters(
# TODO(b/269421880): Enable legacy weight-only scheme with the uniform
# quantized opset
('to_xla_per_tensor', quant_opts_pb2.XLA, False),
('to_xla_per_channel', quant_opts_pb2.XLA, True),
('to_xla_per_tensor', quant_opts_pb2.XLA, False, False),
('to_xla_per_channel', quant_opts_pb2.XLA, True, False),
('to_xla_per_channel_legacy', quant_opts_pb2.XLA, True, True),
)
@test_util.run_in_graph_and_eager_modes
def test_depthwise_conv2d_model(
self,
target_opset: quant_opts_pb2.OpSet,
enable_per_channel_quantization: bool,
enable_legacy_weight_only: bool,
):
input_shape = (1, 3, 4, 512)
filter_shape = (2, 3, 512, 2)
Expand All @@ -4897,6 +4883,7 @@ def test_depthwise_conv2d_model(
),
op_set=target_opset,
enable_per_channel_quantization=enable_per_channel_quantization,
enable_legacy_weight_only=enable_legacy_weight_only,
)

converted_model = quantize_model.quantize(
Expand All @@ -4918,7 +4905,8 @@ def test_depthwise_conv2d_model(
output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def

# Due to other meta data, the compression is not exactly 1/4.
self.assertTrue(self._contains_op(output_graphdef, 'XlaConvV2'))
if not enable_legacy_weight_only:
self.assertTrue(self._contains_op(output_graphdef, 'XlaConvV2'))

size_threshold = 0.5 if enable_per_channel_quantization else 0.3
self.assertSizeRatioLessThan(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1033,17 +1033,6 @@ def _populate_quantization_options_default_values(
'Quantized opset and Weight-only.'
)

if (
quantization_options.enable_per_channel_quantization
and quantization_options.enable_legacy_weight_only
and quantization_options.quantization_method.experimental_method
== _ExperimentalMethod.WEIGHT_ONLY
):
raise ValueError(
'Weight-only per-channel is only supported with'
'enable_legacy_weight_only disabled'
)

if (
quantization_options.quantization_method.experimental_method
== _ExperimentalMethod.WEIGHT_ONLY
Expand Down

0 comments on commit e267a61

Please sign in to comment.