From 3abf464c9734a81d3770a83868fcc9c95f7cf109 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Thu, 28 Nov 2024 16:49:27 +0100 Subject: [PATCH] Add stricter transpose condition for TOSA reshape lowering Removes transposes in lowered graph for reshapes in only H,W dimension and clarifies logic in the annotate_channels_last_dum_order_pass Change-Id: I87e8575d7da8ad56a1f4e937837d7549c05aa11e --- .../annotate_channels_last_dim_order_pass.py | 154 +++++++++++++----- backends/arm/test/ops/test_view.py | 1 + 2 files changed, 118 insertions(+), 37 deletions(-) diff --git a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py index 786117e6457..80c5f3c442d 100644 --- a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py +++ b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py @@ -12,6 +12,7 @@ from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, + get_node_arg, insert_q_dq_pair, ) from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op @@ -83,14 +84,48 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node): return False - def insert_input_transpose(self, node, input_node, graph_module): + @staticmethod + def memory_format_differs(shape): + """Returns true if the shape will have a different memory layout in NCHW and NHWC format""" + if len(shape) >= 4: + C = shape[1] + H = shape[2] + W = shape[3] + elif len(shape) == 3: + C = shape[0] + H = shape[1] + W = shape[2] + if len(shape) <= 2: + return False + + return C > 1 and (H > 1 or W > 1) + + @staticmethod + def is_channel_reshape(input_shape, output_shape): + """Returns true if the reshape changes the channel dimension""" + if not len(input_shape) == len(output_shape) == 4: + return False + + C_old = input_shape[1] + C_new = output_shape[1] + + N_new = output_shape[0] + N_old = input_shape[0] + + return (N_old != N_new) or (C_old != C_new) + + @staticmethod + def insert_input_transpose(node, input_node, graph_module): quantize = input_node.target == dq_op q_params = input_node.args[1:] if quantize else None with graph_module.graph.inserting_before(node): permute_node = create_node( graph_module.graph, torch.ops.passthrough_to_tosa._transpose, - args=(input_node, list(self.NHWC_inverse_order)), + args=( + input_node, + list(AnnotateChannelsLastDimOrder.NHWC_inverse_order), + ), quantize=quantize, q_params=q_params, ) @@ -100,14 +135,17 @@ def insert_input_transpose(self, node, input_node, graph_module): range(len(input_node.meta["val"].size())) ) - def insert_output_transpose(self, node, graph_module): + @staticmethod + def insert_output_transpose(node, graph_module): with graph_module.graph.inserting_after(node): permute_node = create_node( graph_module.graph, torch.ops.passthrough_to_tosa._transpose, - args=(node, list(self.NHWC_order)), + args=(node, list(AnnotateChannelsLastDimOrder.NHWC_order)), + ) + permute_node.meta["tosa_dim_order"] = ( + AnnotateChannelsLastDimOrder.NHWC_order ) - permute_node.meta["tosa_dim_order"] = self.NHWC_order node.meta["tosa_dim_order"] = (0, 1, 2, 3) users = [user for user in node.users if user != permute_node] for user in users: @@ -118,54 +156,96 @@ def insert_output_transpose(self, node, graph_module): q_params = node.args[0].args[1:] insert_q_dq_pair(graph_module.graph, node, q_params) + @staticmethod + def _insert_squeeze_transpose( + input_shape, output_shape, node, input_node, graph_module + ): + nhwc_to_nhwc = len(input_shape) == 4 and len(output_shape) <= 3 + + if nhwc_to_nhwc and AnnotateChannelsLastDimOrder.memory_format_differs( + input_shape + ): + AnnotateChannelsLastDimOrder.insert_input_transpose( + node, input_node, graph_module + ) + + @staticmethod + def _insert_unsqueeze_transpose(input_shape, output_shape, node, graph_module): + nchw_to_nhwc = len(input_shape) == 3 and len(output_shape) == 4 + if nchw_to_nhwc and AnnotateChannelsLastDimOrder.memory_format_differs( + output_shape + ): + AnnotateChannelsLastDimOrder.insert_output_transpose(node, graph_module) + + @staticmethod + def _insert_view_transpose( + input_shape, output_shape, node, input_node, graph_module + ): + nchw_to_nhwc = len(input_shape) < 4 and len(output_shape) == 4 + nhwc_to_nchw = len(input_shape) == 4 and len(output_shape) < 4 + channel_reshape = AnnotateChannelsLastDimOrder.is_channel_reshape( + output_shape, input_shape + ) + + if ( + channel_reshape or nhwc_to_nchw + ) and AnnotateChannelsLastDimOrder.memory_format_differs(input_shape): + AnnotateChannelsLastDimOrder.insert_input_transpose( + node, input_node, graph_module + ) + if ( + channel_reshape or nchw_to_nhwc + ) and AnnotateChannelsLastDimOrder.memory_format_differs(output_shape): + AnnotateChannelsLastDimOrder.insert_output_transpose(node, graph_module) + def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): """ - Reshape operations are not equivalent in NCHW and NHWC. - To get around this, transposes need to be added if the previous or new shape - fulfil the following condition: - C > 1 and (H or W > 1) - - This is relevant for the following operations; - squeeze: 4D -> 3D - unsqueeze: <4D -> 4D - view: <4D -> 4D - view: 4D -> <4D - view: 4D -> 4D - """ - - def transpose_condition(shape): - if len(shape) != 4: - return False - C = shape[1] - H = shape[2] - W = shape[3] - return C > 1 and (H > 1 or W > 1) + Transposes are needed for operators transforming the input to a different rank, as 4D-tensors are assumed to be in NHWC-format, whereas all other are in NCHW format. + This is relevant for the following cases: + - squeeze: 4D -> <4D + - unsqueeze: 3D -> 4D + - view: <4D -> 4D + - view: 4D -> <4D + Additionally, a 4D->4D view operation acting on the channel dimension currently needs to be performed in NCHW format, leadning to one extra input and output transpose for this case. + Transposes can be avoided for shapes where there is no difference in actual memory, e.g for + - H == W == 1 + - C == 1 + - 1D/2D tensors + """ for node in graph_module.graph.nodes: if node.op != "call_function": continue + if node.target == exir_ops.edge.aten.squeeze_copy.dims: input_node = node.args[0] input_shape = input_node.meta["val"].shape - if transpose_condition(input_shape): - self.insert_input_transpose(node, input_node, graph_module) + output_shape = node.meta["val"].shape + + self._insert_squeeze_transpose( + input_shape, output_shape, node, input_node, graph_module + ) elif node.target == exir_ops.edge.aten.unsqueeze_copy.default: + input_node = get_node_arg(node.args, 0, default_value=False) + if input_node: + input_shape = input_node.meta["val"].shape + else: + input_shape = () output_shape = node.meta["val"].shape - if transpose_condition(output_shape): - self.insert_output_transpose(node, graph_module) + + self._insert_unsqueeze_transpose( + input_shape, output_shape, node, graph_module + ) elif node.target == exir_ops.edge.aten.view_copy.default: input_node = node.args[0] + input_shape = input_node.meta["val"].shape + output_shape = node.meta["val"].shape - old_shape = input_node.meta["val"].shape - new_shape = node.meta["val"].shape - - if transpose_condition(old_shape): - self.insert_input_transpose(node, input_node, graph_module) - - if transpose_condition(new_shape): - self.insert_output_transpose(node, graph_module) + self._insert_view_transpose( + input_shape, output_shape, node, input_node, graph_module + ) def call(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: diff --git a/backends/arm/test/ops/test_view.py b/backends/arm/test/ops/test_view.py index 09a8f57bd39..07a32fe5951 100644 --- a/backends/arm/test/ops/test_view.py +++ b/backends/arm/test/ops/test_view.py @@ -43,6 +43,7 @@ class View(torch.nn.Module): (torch.rand(1, 1, 5, 10), (1, 1, 50, 1)), (torch.rand(5, 10, 1, 1), (1, 25, 2)), (torch.rand(2, 50, 1, 1), (1, 100)), + (torch.rand(2, 3, 2, 3), (2, 3, 3, 2)), ] def forward(self, x: torch.Tensor, new_shape):