From 68c038b63b9904cf20f759e9aff087e0e9f8d8fe Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Fri, 1 May 2026 21:30:40 -0700 Subject: [PATCH] Replace AVG_POOL2D with REDUCE_SUM in DecomposeMeanDimPass (#19242) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Replace the avg_pool2d decomposition path in DecomposeMeanDimPass with REDUCE_SUM + MUL(1/N) for all mean.dim reductions. AVG_POOL2D can only pool over spatial (H×W) axes in TOSA/NHWC layout, which forces the compiler to insert TRANSPOSE ops when the reduction is over channels (common in LayerNorm). REDUCE_SUM works on any axis without layout constraints, avoiding those transposes entirely. Reviewed By: 3l1 Differential Revision: D101418199 --- .../arm/_passes/decompose_meandim_pass.py | 77 ++----------------- .../arm/test/misc/test_transpose_counts.py | 10 +-- backends/arm/test/ops/test_cond.py | 2 +- backends/arm/test/ops/test_layer_norm.py | 2 - .../passes/test_decompose_meandim_pass.py | 2 +- 5 files changed, 14 insertions(+), 79 deletions(-) diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index 701d5337636..c7d3bc0a04d 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -16,7 +16,6 @@ ) from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass from executorch.backends.arm.constants import DQ_OPS, Q_OPS -from executorch.exir.backend.utils import WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -51,14 +50,6 @@ def get_dynamic_meandim_decomposition(op) -> tuple: raise RuntimeError(f"Can't get meandim decomposition for op {op}") -def get_avgpool(op): - if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default): - return exir_ops.edge.aten.avg_pool2d.default - if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default): - return torch.ops.aten.avg_pool2d.default - raise RuntimeError(f"Can't get meandim decomposition for op {op}") - - def get_view(op): if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default): return exir_ops.edge.aten.view_copy.default @@ -79,12 +70,11 @@ def get_quantization(op): class DecomposeMeanDimPass(ArmPass): - """Decomposes a meandim into avg_pool and/or sum + mul (1/N). - - :: + """Decomposes a meandim into sum + mul (1/N). - h, w -> avg_pool - n, c -> sum + mul(1/N) + Each reduction dimension is handled via REDUCE_SUM followed by + multiplication by 1/N, which works on any axis without layout + constraints (unlike AVG_POOL2D which only pools over spatial H×W). For rank < 4, the input is reshaped to 4D by padding with dim=1 from the left. @@ -92,10 +82,9 @@ class DecomposeMeanDimPass(ArmPass): Example: x = mean_dim(x, (0,2), keepdim=False) # x = (c,h,w) Becomes: - x = view_copy.default(x, new_shape=(1,c,h,w)) # Reshape to work with avg_pool - x = avg_pool2d.default(x, kernel=(1,w), stride=(1,1)) # Reduce w with avg_pool - x = sum.dim_IntList(x, dim=1, keepdims=True) # Reduce c with sum - x = mul.Tensor(x, 1/c) # Divide by number of channels to get mean + x = view_copy.default(x, new_shape=(1,c,h,w)) # Reshape to 4D + x = sum.dim_IntList(x, dim=(1,3), keepdims=True) # Reduce c,w with sum + x = mul.Tensor(x, 1/(c*w)) # Divide by number of elements to get mean x = view_copy.default(x, new_shape=(h)) # Squeeze dims since keepdims = False """ @@ -110,14 +99,6 @@ def __init__(self, graph_module, tosa_spec, *args, **kwargs): super().__init__(*args, **kwargs) self._graph_module = graph_module self._tosa_spec = tosa_spec - # Lazy import to avoid circular dependency with operator_support - from executorch.backends.arm.operator_support.pool_2d_support import ( - AvgPool2dSupported, - ) - - self._avg_pool_checker = AvgPool2dSupported( - self._tosa_spec, WhyNoPartitionReporter() - ) def call_operator(self, op, args, kwargs, meta, updated=False): if op not in ( @@ -168,12 +149,6 @@ def call_operator(self, op, args, kwargs, meta, updated=False): x = super().call_operator(view_op, (x, new_shape), {}, meta, True) x = self._maybe_insert_q_dq_after(x, meta) - # Reduce (h,w) dims by avg pool if possible - if not has_symbolic_reduce_dim: - x, dims_to_reduce = self._reduce_by_average_pool( - op, x, dims_to_reduce, meta - ) - # Reshape back to 5D if necessary if len(input_shape) > 4: original_dims = input_shape[:-3] @@ -259,44 +234,6 @@ def _reduce_by_sum(self, op, input_node, dims, meta): return super().call_operator(mul_op, (sum, divisor), {}, meta, True) - def _reduce_by_average_pool(self, op, input_node, dims, meta): - dims_to_reduce_by_avgpool = [dim for dim in dims if dim >= 2] - if len(dims_to_reduce_by_avgpool) == 0: - return input_node, dims - - dims_to_reduce_by_sum = [dim for dim in dims if dim < 2] - - avgpool_op = get_avgpool(op) - input_shape = input_node.data.size() - - stride = [1, 1] - if dims_to_reduce_by_avgpool in ([2, 3], [3, 2]): - kernel_size = [input_shape[2], input_shape[3]] - elif dims_to_reduce_by_avgpool == [3]: - kernel_size = [1, input_shape[3]] - elif dims_to_reduce_by_avgpool == [2]: - kernel_size = [input_shape[2], 1] - else: - raise RuntimeError( - f"Bad dims {dims_to_reduce_by_avgpool} for {op} decomposition of mean_dim." - ) - - args = (input_node, kernel_size, stride) - - avg_pool_node = self._graph_module.graph.create_node( - "call_function", avgpool_op, args - ) - is_supported = self._avg_pool_checker.is_node_tosa_supported( - avg_pool_node, self._tosa_spec - ) - - if is_supported: - out = super().call_operator(avgpool_op, args, {}, meta, True) - out = self._maybe_insert_q_dq_after(out, meta) - return out, dims_to_reduce_by_sum - - return input_node, dims - def _maybe_insert_q_dq_after(self, op, meta): """If the input node of op is a dequant node, insert a q-dq pair after op with identical quantization parameters. diff --git a/backends/arm/test/misc/test_transpose_counts.py b/backends/arm/test/misc/test_transpose_counts.py index 55496c8a9b5..068dd28cabc 100644 --- a/backends/arm/test/misc/test_transpose_counts.py +++ b/backends/arm/test/misc/test_transpose_counts.py @@ -404,7 +404,7 @@ def forward(self, x): "groupnorm": TransposeCountCase( GroupNormModule(), (torch.randn(1, 4, 4, 4),), - 1, + 0, ), "multihead_attention_rank2": TransposeCountCase( MultiheadAttentionModule(), @@ -430,16 +430,16 @@ def forward(self, x): Model1ConvMaxPoolResidualLinear(), (torch.randn(2, 8, 64),), 5 ), "model_2_conv_mha_linear_layernorm": TransposeCountCase( - Model2ConvMhaLinearLayerNorm(), (torch.randn(2, 8, 32),), 11 + Model2ConvMhaLinearLayerNorm(), (torch.randn(2, 8, 32),), 9 ), "model_3_lstm_linear": TransposeCountCase( Model3LstmLinear(), (torch.randn(2, 16, 8),), 2 ), "model_4_conv_lstm_linear_layernorm": TransposeCountCase( - Model4ConvLstmLinearLayerNorm(), (torch.randn(2, 8, 32),), 5 + Model4ConvLstmLinearLayerNorm(), (torch.randn(2, 8, 32),), 3 ), "model_5_dwconv_gelu_layernorm_avgpool": TransposeCountCase( - Model5DwConvGeluLayerNormAvgPool(), (torch.randn(1, 8, 16, 16),), 6 + Model5DwConvGeluLayerNormAvgPool(), (torch.randn(1, 8, 16, 16),), 4 ), "model_6_gru_linear": TransposeCountCase( Model6GruLinear(), (torch.randn(2, 16, 8),), 2 @@ -521,7 +521,7 @@ def forward(self, x): "groupnorm_channels_last": TransposeCountCase( GroupNormModule(), (torch.randn(1, 4, 4, 4).to(memory_format=torch.channels_last),), - 3, + 2, ), "cumsum_rank4_dim3_channels_last": TransposeCountCase( CumsumModule(), diff --git a/backends/arm/test/ops/test_cond.py b/backends/arm/test/ops/test_cond.py index d4f856ec761..8c6d9ef329c 100644 --- a/backends/arm/test/ops/test_cond.py +++ b/backends/arm/test/ops/test_cond.py @@ -82,7 +82,7 @@ def true_branch(arg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return arg + torch.sin(arg), arg - torch.sin(arg) def false_branch(arg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - return arg - arg.mean(), arg + arg.mean() + return arg - torch.cos(arg), arg + torch.cos(arg) predicate = x.flatten().sum() > 0 return torch.cond(predicate, true_branch, false_branch, [x]) diff --git a/backends/arm/test/ops/test_layer_norm.py b/backends/arm/test/ops/test_layer_norm.py index 9ecb44e690c..c51789aea65 100644 --- a/backends/arm/test/ops/test_layer_norm.py +++ b/backends/arm/test/ops/test_layer_norm.py @@ -204,8 +204,6 @@ def test_native_layer_norm_16a8w_u55_INT(test_data): u85_xfails_16a8w = { "randn_last_dim": "MLETORCH-1834 - 16A8W native_layer_norm output diff for certain configurations.", - "randn_last_three_dims": "MLETORCH-1834 - 16A8W native_layer_norm output diff for certain configurations.", - "randn_last_three_dims_no_bias": "MLETORCH-1834 - 16A8W native_layer_norm output diff for certain configurations.", } diff --git a/backends/arm/test/passes/test_decompose_meandim_pass.py b/backends/arm/test/passes/test_decompose_meandim_pass.py index df7c78f3525..3da94a5fb98 100644 --- a/backends/arm/test/passes/test_decompose_meandim_pass.py +++ b/backends/arm/test/passes/test_decompose_meandim_pass.py @@ -56,12 +56,12 @@ class MeanDimTensor(torch.nn.Module): ops_after_pass = { "torch.ops.aten.sum.dim_IntList": 2, "torch.ops.aten.mul.Tensor": 1, - "torch.ops.aten.avg_pool2d.default": 1, "torch.ops.aten.reshape.default": 1, } ops_not_after_pass = [ "torch.ops.aten.mean.dim", + "torch.ops.aten.avg_pool2d.default", ] u55_ops_after_pass = {