Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions backends/qualcomm/builders/op_mean_dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast, Dict, List
from typing import cast, Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

Expand Down Expand Up @@ -40,7 +40,22 @@ def define_node(
)

# mean dims and keep dims
mean_dims = cast(List[int], node.args[1])
rank = len(input_node.meta["val"].shape)

if rank == 0:
raise RuntimeError(
"Mean doesn't support 0d input, please report a bug in https://github.com/pytorch/executorch/issues"
)

dim_arg = node.args[1]

if dim_arg is None or len(dim_arg) == 0:
mean_dims = list(range(rank)) # reduce over all dims
elif isinstance(dim_arg, int):
mean_dims = [dim_arg]
else:
mean_dims = list(dim_arg)
print("mean_dims: ", mean_dims, "rank: ", rank)
mean_dims = [
mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims
]
Expand Down
25 changes: 13 additions & 12 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from typing import List, Optional, Tuple, Union

import torch

# module with related operator only

Expand Down Expand Up @@ -1323,20 +1324,20 @@ def forward(self, x):
return self.max_pool2d(x)


class MeanWKeppDim(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.mean(x, (-1, -2), keepdim=True)


class MeanWOKeppDim(torch.nn.Module):
def __init__(self):
class Mean(torch.nn.Module):
def __init__(
self,
dim: Optional[Union[int, Tuple[int, ...], List[int]]] = None,
keepdim: bool = False,
dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.dim = dim
self.keepdim = keepdim
self.dtype = dtype

def forward(self, x):
return torch.mean(x, (-1, -2))
return torch.mean(x, dim=self.dim, keepdim=self.keepdim, dtype=self.dtype)


class MaskedFill(torch.nn.Module):
Expand Down
132 changes: 113 additions & 19 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,12 +1011,61 @@ def test_qnn_backend_max_pool2d(self):
sample_input = (torch.randn(4, 3, 24, 24),)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_mean_dim(self):
modules = [MeanWKeppDim(), MeanWOKeppDim()] # noqa: F405
sample_input = (torch.randn([2, 5, 1, 3]),)
for i, module in enumerate(modules):
def test_qnn_backend_mean(self):
test_comb = [
# Reduce over last two dims, keepdim=True
{
QCOM_MODULE: Mean(dim=(-1, -2), keepdim=True), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),),
},
# Reduce over last two dims, keepdim=False
{
QCOM_MODULE: Mean(dim=(-1, -2), keepdim=False), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),),
},
# Default: reduce all dims
{
QCOM_MODULE: Mean(), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(10, 10),),
},
# TODO: To be enabled via reshape input to 1d tensor
# # Scalar case
# {
# QCOM_MODULE: Mean(),
# QCOM_SAMPLE_INPUTS: (torch.tensor(5.0),),
# },
# Edge case: dim is a empty list
{
QCOM_MODULE: Mean(dim=[]), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),),
},
# Edge case: reduce along dim=0 (batch dimension)
{
QCOM_MODULE: Mean(dim=0), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),),
},
# Edge case: reduce along dim=0 with keepdim=True
{
QCOM_MODULE: Mean(dim=0, keepdim=True), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),),
},
# Edge case: reduce along multiple dims
{
QCOM_MODULE: Mean(dim=(0, 2)), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(3, 4, 5),),
},
# Edge case: high-dimensional tensor
{
QCOM_MODULE: Mean(dim=(1, 3), keepdim=True), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4, 5, 6),),
},
]

for i, test in enumerate(test_comb):
with self.subTest(i=i):
self.lower_module_and_test_output(module, sample_input)
self.lower_module_and_test_output(
test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS]
)

@unittest.skip("failed to lower in QNN 2.26")
def test_qnn_backend_mha(self):
Expand Down Expand Up @@ -1209,10 +1258,8 @@ def test_qnn_backend_slice_scatter(self):
],
QCOM_SAMPLE_INPUTS: [
(
(
torch.zeros(8, 8),
torch.ones(8, 2),
)
torch.zeros(8, 8),
torch.ones(8, 2),
)
],
},
Expand Down Expand Up @@ -2641,13 +2688,62 @@ def test_qnn_backend_max_pool2d(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_mean_dim(self):
modules = [MeanWKeppDim(), MeanWOKeppDim()] # noqa: F405
sample_input = (torch.randn([2, 5, 1, 3]),)
for i, module in enumerate(modules):
def test_qnn_backend_mean(self):
test_comb = [
# Reduce over last two dims, keepdim=True
{
QCOM_MODULE: Mean(dim=(-1, -2), keepdim=True), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),),
},
# Reduce over last two dims, keepdim=False
{
QCOM_MODULE: Mean(dim=(-1, -2), keepdim=False), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),),
},
# Default: reduce all dims
{
QCOM_MODULE: Mean(), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(10, 10),),
},
# TODO: To be enabled via reshape input to 1d tensor
# Scalar case
# {
# QCOM_MODULE: Mean(),
# QCOM_SAMPLE_INPUTS: (torch.tensor(5.0),),
# },
# Edge case: dim is a empty list
{
QCOM_MODULE: Mean(dim=[]), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),),
},
# Edge case: reduce along dim=0 (batch dimension)
{
QCOM_MODULE: Mean(dim=0), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),),
},
# Edge case: reduce along dim=0 with keepdim=True
{
QCOM_MODULE: Mean(dim=0, keepdim=True), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),),
},
# Edge case: reduce along multiple dims
{
QCOM_MODULE: Mean(dim=(0, 2)), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(3, 4, 5),),
},
# Edge case: high-dimensional tensor
{
QCOM_MODULE: Mean(dim=(1, 3), keepdim=True), # noqa: F405
QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4, 5, 6),),
},
]

for i, test in enumerate(test_comb):
with self.subTest(i=i):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)
module = self.get_qdq_module(
test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS]
)
self.lower_module_and_test_output(module, test[QCOM_SAMPLE_INPUTS])

def test_qnn_backend_mha(self):
module = MultiheadAttention() # noqa: F405
Expand Down Expand Up @@ -2872,10 +2968,8 @@ def test_qnn_backend_slice_scatter(self):
],
QCOM_SAMPLE_INPUTS: [
(
(
torch.zeros(8, 8),
torch.ones(8, 2),
)
torch.zeros(8, 8),
torch.ones(8, 2),
)
],
},
Expand Down
Loading