From 32b1c6be7ba37eb6b54dc27249216ea774ff4e59 Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Wed, 18 Jun 2025 10:07:07 -0700 Subject: [PATCH] Make the requant pass call the per_tensor overload Summary: As titled. No need to introduce tensors. Differential Revision: D74216340 --- backends/cadence/aot/fuse_ops.py | 28 ++++--------------- backends/cadence/aot/remove_ops.py | 2 +- .../aot/tests/test_fusion_ops_passes.py | 10 +++---- .../aot/tests/test_reorder_ops_passes.py | 2 +- 4 files changed, 12 insertions(+), 30 deletions(-) diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index aaf7f051b09..5c7f10729cc 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -712,32 +712,14 @@ def _create_requantize_node( out_dtype: torch.dtype, graph: torch.fx.Graph, ) -> torch.fx.Node: - in_scale_tensor = graph.call_function( - exir_ops.edge.aten.full.default, args=((1,), in_scale) - ) - in_zero_point_tensor = graph.call_function( - exir_ops.edge.aten.full.default, - args=((1,), in_zero_point), - kwargs={"dtype": torch.int32}, - ) - out_scale_tensor = graph.call_function( - exir_ops.edge.aten.full.default, args=((1,), out_scale) - ) - out_zero_point_tensor = graph.call_function( - exir_ops.edge.aten.full.default, - args=((1,), out_zero_point), - kwargs={"dtype": torch.int32}, - ) - # cadence::requantize(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, ScalarType out_dtype) -> Tensor Y - # TODO(hardiksharma): Add support for per-tensor requantize. return graph.call_function( - exir_ops.edge.cadence.requantize.default, + exir_ops.edge.cadence.requantize.per_tensor, args=( in_tensor, - in_scale_tensor, - in_zero_point_tensor, - out_scale_tensor, - out_zero_point_tensor, + in_scale, + in_zero_point, + out_scale, + out_zero_point, out_dtype, ), ) diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index 996dfa43f8f..fe23ea73754 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -447,7 +447,7 @@ def call_operator( kwargs: dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: - if op != exir_ops.edge.cadence.requantize.default: + if op != exir_ops.edge.cadence.requantize.per_tensor: return super().call_operator(op, args, kwargs, meta) # Parse the args diff --git a/backends/cadence/aot/tests/test_fusion_ops_passes.py b/backends/cadence/aot/tests/test_fusion_ops_passes.py index 30ea91bafb5..7609d972377 100644 --- a/backends/cadence/aot/tests/test_fusion_ops_passes.py +++ b/backends/cadence/aot/tests/test_fusion_ops_passes.py @@ -298,7 +298,7 @@ def test_force_quant_dequant_fusion(self): # Verify that dequant/quant pair was replaced with requantize. exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, - exir_ops.edge.cadence.requantize.default: 1, + exir_ops.edge.cadence.requantize.per_tensor: 1, }, ) @@ -328,7 +328,7 @@ def test_no_replace_quant_permute_dequant_with_requantize(self): # quantize -> permute -> dequantize should not be replaced with requantize. exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1, - exir_ops.edge.cadence.requantize.default: 0, + exir_ops.edge.cadence.requantize.per_tensor: 0, }, ) @@ -357,7 +357,7 @@ def test_replace_quant_view_dequant_with_requantize(self): # Verify that dequant/quant pair was replaced with requantize. exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, - exir_ops.edge.cadence.requantize.default: 1, + exir_ops.edge.cadence.requantize.per_tensor: 1, }, ) @@ -383,7 +383,7 @@ def test_replace_dequant_quant_with_requantize(self): # Verify that dequant -> quant was replaced with requantize. exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, - exir_ops.edge.cadence.requantize.default: 1, + exir_ops.edge.cadence.requantize.per_tensor: 1, }, ) @@ -413,7 +413,7 @@ def test_replace_dequant_permute_quant_with_requantize(self): exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, exir_ops.edge.aten.permute_copy.default: 1, - exir_ops.edge.cadence.requantize.default: 1, + exir_ops.edge.cadence.requantize.per_tensor: 1, }, ) diff --git a/backends/cadence/aot/tests/test_reorder_ops_passes.py b/backends/cadence/aot/tests/test_reorder_ops_passes.py index 3e64a0ecd7c..062d188be1d 100644 --- a/backends/cadence/aot/tests/test_reorder_ops_passes.py +++ b/backends/cadence/aot/tests/test_reorder_ops_passes.py @@ -214,7 +214,7 @@ def test_advance_branched_quantize(self): self.assertEqual( count_node( graph_module, - exir_ops.edge.cadence.requantize.default, + exir_ops.edge.cadence.requantize.per_tensor, ), 1, )