From 37828e472c1d94da47ee6bc3ce638f07f9dddb65 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Tue, 2 Dec 2025 10:43:04 -0800 Subject: [PATCH] More passes updated to be more efficient and correctly set their modified bit (#16044) Summary: Updated - FuseCascadedTransposeOrPermuteOps - ReplaceSplitWithSlicePass - ReplacePowWithMulPass - ReplaceMatmulWithTransposedMatmulPass Reviewed By: hsharma35, eigen-k Differential Revision: D87812526 --- backends/cadence/aot/fuse_ops.py | 101 ++++--- backends/cadence/aot/replace_ops.py | 257 ++++++++++-------- .../aot/tests/test_fusion_ops_passes.py | 52 +++- .../aot/tests/test_replace_ops_passes.py | 36 ++- 4 files changed, 269 insertions(+), 177 deletions(-) diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index dbd19e1d3af..25afdf7ee47 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -34,6 +34,7 @@ from executorch.backends.cadence.aot.pass_utils import ( CadencePassAttribute, register_cadence_pass, + RemoveOrReplacePassInterface, ) from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.exir.dialects._ops import ops as exir_ops @@ -454,7 +455,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class FuseCascadedTransposeOrPermuteOps(ExportPass): +class FuseCascadedTransposeOrPermuteOps(RemoveOrReplacePassInterface): """ Fuse a cascaded chain of transpose and permute ops """ @@ -464,63 +465,61 @@ class FuseCascadedTransposeOrPermuteOps(ExportPass): exir_ops.edge.aten.permute_copy.default, } - # Find a chain of transpose or permute ops, and fuse them into a single permute op. + @property + def targets(self) -> list[EdgeOpOverload]: + return list(self.transpose_or_permute_target) - def fuse_cascaded_transpose_or_permute_ops( - self, graph_module: torch.fx.GraphModule - ): - graph = graph_module.graph - for node in graph.nodes: - # We are only interested in permute/transpose ops - if node.target not in self.transpose_or_permute_target: - continue - # Get the cascaded chain of transpose/permute ops starting at node - cascaded_transpose_or_permute_ops = get_cascaded_ops( - [node], self.transpose_or_permute_target + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Get the cascaded chain of transpose/permute ops starting at node + cascaded_transpose_or_permute_ops = get_cascaded_ops( + [node], self.transpose_or_permute_target + ) + # The chain must have more than 1 node + if len(cascaded_transpose_or_permute_ops) == 1: + return False + + # Get shape from node metadata + val = node.meta.get("val") + if val is None: + return False + out_shape = val.shape + out_dims = len(out_shape) + + # This is the trivial dimension order + dims = list(range(out_dims)) + # Compute the effect of the chain on dims + for tp in cascaded_transpose_or_permute_ops: + dims = ( + get_transposed_dims(tp, dims) + if tp.target == exir_ops.edge.aten.transpose_copy.int + else get_permuted_dims(tp, dims) ) - # The chain must have more than 1 node - if len(cascaded_transpose_or_permute_ops) == 1: - continue - out_shape = get_shape(graph_module, node) - assert out_shape is not None - out_dims = len(out_shape) - # This is the trivial dimension order - dims = list(range(out_dims)) - # Compute the effect of the chain on dims - for tp in cascaded_transpose_or_permute_ops: - dims = ( - get_transposed_dims(tp, dims) - if tp.target == exir_ops.edge.aten.transpose_copy.int - else get_permuted_dims(tp, dims) - ) + graph = node.graph - # In case the permute chain cancelled each other, the final dims will - # be the same as the initial order. In that case, the chain was nop. - # Otherwise create a new permute op that encompasses the effect of the - # chain. - if dims == list(range(out_dims)): - cascaded_transpose_or_permute_ops[-1].replace_all_uses_with( - node.args[0] + # In case the permute chain cancelled each other, the final dims will + # be the same as the initial order. In that case, the chain was nop. + # Otherwise create a new permute op that encompasses the effect of the + # chain. + if dims == list(range(out_dims)): + cascaded_transpose_or_permute_ops[-1].replace_all_uses_with( + cast(torch.fx.Node, node.args[0]) + ) + else: + with graph.inserting_before(cascaded_transpose_or_permute_ops[-1]): + new_permute = graph.call_function( + exir_ops.edge.aten.permute_copy.default, + args=(node.args[0], dims), ) - else: - with graph.inserting_before(cascaded_transpose_or_permute_ops[-1]): - new_permute = graph.call_function( - exir_ops.edge.aten.permute_copy.default, - args=(node.args[0], dims), - ) - cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(new_permute) + new_permute.meta = cascaded_transpose_or_permute_ops[-1].meta + cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(new_permute) - # Now erase the chain - for tp in reversed(cascaded_transpose_or_permute_ops): - graph.erase_node(tp) - - graph_module.recompile() + # Now erase the chain (except the first node which will be handled by the interface) + for tp in reversed(cascaded_transpose_or_permute_ops[1:]): + graph.erase_node(tp) - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self.fuse_cascaded_transpose_or_permute_ops(graph_module) - result = super().call(graph_module) - return result + # Return True to indicate the first node in the chain should be removed + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 27dfe34d90f..7a3a3c90ede 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -20,7 +20,6 @@ import torch import torch.fx from executorch.backends.cadence.aot.compiler_utils import ( - get_shape, get_zero_point, is_node_with_op, quantize_tensor_multiplier, @@ -1807,23 +1806,59 @@ def call_operator( # Adapted from fbcode/pyspeech/opt_passes/replace_ops.py @register_cadence_pass(CadencePassAttribute(opt_level=2)) -class ReplaceSplitWithSlicePass(ExportPass): +class ReplaceSplitWithSlicePass(RemoveOrReplacePassInterface): """ split_with_sizes() delegates to slice() op, so perform this replacement here. This avoids the expense of delegation from ATen. """ - # For split_with_sizes, return the slice dim and extent for each split. - def get_split_sizes( - self, graph_module: torch.fx.GraphModule, node: torch.fx.Node - ) -> Optional[list[tuple[int, ...]]]: + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.split_with_sizes_copy.default] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # All the users of this split_with_sizes op must be getitem ops + if any(user.target != operator.getitem for user in node.users): + return False + + # Get the slice dim and extent for each split + slice_ops = self._get_split_sizes(node) + if slice_ops is None: + return False + + graph = node.graph + + # Go over each getitem user, and replace it with slice op + for user in list(node.users.keys()): + assert user.target == operator.getitem + item_idx = int(user.args[1]) + assert item_idx < len(slice_ops) + cur_slice = slice_ops[item_idx] + with graph.inserting_before(user): + cur_slice_node = graph.call_function( + exir_ops.edge.aten.slice_copy.Tensor, + (node.args[0], cur_slice[0], cur_slice[1], cur_slice[2], 1), + ) + # Metadata copy important + cur_slice_node.meta = user.meta + user.replace_all_uses_with(cur_slice_node) + + # Return True to indicate the split node should be removed + return True + + def _get_split_sizes(self, node: torch.fx.Node) -> Optional[list[tuple[int, ...]]]: + """For split_with_sizes, return the slice dim and extent for each split.""" # Parse the args of the split_with_sizes op tensor_arg, split_sizes = node.args[0:2] assert isinstance(tensor_arg, torch.fx.Node) - in_shape = get_shape(graph_module, tensor_arg) - split_dim = 0 if len(node.args) < 3 else node.args[2] - if in_shape is None: + + # Get shape from node metadata + val = tensor_arg.meta.get("val") + if val is None: return None + in_shape = val.shape + + split_dim = 0 if len(node.args) < 3 else node.args[2] # Canonicalize the split dimension assert isinstance(split_dim, int) @@ -1841,103 +1876,69 @@ def get_split_sizes( return slice_ops - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - graph = graph_module.graph - for node in graph.nodes: - if not isinstance(node.target, EdgeOpOverload): - continue - if ( - get_edge_overload_packet(node.target) - != exir_ops.edge.aten.split_with_sizes_copy - ): - continue - # All the users of this split_with_sizes op must be getitem ops - if any(user.target != operator.getitem for user in node.users): - continue - - # Get the slice dim and extent for each split - slice_ops = self.get_split_sizes(graph_module, node) - if slice_ops is None: - continue - # Go over each getitem user, and replace it with slice op - for user in list(node.users.keys()): - assert user.target == operator.getitem - item_idx = user.args[1] - assert item_idx < len(slice_ops) - cur_slice = slice_ops[item_idx] - with graph.inserting_before(user): - cur_slice_node = graph.call_function( - exir_ops.edge.aten.slice_copy.Tensor, - (node.args[0], cur_slice[0], cur_slice[1], cur_slice[2], 1), - ) - user.replace_all_uses_with(cur_slice_node) - graph.erase_node(user) +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class ReplacePowWithMulPass(RemoveOrReplacePassInterface): + """ + Replace the pow op with successive mul ops when the exponent is an + integer between 2 and 4 (inclusive). + """ - graph.erase_node(node) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.pow.Tensor_Scalar] - graph_module.recompile() - result = super().call(graph_module) - return result + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Check if we have at least 2 args and the exponent is an int + if len(node.args) < 2 or not isinstance(node.args[1], int): + return False + exponent = cast(int, node.args[1]) -@register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplacePowWithMulPass(ExportPass): - """ - Replace the pow op for a mul op. - """ + # Only replace if exponent is between 2 and 4 (inclusive) + if exponent < 2 or exponent > 4: + return False - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if not ( - len(args) > 1 - and isinstance(args[1], int) - and cast(int, args[1]) > 1 - and cast(int, args[1]) < 5 - and op - in { - exir_ops.edge.aten.pow.Tensor_Scalar, - } - ): - return super().call_operator(op, args, kwargs, meta) + x = node.args[0] + assert isinstance(x, torch.fx.Node) - x = args[0] - exponent = cast(int, args[1]) + graph = node.graph + result_node = x - if exponent > 2: - for _ in range(exponent, 2, -1): - x = super().call_operator( + # Create successive mul operations + # For exponent=2: x * x (1 mul) + # For exponent=3: (x * x) * x (2 muls) + # For exponent=4: ((x * x) * x) * x (3 muls) + for _ in range(exponent - 1): + with graph.inserting_before(node): + result_node = graph.call_function( exir_ops.edge.aten.mul.Tensor, - (x, args[0]), - {}, - meta, + args=(result_node, x), ) - return super().call_operator( - exir_ops.edge.aten.mul.Tensor, - (x, args[0]), - {}, - meta, - ) + result_node.meta = node.meta + + node.replace_all_uses_with(result_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceMatmulWithTransposedMatmulPass(ExportPass): +class ReplaceMatmulWithTransposedMatmulPass(RemoveOrReplacePassInterface): """ For certain backends, we have efficient kernels for transposed matmul. We replace AxB with AxB' for such backends. """ - def call_operator(self, op, args, kwargs, meta): - if op != exir_ops.edge.cadence.quantized_matmul.default or args[-1] is True: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.cadence.quantized_matmul.default] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # If already transposed, bail + if len(node.args) >= 9 and node.args[-1] is True: + return False # Get the args - if len(args) == 9: + if len(node.args) == 9: ( X_arg, X_zero_point, @@ -1948,8 +1949,8 @@ def call_operator(self, op, args, kwargs, meta): out_shift, out_zero_point, transposed, - ) = args - elif len(args) == 8: + ) = node.args + elif len(node.args) == 8: ( X_arg, X_zero_point, @@ -1959,37 +1960,43 @@ def call_operator(self, op, args, kwargs, meta): out_multiplier, out_shift, out_zero_point, - ) = args + ) = node.args transposed = False else: raise AssertionError( - f"Unexpected number of args for quantized_matmul: {len(args)}" + f"Unexpected number of args for quantized_matmul: {len(node.args)}" ) # If the matmul is already transposed, bail if transposed: - return super().call_operator(op, args, kwargs, meta) + return False - # Get the second tensor - Y_tensor = Y_arg.to_tensor() - # Concretize the bias - zero_bias = super().call_operator( - exir_ops.edge.aten.full.default, - ([Y_tensor.size(-1)], 0), - {"dtype": torch.int32}, - meta, - ) + # Get the second tensor from metadata + assert isinstance(Y_arg, torch.fx.Node) + Y_tensor_val = Y_arg.meta.get("val") + if Y_tensor_val is None: + return False - # Y_arg is always a ProxyValue, so we insert a transpose node - transpose_args = (Y_arg, -1, -2) - Y_arg_t = super().call_operator( - exir_ops.edge.aten.transpose_copy.int, - transpose_args, - {}, - meta, - ) + graph = node.graph + + # Create zero bias + with graph.inserting_before(node): + zero_bias = graph.call_function( + exir_ops.edge.aten.full.default, + args=([Y_tensor_val.size(-1)], 0), + kwargs={"dtype": torch.int32}, + ) + zero_bias.meta = node.meta + + # Transpose Y_arg + with graph.inserting_before(node): + Y_arg_t = graph.call_function( + exir_ops.edge.aten.transpose_copy.int, + args=(Y_arg, -1, -2), + ) + Y_arg_t.meta = node.meta - # Construct the new args, and return the transposed matmult op + # Construct the new args, and create the transposed matmul op new_args = ( X_arg, X_zero_point, @@ -2001,18 +2008,32 @@ def call_operator(self, op, args, kwargs, meta): out_zero_point, True, ) - return super().call_operator(op, new_args, kwargs, meta) + + with graph.inserting_before(node): + new_node = graph.call_function( + exir_ops.edge.cadence.quantized_matmul.default, + args=new_args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta + + node.replace_all_uses_with(new_node) + return True def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False result = super().call(graph_module) - # Fuse any inserted transpose node with transpose/permute nodes - # surrounding it. - result = FuseCascadedTransposeOrPermuteOps()(result.graph_module) - assert result is not None - # Replace permute with transpose. - result = ReplacePermuteWithTransposePass()(result.graph_module) - assert result is not None - return result + modified = modified or result.modified + if modified: + # Fuse any inserted transpose node with transpose/permute nodes + # surrounding it. + result = FuseCascadedTransposeOrPermuteOps().call(result.graph_module) + modified = modified or result.modified + # Replace permute with transpose. + result = ReplacePermuteWithTransposePass().call(result.graph_module) + modified = modified or result.modified + + return PassResult(result.graph_module, modified) @register_cadence_pass(CadencePassAttribute(opt_level=1)) diff --git a/backends/cadence/aot/tests/test_fusion_ops_passes.py b/backends/cadence/aot/tests/test_fusion_ops_passes.py index d160a02721a..980acee5b66 100644 --- a/backends/cadence/aot/tests/test_fusion_ops_passes.py +++ b/backends/cadence/aot/tests/test_fusion_ops_passes.py @@ -7,6 +7,7 @@ # pyre-strict +import copy import unittest from typing import cast, Final, List, Tuple @@ -29,6 +30,46 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import PassResult, ProxyValue +from torch.utils import _pytree as pytree + + +def validate_numerics( + original: torch.fx.GraphModule, + modified: torch.fx.GraphModule, + inputs: tuple[torch.Tensor, ...] | list[torch.Tensor], + pass_name: str, + rtol: float = 1e-5, + atol: float = 1e-6, +) -> None: + """Validate that two graph modules produce numerically equivalent outputs. + + Args: + original: The original graph module before the pass + modified: The modified graph module after the pass + inputs: Input tensors to run through both graphs + pass_name: Name of the pass being validated (for error messages) + rtol: Relative tolerance for allclose comparison + atol: Absolute tolerance for allclose comparison + """ + original.eval() + modified.eval() + with torch.no_grad(): + orig_out = original(*inputs) + mod_out = modified(*inputs) + + flat_orig_out, _ = pytree.tree_flatten(orig_out) + flat_mod_out, _ = pytree.tree_flatten(mod_out) + + # Check that outputs match within tolerance + for i, (orig_tensor, mod_tensor) in enumerate(zip(flat_orig_out, flat_mod_out)): + if not torch.allclose(orig_tensor, mod_tensor, rtol=rtol, atol=atol): + max_diff = torch.max(torch.abs(orig_tensor - mod_tensor)).item() + raise AssertionError( + f"Pass validation failed for pass {pass_name}. " + f"Output tensor {i} differs by max {max_diff:.6e}. " + f"Expected rtol={rtol}, atol={atol}. " + f"Original output: {orig_tensor}, Modified output: {mod_tensor}" + ) class TestFusionPassesBase(unittest.TestCase): @@ -202,7 +243,8 @@ def test_keep_mm_add_with_multiple_users(self) -> None: class TestFusionPasses(TestFusionPassesBase): def test_permute_transpose_fusion(self) -> None: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(3, 1, 3, 1, 4, dtype=torch.float32)) + x_input = torch.randn(3, 1, 3, 1, 4, dtype=torch.float32) + x = builder.placeholder("x", x_input) permute = builder.call_operator( op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 4, 1, 3]) ) @@ -212,8 +254,11 @@ def test_permute_transpose_fusion(self) -> None: ) builder.output([output]) original_graph = builder.get_graph_module() + graph_copy = copy.deepcopy(original_graph) p = FuseCascadedTransposeOrPermuteOps() - converted_graph = cast(PassResult, p(original_graph)).graph_module + result = p.call(original_graph) + self.assertTrue(result.modified) + converted_graph = result.graph_module converted_graph.graph.eliminate_dead_code() # Assert that permute op was fused with transpose op self.assertEqual( @@ -222,6 +267,9 @@ def test_permute_transpose_fusion(self) -> None: self.assertEqual( count_node(converted_graph, exir_ops.edge.aten.transpose_copy.int), 0 ) + validate_numerics( + graph_copy, converted_graph, (x_input,), "FuseCascadedTransposeOrPermuteOps" + ) def test_view_fusion(self) -> None: builder = GraphBuilder() diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 6143faedc14..663342b4e0e 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -165,8 +165,13 @@ def test_replace_matmul_with_transposed_matmul( ) builder.output([matmul]) original_gm = builder.get_graph_module() + + gm_before = copy.deepcopy(original_gm) p = ReplaceMatmulWithTransposedMatmulPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = p.call(original_gm) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 1, @@ -178,7 +183,7 @@ def test_replace_matmul_with_transposed_matmul( 1, ) validate( - original_gm, + gm_before, graph_after_passes, (x_, y_), "ReplaceMatmulWithTransposedMatmulPass", @@ -1439,7 +1444,8 @@ def test_replace_aten_where_with_cadence_broadcast( def test_replace_split_with_sizes_with_slice(self) -> None: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(1, 16, 8, 4)) + x_input = torch.randn(1, 16, 8, 4) + x = builder.placeholder("x", x_input) split = builder.call_operator( exir_ops.edge.aten.split_with_sizes_copy.default, (x, [8, 8], 1) ) @@ -1449,8 +1455,18 @@ def test_replace_split_with_sizes_with_slice(self) -> None: builder.output([out0, out1]) graph_module = builder.get_graph_module() + gm_before = copy.deepcopy(graph_module) p = ReplaceSplitWithSlicePass() - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + result = cast(PassResult, p(graph_module)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + validate( + gm_before, + graph_after_passes, + [x_input], + "ReplaceSplitWithSlicePass", + ) self.assertEqual( count_node( @@ -1465,14 +1481,22 @@ def test_replace_split_with_sizes_with_slice(self) -> None: @expand([[2], [3], [4]]) def test_replace_pow_with_mul(self, exponent: int) -> None: - x = torch.randn(2, 1, 64) + x_input = torch.randn(2, 1, 64) + x = x_input original_gm = single_op_builder( placeholders=(x,), op=exir_ops.edge.aten.pow.Tensor_Scalar, args=(x, exponent), ) + + gm_before = copy.deepcopy(original_gm) p = ReplacePowWithMulPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + validate(gm_before, graph_after_passes, [x_input], "ReplacePowWithMulPass") + self.assertEqual( count_node( graph_after_passes,