From 7d8e2a004f36429a0bf8d902212997856be57bd0 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Tue, 27 May 2025 10:10:24 +0200 Subject: [PATCH] Arm backend: Refactor Conv1dUnsqueezePass Simplifies the pass structure and removes the need for using the exported_program. Signed-off-by: Adrian Lundell Change-Id: I60d62bd42b41f7114b1b8d1697ab8097af2e6839 --- backends/arm/_passes/arm_pass_manager.py | 4 +- backends/arm/_passes/conv1d_unsqueeze_pass.py | 165 +++++------------- 2 files changed, 44 insertions(+), 125 deletions(-) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 5051950fce7..668d7a36257 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -118,7 +118,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(UnsqueezeBeforeRepeatPass()) self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) self.add_pass(DecomposeSumPass()) - self.add_pass(Conv1dUnsqueezePass(exported_program)) + self.add_pass(Conv1dUnsqueezePass()) self.add_pass(DecomposeSelectPass()) self.add_pass(ConvertSqueezesToViewPass()) @@ -173,7 +173,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(UnsqueezeBeforeRepeatPass()) self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) self.add_pass(DecomposeSumPass()) - self.add_pass(Conv1dUnsqueezePass(exported_program)) + self.add_pass(Conv1dUnsqueezePass()) self.add_pass(DecomposeSelectPass()) self.add_pass(ConvertSqueezesToViewPass()) diff --git a/backends/arm/_passes/conv1d_unsqueeze_pass.py b/backends/arm/_passes/conv1d_unsqueeze_pass.py index 16c6f6b209f..56f674e9066 100644 --- a/backends/arm/_passes/conv1d_unsqueeze_pass.py +++ b/backends/arm/_passes/conv1d_unsqueeze_pass.py @@ -1,22 +1,13 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. # All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe - -import torch -from executorch.backends.arm._passes.arm_pass_utils import ( - create_node, - get_param_tensor, - is_param_node, -) -from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.pass_base import ExportPass class Conv1dUnsqueezePass(ExportPass): @@ -24,125 +15,53 @@ class Conv1dUnsqueezePass(ExportPass): This pass is used to change conv1d ops into conv2d since TOSA only supports 2d and 3d convolution. This is done by modifying the graph to do the following: - 1) unsqueeze the convolution's input from 3d to 4d + 1a) unsqueeze the convolution's input from 3d to 4d + 1b) unsqueeze the convolution's weight from 3d to 4d 2) perform a conv2d (with a modified version of the original conv1d args) 3) squeeze the output back down to 3d. """ - def __init__(self, exported_program: ExportedProgram) -> None: - super().__init__() - self.exported_program = exported_program - - def unsqueeze_kernel_weights(self, kernel_node): - """ - Unsqueezes the weights of a conv1d to make it 4 dimensional. - - Args: - kernel_node: the weights of conv1d node to be unsqueezed - """ - kernel_param_3d = get_param_tensor(self.exported_program, kernel_node) - if kernel_param_3d is None: - raise AssertionError("Expected param tensor for the kernel node") - - kernel_param_4d = torch.nn.Parameter( - data=kernel_param_3d.data.contiguous().unsqueeze(dim=-1), - requires_grad=False, + def call_operator(self, op, args, kwargs, meta): + if op != exir_ops.edge.aten.convolution.default: + return super().call_operator(op, args, kwargs, meta) + stride = list(args[3]) + if len(stride) != 1: + return super().call_operator(op, args, kwargs, meta) + + x = args[0] + x_unsqueezed_shape = list(x.data.shape) + [1] + x = super().call_operator( + exir_ops.edge.aten.view_copy.default, (x, x_unsqueezed_shape), {}, meta ) - if torch._export.utils.is_param(self.exported_program, kernel_node): - parameter_name = self.exported_program.graph_signature.inputs_to_parameters[ - kernel_node.name - ] - self.exported_program.state_dict[parameter_name] = kernel_param_4d - kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1) - elif torch._export.utils.is_buffer(self.exported_program, kernel_node): - buffer_name = self.exported_program.graph_signature.inputs_to_buffers[ - kernel_node.name - ] - self.exported_program.state_dict[buffer_name] = kernel_param_4d - kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1) - elif torch._export.utils.is_lifted_tensor_constant( - self.exported_program, kernel_node - ): - buffer_name = ( - self.exported_program.graph_signature.inputs_to_lifted_tensor_constants[ - kernel_node.name - ] - ) - self.exported_program.constants[buffer_name] = kernel_param_4d - kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1) - else: - setattr( - kernel_node.graph.owning_module, - kernel_node.target, - kernel_param_4d, - ) - - def call(self, graph_module: torch.fx.GraphModule): - graph = graph_module.graph - node_list = list(graph.nodes) - for node in node_list: - if node.op == "call_function": - if node.target == exir_ops.edge.aten.convolution.default: - stride = list(node.args[3]) - if len(stride) != 1: - # skip conv if it is not 1d - continue - - kernel_node = node.args[1] - - if not is_param_node(self.exported_program, kernel_node): - raise AssertionError( - "Expected op for convolution weight node to be a get_attr node or a parameter" - ) + w_meta = meta.copy() + w_meta.data["input_qparams"] = {} + w_meta.data["output_qparams"] = {} - # Modify graph such that the conv changes from 1d to 2d - self.unsqueeze_kernel_weights(kernel_node) - - # (b) Extend stride, padding, and dilation for extra dim - node.args = ( - node.args[0], - node.args[1], - node.args[2], - node.args[3] + [1], # stride - node.args[4] + [0], # padding - node.args[5] + [1], # dilation - node.args[6], - node.args[7] + [0], - node.args[8], - ) - - # c. Add unsqueeze to input (3d -> 4d) and squeeze to output (4d -> 3d) - # unsqueeze -> conv2d -> squeeze - with graph.inserting_before(node): - input_node = node.args[0] - unsqueeze_before = create_node( - graph, exir_ops.edge.aten.unsqueeze_copy.default - ) - unsqueeze_before.args = ( - input_node, # Input is node's original input - -1, # Last Dimension - ) - node.replace_input_with(input_node, unsqueeze_before) + w = args[1] + w_unsqueezed_shape = list(w.data.shape) + [1] + w = super().call_operator( + exir_ops.edge.aten.view_copy.default, (w, w_unsqueezed_shape), {}, w_meta + ) - with graph.inserting_after(node): - squeeze_after = create_node( - graph, - exir_ops.edge.aten.squeeze_copy.dims, - ) - squeeze_after.args = ( - node, # Input is the conv node - [-1], # Last dimension - ) - original_users = [ - user for user in node.users if user != squeeze_after - ] - for user in original_users: - user.replace_input_with(node, squeeze_after) + new_args = ( + x, + w, + args[2], + args[3] + [1], # stride + args[4] + [0], # padding + args[5] + [1], # dilation + args[6], + args[7] + [0], + args[8], + ) + x = super().call_operator( + exir_ops.edge.aten.convolution.default, new_args, kwargs, meta + ) - graph_module.recompile() - # Since we are overriding "call", we need to call the parent's "call" - # to retrace the graph and regenerate metadata - graph_module = super().call(graph_module).graph_module + x_squeezed_shape = list(x.data.shape)[:-1] + x = super().call_operator( + exir_ops.edge.aten.view_copy.default, (x, x_squeezed_shape), {}, meta + ) - return PassResult(graph_module, True) + return x