From 8210ab6ca87c89805b1626d56df0d116a03c79fc Mon Sep 17 00:00:00 2001 From: cccclai Date: Wed, 1 Oct 2025 21:13:24 -0700 Subject: [PATCH] support qnn mean (dim=None) (#14675) Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776 (cherry picked from commit 9ab5592a6533e9d903d927ff70d9aef83a74f0c6) --- backends/qualcomm/builders/op_mean_dim.py | 19 ++- backends/qualcomm/tests/models.py | 25 ++-- backends/qualcomm/tests/test_qnn_delegate.py | 132 ++++++++++++++++--- 3 files changed, 143 insertions(+), 33 deletions(-) diff --git a/backends/qualcomm/builders/op_mean_dim.py b/backends/qualcomm/builders/op_mean_dim.py index 630b1b0b8de..22cb47ee288 100644 --- a/backends/qualcomm/builders/op_mean_dim.py +++ b/backends/qualcomm/builders/op_mean_dim.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import cast, Dict, List +from typing import cast, Dict import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper @@ -40,7 +40,22 @@ def define_node( ) # mean dims and keep dims - mean_dims = cast(List[int], node.args[1]) + rank = len(input_node.meta["val"].shape) + + if rank == 0: + raise RuntimeError( + "Mean doesn't support 0d input, please report a bug in https://github.com/pytorch/executorch/issues" + ) + + dim_arg = node.args[1] + + if dim_arg is None or len(dim_arg) == 0: + mean_dims = list(range(rank)) # reduce over all dims + elif isinstance(dim_arg, int): + mean_dims = [dim_arg] + else: + mean_dims = list(dim_arg) + print("mean_dims: ", mean_dims, "rank: ", rank) mean_dims = [ mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims ] diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 28ea224d747..299ad56038c 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -4,8 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import torch +from typing import List, Optional, Tuple, Union +import torch # module with related operator only @@ -1323,20 +1324,20 @@ def forward(self, x): return self.max_pool2d(x) -class MeanWKeppDim(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.mean(x, (-1, -2), keepdim=True) - - -class MeanWOKeppDim(torch.nn.Module): - def __init__(self): +class Mean(torch.nn.Module): + def __init__( + self, + dim: Optional[Union[int, Tuple[int, ...], List[int]]] = None, + keepdim: bool = False, + dtype: Optional[torch.dtype] = None, + ): super().__init__() + self.dim = dim + self.keepdim = keepdim + self.dtype = dtype def forward(self, x): - return torch.mean(x, (-1, -2)) + return torch.mean(x, dim=self.dim, keepdim=self.keepdim, dtype=self.dtype) class MaskedFill(torch.nn.Module): diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 9cca7b4203d..b339a4d5aa9 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1011,12 +1011,61 @@ def test_qnn_backend_max_pool2d(self): sample_input = (torch.randn(4, 3, 24, 24),) self.lower_module_and_test_output(module, sample_input) - def test_qnn_backend_mean_dim(self): - modules = [MeanWKeppDim(), MeanWOKeppDim()] # noqa: F405 - sample_input = (torch.randn([2, 5, 1, 3]),) - for i, module in enumerate(modules): + def test_qnn_backend_mean(self): + test_comb = [ + # Reduce over last two dims, keepdim=True + { + QCOM_MODULE: Mean(dim=(-1, -2), keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),), + }, + # Reduce over last two dims, keepdim=False + { + QCOM_MODULE: Mean(dim=(-1, -2), keepdim=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),), + }, + # Default: reduce all dims + { + QCOM_MODULE: Mean(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(10, 10),), + }, + # TODO: To be enabled via reshape input to 1d tensor + # # Scalar case + # { + # QCOM_MODULE: Mean(), + # QCOM_SAMPLE_INPUTS: (torch.tensor(5.0),), + # }, + # Edge case: dim is a empty list + { + QCOM_MODULE: Mean(dim=[]), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),), + }, + # Edge case: reduce along dim=0 (batch dimension) + { + QCOM_MODULE: Mean(dim=0), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),), + }, + # Edge case: reduce along dim=0 with keepdim=True + { + QCOM_MODULE: Mean(dim=0, keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),), + }, + # Edge case: reduce along multiple dims + { + QCOM_MODULE: Mean(dim=(0, 2)), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(3, 4, 5),), + }, + # Edge case: high-dimensional tensor + { + QCOM_MODULE: Mean(dim=(1, 3), keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4, 5, 6),), + }, + ] + + for i, test in enumerate(test_comb): with self.subTest(i=i): - self.lower_module_and_test_output(module, sample_input) + self.lower_module_and_test_output( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) @unittest.skip("failed to lower in QNN 2.26") def test_qnn_backend_mha(self): @@ -1209,10 +1258,8 @@ def test_qnn_backend_slice_scatter(self): ], QCOM_SAMPLE_INPUTS: [ ( - ( - torch.zeros(8, 8), - torch.ones(8, 2), - ) + torch.zeros(8, 8), + torch.ones(8, 2), ) ], }, @@ -2641,13 +2688,62 @@ def test_qnn_backend_max_pool2d(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) - def test_qnn_backend_mean_dim(self): - modules = [MeanWKeppDim(), MeanWOKeppDim()] # noqa: F405 - sample_input = (torch.randn([2, 5, 1, 3]),) - for i, module in enumerate(modules): + def test_qnn_backend_mean(self): + test_comb = [ + # Reduce over last two dims, keepdim=True + { + QCOM_MODULE: Mean(dim=(-1, -2), keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),), + }, + # Reduce over last two dims, keepdim=False + { + QCOM_MODULE: Mean(dim=(-1, -2), keepdim=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),), + }, + # Default: reduce all dims + { + QCOM_MODULE: Mean(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(10, 10),), + }, + # TODO: To be enabled via reshape input to 1d tensor + # Scalar case + # { + # QCOM_MODULE: Mean(), + # QCOM_SAMPLE_INPUTS: (torch.tensor(5.0),), + # }, + # Edge case: dim is a empty list + { + QCOM_MODULE: Mean(dim=[]), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),), + }, + # Edge case: reduce along dim=0 (batch dimension) + { + QCOM_MODULE: Mean(dim=0), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),), + }, + # Edge case: reduce along dim=0 with keepdim=True + { + QCOM_MODULE: Mean(dim=0, keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),), + }, + # Edge case: reduce along multiple dims + { + QCOM_MODULE: Mean(dim=(0, 2)), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(3, 4, 5),), + }, + # Edge case: high-dimensional tensor + { + QCOM_MODULE: Mean(dim=(1, 3), keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4, 5, 6),), + }, + ] + + for i, test in enumerate(test_comb): with self.subTest(i=i): - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + module = self.get_qdq_module( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + self.lower_module_and_test_output(module, test[QCOM_SAMPLE_INPUTS]) def test_qnn_backend_mha(self): module = MultiheadAttention() # noqa: F405 @@ -2872,10 +2968,8 @@ def test_qnn_backend_slice_scatter(self): ], QCOM_SAMPLE_INPUTS: [ ( - ( - torch.zeros(8, 8), - torch.ones(8, 2), - ) + torch.zeros(8, 8), + torch.ones(8, 2), ) ], },