Skip to content

Commit 6948323

Browse files
cccclaifacebook-github-bot
authored andcommitted
support qnn mean (dim=None) (#14675)
Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776
1 parent 68b2d3c commit 6948323

File tree

3 files changed

+74
-24
lines changed

3 files changed

+74
-24
lines changed

backends/qualcomm/builders/op_mean_dim.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,16 @@ def define_node(
4040
)
4141

4242
# mean dims and keep dims
43-
mean_dims = cast(List[int], node.args[1])
43+
rank = len(input_node.meta["val"].shape)
44+
dim_arg = node.args[1]
45+
46+
if dim_arg is None:
47+
mean_dims = list(range(rank)) # reduce over all dims
48+
elif isinstance(dim_arg, int):
49+
mean_dims = [dim_arg]
50+
else:
51+
mean_dims = list(dim_arg)
52+
4453
mean_dims = [
4554
mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims
4655
]

backends/qualcomm/tests/models.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
8+
from typing import Optional, Union, Tuple, List
99

1010
# module with related operator only
1111

@@ -1332,20 +1332,20 @@ def forward(self, x):
13321332
return self.max_pool2d(x)
13331333

13341334

1335-
class MeanWKeppDim(torch.nn.Module):
1336-
def __init__(self):
1337-
super().__init__()
1338-
1339-
def forward(self, x):
1340-
return torch.mean(x, (-1, -2), keepdim=True)
1341-
1342-
1343-
class MeanWOKeppDim(torch.nn.Module):
1344-
def __init__(self):
1335+
class Mean(torch.nn.Module):
1336+
def __init__(
1337+
self,
1338+
dim: Optional[Union[int, Tuple[int, ...], List[int]]] = None,
1339+
keepdim: bool = False,
1340+
dtype: Optional[torch.dtype] = None,
1341+
):
13451342
super().__init__()
1343+
self.dim = dim
1344+
self.keepdim = keepdim
1345+
self.dtype = dtype
13461346

13471347
def forward(self, x):
1348-
return torch.mean(x, (-1, -2))
1348+
return torch.mean(x, dim=self.dim, keepdim=self.keepdim, dtype=self.dtype)
13491349

13501350

13511351
class MaskedFill(torch.nn.Module):

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,12 +1018,32 @@ def test_qnn_backend_max_pool2d(self):
10181018
sample_input = (torch.randn(4, 3, 24, 24),)
10191019
self.lower_module_and_test_output(module, sample_input)
10201020

1021-
def test_qnn_backend_mean_dim(self):
1022-
modules = [MeanWKeppDim(), MeanWOKeppDim()] # noqa: F405
1023-
sample_input = (torch.randn([2, 5, 1, 3]),)
1024-
for i, module in enumerate(modules):
1021+
def test_qnn_backend_mean(self):
1022+
test_comb = [
1023+
{
1024+
QCOM_MODULE: Mean(dim=(-1, -2), keepdim=True), # keepdim=True
1025+
QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),),
1026+
},
1027+
{
1028+
QCOM_MODULE: Mean(dim=(-1, -2), keepdim=False), # keepdim=False
1029+
QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),),
1030+
},
1031+
{
1032+
QCOM_MODULE: Mean(), # default: reduce all dims
1033+
QCOM_SAMPLE_INPUTS: (torch.randn(10, 10),),
1034+
},
1035+
{
1036+
QCOM_MODULE: Mean(), # scalar case
1037+
QCOM_SAMPLE_INPUTS: (torch.tensor([5.0]),),
1038+
},
1039+
]
1040+
1041+
for i, test in enumerate(test_comb):
10251042
with self.subTest(i=i):
1026-
self.lower_module_and_test_output(module, sample_input)
1043+
module = self.get_qdq_module(
1044+
test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS]
1045+
)
1046+
self.lower_module_and_test_output(module, test[QCOM_SAMPLE_INPUTS])
10271047

10281048
@unittest.skip("failed to lower in QNN 2.26")
10291049
def test_qnn_backend_mha(self):
@@ -2666,13 +2686,34 @@ def test_qnn_backend_max_pool2d(self):
26662686
module = self.get_qdq_module(module, sample_input)
26672687
self.lower_module_and_test_output(module, sample_input)
26682688

2669-
def test_qnn_backend_mean_dim(self):
2670-
modules = [MeanWKeppDim(), MeanWOKeppDim()] # noqa: F405
2671-
sample_input = (torch.randn([2, 5, 1, 3]),)
2672-
for i, module in enumerate(modules):
2689+
def test_qnn_backend_mean(self):
2690+
test_comb = [
2691+
{
2692+
QCOM_MODULE: Mean(dim=(-1, -2), keepdim=True), # keepdim=True
2693+
QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),),
2694+
},
2695+
{
2696+
QCOM_MODULE: Mean(dim=(-1, -2), keepdim=False), # keepdim=False
2697+
QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),),
2698+
},
2699+
{
2700+
QCOM_MODULE: Mean(), # default: reduce all dims
2701+
QCOM_SAMPLE_INPUTS: (torch.randn(10, 10),),
2702+
},
2703+
{
2704+
QCOM_MODULE: Mean(), # scalar case
2705+
QCOM_SAMPLE_INPUTS: (torch.tensor([5.0]),),
2706+
},
2707+
]
2708+
2709+
for i, test in enumerate(test_comb):
26732710
with self.subTest(i=i):
2674-
module = self.get_qdq_module(module, sample_input)
2675-
self.lower_module_and_test_output(module, sample_input)
2711+
module = self.get_qdq_module(
2712+
test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS]
2713+
)
2714+
module = self.get_qdq_module(module, test[QCOM_SAMPLE_INPUTS])
2715+
self.lower_module_and_test_output(module, test[QCOM_SAMPLE_INPUTS])
2716+
26762717

26772718
def test_qnn_backend_mha(self):
26782719
module = MultiheadAttention() # noqa: F405

0 commit comments

Comments
 (0)