diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index a47422d3dc6..9dc695c68af 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -6,9 +6,8 @@ # pyre-strict -import logging from dataclasses import dataclass, field -from typing import cast, List, Optional, Sequence, Set, Type +from typing import cast, List, Optional, Set, Type # Import these for the cadence function signatures. import executorch.backends.cadence.aot.ops_registrations # noqa: F401 @@ -69,45 +68,57 @@ class RemoveRedundantOps: @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class RemoveZeroSizedCatArgsPass(ExportPass): - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.cat.default: - return super().call_operator(op, args, kwargs, meta) - - # Remove any zero-sized tensor arg to form a new args list. - cat_inputs: list[ProxyValue] = [] - for arg in cast(Sequence[ProxyValue], args[0]): - if arg.to_tensor().numel() > 0: - cat_inputs.append(arg) +class RemoveZeroSizedCatArgsPass(RemoveOrReplacePassInterface): + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.cat.default] - # If all the tensors were empty, we just return an empty tensor with - # the right shape. + def maybe_remove_or_replace(self, node: Node) -> bool: + # Get the cat inputs (first argument is a list of tensors) + cat_inputs_arg = node.args[0] + + # Assert that cat_inputs_arg is iterable + assert isinstance( + cat_inputs_arg, (list, tuple) + ), "cat_inputs_arg must be a sequence type" + + # Filter out zero-sized tensors + cat_inputs: list[Node] = [] + for arg in cat_inputs_arg: + if isinstance(arg, Node) and arg.meta.get("val") is not None: + if arg.meta["val"].numel() > 0: + cat_inputs.append(arg) + + # If all tensors were empty, create a full op with the right shape if not cat_inputs: - empty_shape = meta["val"].shape - dtype = meta["val"].dtype - return super().call_operator( - exir_ops.edge.aten.full.default, - (tuple(empty_shape), 0), - {"dtype": dtype}, - meta, - ) + empty_shape = node.meta["val"].shape + dtype = node.meta["val"].dtype + # Create a new full node + with node.graph.inserting_before(node): + full_node = node.graph.call_function( + exir_ops.edge.aten.full.default, + args=(tuple(empty_shape), 0), + kwargs={"dtype": dtype}, + ) + full_node.meta = node.meta.copy() + node.replace_all_uses_with(full_node) + return True - # If there was only one tensor in the cat_inputs list, - # we can safely erase this cat op. + # If only one tensor remains, replace with it if len(cat_inputs) == 1: - return cat_inputs[0] + node.replace_all_uses_with(cat_inputs[0]) + return True + + # If the number of inputs changed, update the cat args + if len(cat_inputs) < len(cat_inputs_arg): + # Update the first argument with filtered inputs + new_args = list(node.args) + new_args[0] = cat_inputs + node.args = tuple(new_args) + return True - # Otherwise, we replace args[0] with cat_inputs. - new_args = list(args) - # pyre error introduced after D66937105 - new_args[0] = cat_inputs # pyre-ignore[6] - return super().call_operator(op, tuple(new_args), kwargs, meta) + # No changes needed + return False @register_cadence_pass(CadencePassAttribute(opt_level=0)) @@ -151,25 +162,29 @@ def maybe_remove_or_replace(self, node: Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveZeroSizedConstantPadNd(ExportPass): - def call_operator( - self, - op, # pyre-ignore - args: tuple[ProxyValue, tuple[int, ...], Argument], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.constant_pad_nd.default: - return super().call_operator(op, args, kwargs, meta) +class RemoveZeroSizedConstantPadNd(RemoveOrReplacePassInterface): + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.constant_pad_nd.default] - input_tensor = args[0] - padding = args[1] + def maybe_remove_or_replace(self, node: Node) -> bool: + # Get padding argument (second argument) + if len(node.args) < 2: + return False + + padding = node.args[1] + if not isinstance(padding, (list, tuple)): + return False + # If any padding value is non-zero, keep the node if any(x != 0 for x in padding): - return super().call_operator(op, args, kwargs, meta) + return False - logging.debug(f"Erasing 0 sized constant pad nd node with {input_tensor}") - return input_tensor + # All padding is zero, replace with input + input_node = node.args[0] + assert isinstance(input_node, Node) + node.replace_all_uses_with(input_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -721,17 +736,17 @@ def get_squeeze_indices(self, view_node: Node) -> List[int]: return squeeze_indices - def handle_squeeze(self, view_node: Node, visited_view_nodes: Set[Node]) -> None: + def handle_squeeze(self, view_node: Node, visited_view_nodes: Set[Node]) -> bool: if view_node in visited_view_nodes: - return + return False squeeze_indices = self.get_squeeze_indices(view_node) if not squeeze_indices: - return + return False # Only handle simple chains for now. if len(view_node.users) != 1: - return + return False node = next(iter(view_node.users)) # Traverse down from the node until finding another view op. @@ -739,9 +754,9 @@ def handle_squeeze(self, view_node: Node, visited_view_nodes: Set[Node]) -> None while node.target != exir_ops.edge.aten.view_copy.default: # Only handle simple chains for now if len(node.users) != 1: - return + return False if node.target not in self.intermediate_ops: - return + return False if node.target == exir_ops.edge.aten.slice_copy.Tensor: intermediate_slices.append(node) node = next(iter(node.users)) @@ -764,18 +779,22 @@ def handle_squeeze(self, view_node: Node, visited_view_nodes: Set[Node]) -> None # Skip the initial view node. input_node = cast(Node, get_arg(view_node, "input")) view_node.replace_all_uses_with(input_node) + return True def call(self, graph_module: torch.fx.GraphModule) -> PassResult: visited_view_nodes = set() + modified = False for view_node in graph_module.graph.find_nodes( op="call_function", target=exir_ops.edge.aten.view_copy.default, sort=True ): - self.handle_squeeze(view_node, visited_view_nodes) + modified |= self.handle_squeeze(view_node, visited_view_nodes) - graph_module.graph.eliminate_dead_code() - graph_module.recompile() + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return super().call(graph_module) - return super().call(graph_module) + return PassResult(graph_module, False) @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -798,23 +817,27 @@ class RemoveBranchedQuantDequant(ExportPass): } def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self.remove_branched( + modified = self.remove_branched( graph_module, self.quantize_op_packets, self.dequantize_op_packets ) - self.remove_branched( + modified |= self.remove_branched( graph_module, self.dequantize_op_packets, self.quantize_op_packets ) - graph_module.graph.eliminate_dead_code() - result = super().call(graph_module) - return result + if modified: + graph_module.graph.eliminate_dead_code() + result = super().call(graph_module) + return result + + return PassResult(graph_module, False) def remove_branched( self, graph_module: torch.fx.GraphModule, producer_pkts: set[EdgeOpOverloadPacket], consumer_pkts: set[EdgeOpOverloadPacket], - ) -> None: + ) -> bool: + modified = False for node in graph_module.graph.nodes: if ( node.op != "call_function" @@ -838,61 +861,62 @@ def remove_branched( continue user.replace_all_uses_with(node.args[0]) + modified = True + return modified -class RemoveCatFromSliceCopyPass(ExportPass): + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class RemoveCatFromSliceCopyPass(RemoveOrReplacePassInterface): """ Simplifies cat->slice_copy chains where one of the cat inputs can be directly passed to the slice_copy. """ - def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None: - for slice_copy_node in graph_module.graph.find_nodes( - op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor - ): - cat_node = cast(Node, get_arg(slice_copy_node, "input")) - slice_dim = cast(int, get_arg(slice_copy_node, "dim")) - start_idx = cast(int, get_arg(slice_copy_node, "start")) - end_idx = cast(int, get_arg(slice_copy_node, "end")) - step = cast(int, get_arg(slice_copy_node, "step")) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.slice_copy.Tensor] - if cat_node.target != exir_ops.edge.aten.cat.default or step != 1: - continue + def maybe_remove_or_replace(self, node: Node) -> bool: + cat_node = cast(Node, get_arg(node, "input")) + slice_dim = cast(int, get_arg(node, "dim")) + start_idx = cast(int, get_arg(node, "start")) + end_idx = cast(int, get_arg(node, "end")) + step = cast(int, get_arg(node, "step")) - # Make sure cat and slice happens on the same dimension. - cat_dim = cast(Node, get_arg(cat_node, "dim")) - if cat_dim != slice_dim: - continue + if cat_node.target != exir_ops.edge.aten.cat.default or step != 1: + return False + + # Make sure cat and slice happens on the same dimension. + cat_dim = cast(int, get_arg(cat_node, "dim")) + if cat_dim != slice_dim: + return False - # Canonicalize slice indices. - cat_output_shape = cat_node.meta["val"].shape - if start_idx is None: - start_idx = 0 - elif start_idx < 0: - start_idx += cat_output_shape[cat_dim] - if end_idx is None or end_idx > cat_output_shape[cat_dim]: - end_idx = cat_output_shape[cat_dim] - elif end_idx < 0: - end_idx += cat_output_shape[cat_dim] - - offset = 0 - for cat_input_node in cast(List[Node], get_arg(cat_node, "tensors")): - cat_input_shape = cat_input_node.meta["val"].shape - - # Check if the slice range overlaps with the cat input range. - if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]: - slice_copy_node.replace_input_with(cat_node, cat_input_node) - set_arg(slice_copy_node, "start", start_idx - offset) - set_arg(slice_copy_node, "end", end_idx - offset) - break - - offset += cat_input_shape[cat_dim] + # Canonicalize slice indices. + cat_output_shape = cat_node.meta["val"].shape + if start_idx is None: + start_idx = 0 + elif start_idx < 0: + start_idx += cat_output_shape[cat_dim] + if end_idx is None or end_idx > cat_output_shape[cat_dim]: + end_idx = cat_output_shape[cat_dim] + elif end_idx < 0: + end_idx += cat_output_shape[cat_dim] + + offset = 0 + for cat_input_node in cast(List[Node], get_arg(cat_node, "tensors")): + cat_input_shape = cat_input_node.meta["val"].shape + + # Check if the slice range overlaps with the cat input range. + if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]: + node.replace_input_with(cat_node, cat_input_node) + set_arg(node, "start", start_idx - offset) + set_arg(node, "end", end_idx - offset) + return True - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self._remove_unused_cat(graph_module) - graph_module.recompile() - graph_module.graph.eliminate_dead_code() - return super().call(graph_module) + offset += cat_input_shape[cat_dim] + + return False class CommonRemovePasses: @@ -901,7 +925,6 @@ class CommonRemovePasses: RemoveAliasCopyOpPass, RemoveNopExpandOpPass, RemoveNopSliceOrViewOpPass, - RemoveNopSelectOpPass, RemoveToOpsPass, RemoveZeroSizedCatArgsPass, RemovePermutesAroundElementwiseOps, diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index 483d737f97d..158ec73cf27 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -196,13 +196,13 @@ def test_remove_zero_sized_constant_pad_nd( ) builder.output([pad]) original = builder.get_graph_module() - graph_after_passes = cast( - PassResult, RemoveZeroSizedConstantPadNd()(original) - ).graph_module + pass_result = cast(PassResult, RemoveZeroSizedConstantPadNd()(original)) + graph_after_passes = pass_result.graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.constant_pad_nd.default), 0, ) + self.assertTrue(pass_result.modified) def test_remove_expand(self) -> None: builder = GraphBuilder() @@ -228,12 +228,12 @@ def test_remove_zero_arg_cat(self) -> None: ) builder.output([concat]) original = builder.get_graph_module() - graph_after_passes = cast( - PassResult, RemoveZeroSizedCatArgsPass()(original) - ).graph_module + pass_result = cast(PassResult, RemoveZeroSizedCatArgsPass()(original)) + graph_after_passes = pass_result.graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 0 ) + self.assertTrue(pass_result.modified) def test_remove_clone(self) -> None: builder = GraphBuilder() @@ -611,7 +611,9 @@ def test_remove_squeeze_view_before_elemwise_ops(self) -> None: original = deepcopy(model) p = RemoveSqueezeViewBeforeElementwiseOps() - transformed = cast(PassResult, p(model)).graph_module + pass_result = cast(PassResult, p(model)) + self.assertTrue(pass_result.modified) + transformed = pass_result.graph_module # First view should be eliminated and second view should be trivial. views = transformed.graph.find_nodes( @@ -872,9 +874,9 @@ def test_remove_dequant_on_branch(self) -> None: ) builder.output([x1_output, y1_output]) original = builder.get_graph_module() - graph_after_passes = cast( - PassResult, RemoveBranchedQuantDequant()(original) - ).graph_module + pass_result = cast(PassResult, RemoveBranchedQuantDequant()(original)) + self.assertTrue(pass_result.modified) + graph_after_passes = pass_result.graph_module self.assertEqual( count_node( graph_after_passes, @@ -904,9 +906,9 @@ def test_remove_cat_from_slice_copy(self) -> None: ) builder.output([output]) original = builder.get_graph_module() - graph_after_passes = cast( - PassResult, RemoveCatFromSliceCopyPass()(original) - ).graph_module + pass_result = cast(PassResult, RemoveCatFromSliceCopyPass()(original)) + self.assertTrue(pass_result.modified) + graph_after_passes = pass_result.graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 0 ) @@ -922,9 +924,9 @@ def test_keep_cat_from_slice_copy(self) -> None: ) builder.output([output]) original = builder.get_graph_module() - graph_after_passes = cast( - PassResult, RemoveCatFromSliceCopyPass()(original) - ).graph_module + pass_result = cast(PassResult, RemoveCatFromSliceCopyPass()(original)) + self.assertFalse(pass_result.modified) + graph_after_passes = pass_result.graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 1 )