From 2486faab00516eb201e9f110398f6ad0c69350f9 Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Fri, 1 Aug 2025 09:51:24 +0200 Subject: [PATCH] Arm backend: Move is_consumer_node_depthwise_conv2d to the pass using it The function is_consumer_node_depthwise_conv2d is only used by annotate_channels_last_dim_order_pass and can therefore be moved closer to where it is used. Change-Id: I15d97d555b809397a90369568af8ee5c6d607e56 Signed-off-by: Sebastian Larsson --- .../arm/_passes/to_tosa_memory_format_pass.py | 16 ++++++++++++++-- backends/arm/tosa_utils.py | 15 --------------- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index 49482a70059..e5d810676d1 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -12,7 +12,6 @@ get_first_fake_tensor, is_param_node, ) -from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -43,6 +42,19 @@ def __init__(self, exported_program: ExportedProgram) -> None: self.exported_program = exported_program super().__init__() + @staticmethod + def _is_consumer_node_depthwise_conv2d(node: torch.fx.Node): + consumer_node = list(node.users)[0] + if consumer_node.target == exir_ops.edge.aten.convolution.default: + consumer_node_inputs = consumer_node.all_input_nodes + groups = consumer_node.args[-1] + in_channels = consumer_node_inputs[0].meta["val"].shape[1] + out_channels = consumer_node_inputs[1].meta["val"].shape[0] + if (in_channels == groups) and (out_channels % in_channels) == 0: + return True + + return False + def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node): """ returns True for w in the following sequence; @@ -53,7 +65,7 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node): consumer_node = list(node.users)[0] if self.is_weight_node_for_depthwise_conv2d(consumer_node): return True - if is_consumer_node_depthwise_conv2d(node): + if self._is_consumer_node_depthwise_conv2d(node): # Check that node is the weight-argument and not input or bias return consumer_node.args[1] == node diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index fec8f4337a2..983d1f9c023 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -16,9 +16,7 @@ import torch from executorch.backends.arm.tosa_mapping import extract_tensor_meta - from executorch.backends.arm.tosa_specification import TosaSpecification -from executorch.exir.dialects._ops import ops as exir_ops from torch._subclasses.fake_tensor import FakeTensor from torch.fx import Node @@ -155,19 +153,6 @@ def build_reshape_tosa_1_0( ) -def is_consumer_node_depthwise_conv2d(node: Node): - consumer_node = list(node.users)[0] - if consumer_node.target == exir_ops.edge.aten.convolution.default: - consumer_node_inputs = consumer_node.all_input_nodes - groups = consumer_node.args[-1] - in_channels = consumer_node_inputs[0].meta["val"].shape[1] - out_channels = consumer_node_inputs[1].meta["val"].shape[0] - if (in_channels == groups) and (out_channels % in_channels) == 0: - return True - - return False - - def tosa_shape(shape, dim_order): reordered = tuple([shape[dim] for dim in dim_order]) # Dynamic shapes in executorch are represented with torch.SymInt objects in the shapes,