From 13a997c5cf4a311ed4c6dd430e2c68ad9560b15f Mon Sep 17 00:00:00 2001 From: Riley Dulin Date: Tue, 18 Feb 2025 11:31:57 -0800 Subject: [PATCH] Add small repro test for unsigned -> signed et loss error (#8506) Summary: There was a difference in behavior from `quantized_decomposed.quantize_per_tensor` and `cadence.quantize_per_tensor`, specifically how rounding half values worked. The former rounds towards even (based on `torch.round` which does that). The latter rounds away from zero. Make sure the python implementation matches the Executorch implementation in this regard. Reviewed By: sabarishsnk Differential Revision: D69668881 --- backends/cadence/aot/TARGETS | 1 + backends/cadence/aot/fuse_ops.py | 8 +++++++- backends/cadence/aot/remove_ops.py | 2 ++ backends/cadence/aot/reorder_ops.py | 8 ++++++++ backends/cadence/aot/replace_ops.py | 10 ++++++---- 5 files changed, 24 insertions(+), 5 deletions(-) diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 0590e694602..78a78bbda30 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -180,6 +180,7 @@ python_library( typing = True, deps = [ "//caffe2:torch", + ":ops_registrations", ":compiler_utils", "//executorch/backends/cadence/aot:pass_utils", "//executorch/backends/cadence/aot:utils", diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index aa79b5582a7..47e6b8b5d03 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -16,6 +16,9 @@ from numbers import Number from typing import cast, Sequence +# Import these for the cadence function signatures. +import executorch.backends.cadence.aot.ops_registrations # noqa: F401 + import torch import torch.fx from executorch.backends.cadence.aot.compiler_utils import ( @@ -849,7 +852,10 @@ def attempt_fusion( if isinstance(arg, torch.fx.Node) and isinstance(arg.target, EdgeOpOverload) and get_edge_overload_packet(arg.target) - == exir_ops.edge.quantized_decomposed.dequantize_per_tensor + in ( + exir_ops.edge.quantized_decomposed.dequantize_per_tensor, + exir_ops.edge.cadence.dequantize_per_tensor, + ) ] multiplier_nodes = [ arg diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index caceabfba82..942f6d55533 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -569,6 +569,8 @@ class Subgraph: exir_ops.edge.aten.hardtanh.default, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.cadence.quantize_per_tensor.default, + exir_ops.edge.cadence.dequantize_per_tensor.default, } # must be initialized in the constructor diff --git a/backends/cadence/aot/reorder_ops.py b/backends/cadence/aot/reorder_ops.py index 0fd7f0b61a4..e8a8e230531 100644 --- a/backends/cadence/aot/reorder_ops.py +++ b/backends/cadence/aot/reorder_ops.py @@ -118,6 +118,8 @@ def get_descendent_quant_ops(self, node: torch.fx.Node) -> List[torch.fx.Node]: if user_target in { torch.ops.quantized_decomposed.quantize_per_tensor, exir_ops.edge.quantized_decomposed.quantize_per_tensor, + torch.ops.cadence.quantize_per_tensor, + exir_ops.edge.cadence.quantize_per_tensor, }: descendent_quant_ops.append(user) # If the successor is a trivially quantizable op, consider its users @@ -300,6 +302,8 @@ def advance_quantize_op(self, graph_module: torch.fx.GraphModule): if get_overload_packet(node.target) not in ( exir_ops.edge.quantized_decomposed.quantize_per_tensor, torch.ops.quantized_decomposed.quantize_per_tensor, + exir_ops.edge.cadence.quantize_per_tensor, + torch.ops.cadence.quantize_per_tensor, ): continue @@ -413,6 +417,7 @@ def postponing_feasible(self, dequant_node: torch.fx.Node): in { exir_ops.edge.quantized_decomposed.quantize_per_tensor, exir_ops.edge.quantized_decomposed.quantize_per_channel, + exir_ops.edge.cadence.quantize_per_tensor, } for x in users ) @@ -422,6 +427,7 @@ def postpone_dequantize_op(self, graph_module: torch.fx.GraphModule) -> bool: packet_to_overload_map = { exir_ops.edge.quantized_decomposed.dequantize_per_tensor: "default", exir_ops.edge.quantized_decomposed.dequantize_per_channel: "default", + exir_ops.edge.cadence.dequantize_per_tensor: "default", } graph = graph_module.graph modified = False @@ -500,6 +506,7 @@ class SinkOpsCloserToUsePass(ExportPass): exir_ops.edge.aten.dequantize, exir_ops.edge.quantized_decomposed.dequantize_per_tensor, exir_ops.edge.quantized_decomposed.dequantize_per_channel, + exir_ops.edge.cadence.dequantize_per_tensor, } def sink_ops_closer_to_use(self, graph_module: torch.fx.GraphModule): @@ -558,6 +565,7 @@ class HoistOpsCloserToDefPass(ExportPass): hoistable_ops: Set[EdgeOpOverload] = { exir_ops.edge.quantized_decomposed.quantize_per_tensor, + exir_ops.edge.cadence.quantize_per_tensor, exir_ops.edge.aten.slice_copy, exir_ops.edge.aten.select_copy, } diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 487d374fb80..d2fbc0eda80 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -162,11 +162,12 @@ def call_operator( kwargs: Dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: - if op not in {exir_ops.edge.quantized_decomposed.quantize_per_tensor.default}: + ns = exir_ops.edge if isinstance(op, EdgeOpOverload) else torch.ops + if op != ns.quantized_decomposed.quantize_per_tensor.default: return super().call_operator(op, args, kwargs, meta) return super().call_operator( - exir_ops.edge.cadence.quantize_per_tensor.default, + ns.cadence.quantize_per_tensor.default, args, kwargs, meta, @@ -188,11 +189,12 @@ def call_operator( kwargs: Dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: - if op not in {exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default}: + ns = exir_ops.edge if isinstance(op, EdgeOpOverload) else torch.ops + if op != ns.quantized_decomposed.dequantize_per_tensor.default: return super().call_operator(op, args, kwargs, meta) return super().call_operator( - exir_ops.edge.cadence.dequantize_per_tensor.default, + ns.cadence.dequantize_per_tensor.default, args, kwargs, meta,