From 3cf70b3d7bd8621dd39c452bddfc2f337b60764b Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Tue, 11 Nov 2025 10:09:05 -0800 Subject: [PATCH] Correctly track modified bit for variety of replace ops (#15727) Summary: Ops updated in this diff - ReplaceSafeSoftmaxWithSoftmax - ReplacePT2QuantWithCadenceQuantPass - ReplacePT2DequantWithCadenceDequantPass - ReplaceFunctionallyEquivalentOpTargets - ReplaceScalarTensorWithFullPass - ReplaceFullLikeWithFullPass - ReplaceInfArgInFullWithValuePass - ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass Also, removed ReplaceAtenApproxGeluWithApproxGeluPass since it was a NO-OP (not doing what its name suggested). Reviewed By: ethansfng Differential Revision: D86725366 --- backends/cadence/aot/replace_ops.py | 282 +++++++++--------- .../aot/tests/test_replace_ops_passes.py | 39 +-- 2 files changed, 161 insertions(+), 160 deletions(-) diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index d430e95c470..a05548a028c 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -34,6 +34,7 @@ CadencePassAttribute, none_throws, register_cadence_pass, + RemoveOrReplacePassInterface, ) from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass from executorch.backends.cadence.aot.utils import get_edge_overload_packet @@ -115,84 +116,84 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceSafeSoftmaxWithSoftmax(ExportPass): # keep +class ReplaceSafeSoftmaxWithSoftmax(RemoveOrReplacePassInterface): # keep """ Replace _safe_softmax with _softmax """ - def call_operator( - self, - op, - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != torch.ops.aten._safe_softmax.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [torch.ops.aten._safe_softmax.default] + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Add False for the half_to_float argument of softmax - softmax_args = list(args) + [False] + softmax_args = tuple(list(node.args) + [False]) - return super().call_operator( - torch.ops.aten._softmax.default, - tuple(softmax_args), - kwargs, - meta, - ) + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + torch.ops.aten._softmax.default, + args=softmax_args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplacePT2QuantWithCadenceQuantPass(ExportPass): +class ReplacePT2QuantWithCadenceQuantPass(RemoveOrReplacePassInterface): """ Replace the pt2 quantization ops with cadence quantization ops. We do not link kernels to the PT2 quantization ops, so we need to replace them with cadence ops at all optimization levels. """ - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - ns = exir_ops.edge if isinstance(op, EdgeOpOverload) else torch.ops - if op != ns.quantized_decomposed.quantize_per_tensor.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + ] - return super().call_operator( - ns.cadence.quantize_per_tensor.default, - args, - kwargs, - meta, - ) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + ns = exir_ops.edge if isinstance(node.target, EdgeOpOverload) else torch.ops + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + ns.cadence.quantize_per_tensor.default, + args=node.args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplacePT2DequantWithCadenceDequantPass(ExportPass): +class ReplacePT2DequantWithCadenceDequantPass(RemoveOrReplacePassInterface): """ Replace the pt2 dequantization ops with cadence dequantization ops. We do not link kernels to the PT2 quantization ops, so we need to replace them with cadence ops at all optimization levels. """ - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - ns = exir_ops.edge if isinstance(op, EdgeOpOverload) else torch.ops - if op != ns.quantized_decomposed.dequantize_per_tensor.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + ] - return super().call_operator( - ns.cadence.dequantize_per_tensor.default, - args, - kwargs, - meta, - ) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + ns = exir_ops.edge if isinstance(node.target, EdgeOpOverload) else torch.ops + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + ns.cadence.dequantize_per_tensor.default, + args=node.args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) @@ -232,18 +233,34 @@ def call_operator( @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceFunctionallyEquivalentOpTargets(ExportPass): +class ReplaceFunctionallyEquivalentOpTargets(RemoveOrReplacePassInterface): """ Replace an op with a functionally equivalent op by just switching the op target, but without incurring any change to the op args. """ - def call_operator(self, op, args, kwargs, meta): - if op not in functionally_equivalent_op_targets: - return super().call_operator(op, args, kwargs, meta) - return super().call_operator( - functionally_equivalent_op_targets[op], args, kwargs, meta - ) + @property + def targets(self) -> list[EdgeOpOverload]: + return list(functionally_equivalent_op_targets.keys()) + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + assert isinstance(node.target, EdgeOpOverload) + target_op = functionally_equivalent_op_targets[node.target] + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + target_op, + args=node.args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + + # RemoveOrReplacePassInterface calls eliminate_dead_code, but this doesn't + # remove impure nodes (nodes which have side effects). Not sure if that is + # generally safe, so instead of modifying the interface, just erasing + # these nodes for this pass. + node.graph.erase_node(node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -1438,82 +1455,95 @@ def call_operator(self, op, args, kwargs, meta): @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceScalarTensorWithFullPass(ExportPass): +class ReplaceScalarTensorWithFullPass(RemoveOrReplacePassInterface): """ aten.scalar_tensor can be replaced by aten.full with a shape of [1]. scalar_tensor is not supported, so this is an opt_level=0 pass. """ - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in { - exir_ops.edge.aten.scalar_tensor.default, + @property + def targets(self) -> list[EdgeOpOverload]: + return [ torch.ops.aten.scalar_tensor.default, - }: - return super().call_operator(op, args, kwargs, meta) + exir_ops.edge.aten.scalar_tensor.default, + ] - return super().call_operator( - exir_ops.edge.aten.full.default, - ( - [1], - args[0], - ), - {"dtype": torch.float32}, - meta, - ) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.full.default, + args=( + [1], + node.args[0], + ), + kwargs={"dtype": torch.float32}, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceFullLikeWithFullPass(ExportPass): +class ReplaceFullLikeWithFullPass(RemoveOrReplacePassInterface): """ aten.full_like can be replaced by aten.full with the shape of the arg tensor. full_like is not supported, so this is an opt_level=0 pass. """ - def call_operator(self, op, args, kwargs, meta): - if op not in { - exir_ops.edge.aten.full_like.default, - }: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.full_like.default] - # Get the shape of the "like" tensor, and pass that in to the full op. - return super().call_operator( - exir_ops.edge.aten.full.default, - ( - args[0].to_tensor().shape, - args[1], - ), - {}, - meta, - ) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + input_arg = node.args[0] + assert isinstance(input_arg, torch.fx.Node) + shape = input_arg.meta["val"].shape + fill_value = node.args[1] + + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.full.default, + args=(shape, fill_value), + kwargs={}, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceInfArgInFullWithValuePass(ExportPass): +class ReplaceInfArgInFullWithValuePass(RemoveOrReplacePassInterface): """ aten.full allows "-inf" and "inf" as inputs. The profiler cannot handle that, so replace them with the maximum value of the type. """ - def call_operator(self, op, args, kwargs, meta): - if op not in { - exir_ops.edge.aten.full.default, - }: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.full.default] - new_args = list(args) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: - if args[1] == float("-inf"): + new_args = list(node.args) + fill_value = node.args[1] + if fill_value == float("-inf"): new_args[1] = torch.finfo(torch.float32).min - elif args[1] == float("inf"): + elif fill_value == float("inf"): new_args[1] = torch.finfo(torch.float32).max + else: + return False - return super().call_operator(op, tuple(new_args), kwargs, meta) + new_args = tuple(new_args) + + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.full.default, + args=new_args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) @@ -1713,26 +1743,6 @@ def call_operator( return super().call_operator(op, args, kwargs, meta) -@register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceAtenApproxGeluWithApproxGeluPass(ExportPass): - """ - Replace the aten gelu op with an approximate arg with an approximate gelu op. - """ - - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in { - exir_ops.edge.aten.gelu.default, - }: - return super().call_operator(op, args, kwargs, meta) - return super().call_operator(op, args, kwargs, meta) - - # Adapted from fbcode/pyspeech/opt_passes/replace_ops.py @register_cadence_pass(CadencePassAttribute(opt_level=2)) class ReplaceSplitWithSlicePass(ExportPass): @@ -2122,18 +2132,25 @@ class CommonReplacePasses: @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass(ExportPass): +class ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass(RemoveOrReplacePassInterface): """ Replace aten linalg svd op with cadence custom op. """ - def call_operator(self, op, args, kwargs, meta): - if op != exir_ops.edge.aten._linalg_svd.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten._linalg_svd.default] - return super().call_operator( - exir_ops.edge.cadence.linalg_svd.default, args, kwargs, meta - ) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.cadence.linalg_svd.default, + args=node.args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True # This class encapsulates all the functions that replace/switch one op in the @@ -2165,6 +2182,5 @@ class CadenceReplaceOpsInGraph: ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass, ReplaceAtenAvgPoolWithCadenceAvgPoolPass, ReplaceWhereWithFullArgsWithWhereScalar, - ReplaceAtenApproxGeluWithApproxGeluPass, ReplaceMulTensorWithMulAndFullOpsPass, ] diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 73964c6c4c4..4bcc8bf371c 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -23,7 +23,6 @@ MakeSliceAndCatDimOutermostPass, ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass, ReplaceAddMMWithLinearPass, - ReplaceAtenApproxGeluWithApproxGeluPass, ReplaceAtenConvolutionWithCadenceConvolutionPass, ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass, ReplaceConstantPadNdWithSlicePass, @@ -329,7 +328,9 @@ def test_replace_functionally_equivalent_op_targets_relu( args=(x,), ) p = ReplaceFunctionallyEquivalentOpTargets() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.relu.default), @@ -813,7 +814,9 @@ def test_replace_masked_scalar_tensor_with_full( builder.output([aten_where_self]) original_gm = builder.get_graph_module() p = ReplaceScalarTensorWithFullPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1, @@ -837,7 +840,9 @@ def test_replace_scalar_tensor_with_full( args=(0.123,), ) p = ReplaceScalarTensorWithFullPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1, @@ -1339,28 +1344,6 @@ def test_replace_aten_where_with_cadence_broadcast( 1, ) - def test_no_replace_aten_gelu_with_approximate_gelu(self) -> None: - inputs = torch.randn(2, 1, 64) - - gm = single_op_builder( - placeholders=(inputs,), - op=exir_ops.edge.aten.gelu.default, - args=(inputs,), - ) - gm = ExportPass().call(gm).graph_module - - p = ReplaceAtenApproxGeluWithApproxGeluPass() - graph_after_passes = p.call(gm).graph_module - - # Assert that aten.gelu op was not decomposed, since it didn't have an approximate argument - self.assertEqual( - count_node( - graph_after_passes, - exir_ops.edge.aten.gelu.default, - ), - 1, - ) - def test_replace_split_with_sizes_with_slice(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 16, 8, 4)) @@ -2142,7 +2125,9 @@ def test_replace_aten_linalg_svd_with_cadence_linalg_svd( ) p = ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module # Assert that the aten linalg_svd op was replaced with cadence linalg_svd op self.assertEqual(