From 2058d8b1340240bdb8402f7943c3c31d30505ab3 Mon Sep 17 00:00:00 2001 From: Riley Dulin Date: Fri, 20 Dec 2024 10:38:33 -0800 Subject: [PATCH] Update RemovePermutesAroundElementwiseOps to work with view as well (#7407) Summary: The RemovePermutesAroundElementwiseOps pass was working well for permutes, but sometimes permutes get optimized into `view_copy` if the dimension being moved doesn't change the byte-level arrangement of the Tensor. Handle this case so we can remove more functions in these chains. Reviewed By: zonglinpeng Differential Revision: D67471456 --- backends/cadence/aot/remove_ops.py | 47 +++++++++++++++---- .../aot/tests/test_remove_ops_passes.py | 31 ++++++++++++ 2 files changed, 69 insertions(+), 9 deletions(-) diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index 6c25aa6a6fd..19551cc5fcb 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -16,7 +16,7 @@ import itertools import logging from dataclasses import dataclass, field -from typing import Callable, cast, Dict, List, Optional, Sequence +from typing import Callable, cast, Dict, Iterable, List, Optional, Sequence, Union import torch import torch.fx @@ -698,16 +698,45 @@ def traverse_intermediate_node(self, node: torch.fx.Node, sg: Subgraph) -> None: sg.is_valid = False def is_starting_permute(self, node: torch.fx.Node) -> bool: - return ( - node.target == exir_ops.edge.aten.permute_copy.default - and cast(list[int], node.args[1]) == self.to_NCHW - ) + return self.is_boundary_permute(node, self.to_NCHW) def is_ending_permute(self, node: torch.fx.Node) -> bool: - return ( - node.target == exir_ops.edge.aten.permute_copy.default - and cast(list[int], node.args[1]) == self.to_NHWC - ) + return self.is_boundary_permute(node, self.to_NHWC) + + @staticmethod + def is_boundary_permute(node: torch.fx.Node, permute_dims: Iterable[int]) -> bool: + permute_dims = list(permute_dims) + if node.target == exir_ops.edge.aten.permute_copy.default: + return cast(list[int], node.args[1]) == permute_dims + elif node.target == exir_ops.edge.aten.view_copy.default: + # If there's a view node, check if it's swapping two dimensions and + # not splitting any others from the input shape. + inp = node.args[0] + if not isinstance(inp, torch.fx.Node): + return False + input_shape = inp.meta["val"].shape + output_shape = node.args[1] + assert isinstance(output_shape, (tuple, list)) + # If the shapes are equal in length, no dimension is being split or + # grouped. Then check if a permute of the input shape results in the output shape. + return ( + len(input_shape) == len(output_shape) + and len(input_shape) == len(permute_dims) + and RemovePermutesAroundElementwiseOps.permute_shape( + input_shape, permute_dims + ) + == output_shape + ) + else: + return False + + @staticmethod + def permute_shape( + shape: Union[List[int], torch.Size], permute_dims: Iterable[int] + ) -> List[int]: + permute_dims = list(permute_dims) + assert len(shape) == len(permute_dims) + return [shape[p] for p in permute_dims] # The following class consolidates functions to remove ops that are redundant diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index f465b55c8d6..25a32a5f077 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -649,6 +649,37 @@ def forward(self, x, y): ][0] self.assertEqual(cat.args[1], 3) + def test_remove_permutes_around_concat_with_views(self) -> None: + class M(torch.nn.Module): + def forward(self, x, y): + # Mix and match views that are permutes and actual permutes. Both + # should be removed. + x = x.view(1, 1, 4, 4) + y = torch.permute(y, [0, 3, 1, 2]) + z = torch.cat((x, y), 1) + return z.view(1, 4, 4, 8) + + inputs = (torch.randn(1, 4, 4, 1), torch.randn(1, 4, 4, 7)) + graph_module = export_to_edge(M(), inputs).exported_program().graph_module + p = RemovePermutesAroundElementwiseOps() + graph_module = cast(PassResult, p(graph_module)).graph_module + + # Expect 0 permutes and views to remain. + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 0 + ) + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.view_copy.default), 0 + ) + + # verify that cat was updated correctly + cat = [ + n + for n in graph_module.graph.nodes + if n.target == exir_ops.edge.aten.cat.default + ][0] + self.assertEqual(cat.args[1], 3) + def test_remove_permutes_around_elemwise_ops_noop(self) -> None: class M(torch.nn.Module): def __init__(self):