From 51f99c3d084aa716b12be1973f5376c532902e7f Mon Sep 17 00:00:00 2001 From: jethroqti Date: Thu, 30 Oct 2025 04:42:25 -0700 Subject: [PATCH] enable operator avg_pool3d and adaptive_avg_pool3d Summary Enable avg_pool3d and adaptive_pool3d operators Test plan python backends/qualcomm/tests/test_qnn_delegate.py TestQNNFloatingPointOperator.test_qnn_backend_avg_pool3d -b build-android -H HOST -s DEVICE -m CHIPID python backends/qualcomm/tests/test_qnn_delegate.py TestQNNFloatingPointOperator.test_qnn_backend_adaptive_avg_pool3d -b build-android -H HOST -s DEVICE -m CHIPID python backends/qualcomm/tests/test_qnn_delegate.py TestQNNQuantizedOperator.test_qnn_backend_avg_pool3d -b build-android -H HOST -s DEVICE -m CHIPID python backends/qualcomm/tests/test_qnn_delegate.py TestQNNQuantizedOperator.test_qnn_backend_adaptive_avg_pool3d -b build-android -H HOST -s DEVICE -m CHIPID --- backends/qualcomm/_passes/layout_transform.py | 2 + backends/qualcomm/builders/README.md | 2 +- backends/qualcomm/builders/__init__.py | 2 + backends/qualcomm/builders/op_avg_pool3d.py | 299 ++++++++++++++++++ backends/qualcomm/builders/qnn_constants.py | 15 + backends/qualcomm/partition/common_defs.py | 2 - backends/qualcomm/quantizer/annotators.py | 12 + backends/qualcomm/tests/models.py | 25 ++ backends/qualcomm/tests/test_qnn_delegate.py | 70 ++++ 9 files changed, 426 insertions(+), 3 deletions(-) create mode 100644 backends/qualcomm/builders/op_avg_pool3d.py diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 364ec6b4880..a96c5b21d42 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -42,7 +42,9 @@ class LayoutTransform(ExportPass): layout_sensitive_ops = { exir_ops.edge.aten.adaptive_avg_pool2d.default, + exir_ops.edge.aten._adaptive_avg_pool3d.default, exir_ops.edge.aten.avg_pool2d.default, + exir_ops.edge.aten.avg_pool3d.default, exir_ops.edge.aten.convolution.default, exir_ops.edge.aten.instance_norm.default, exir_ops.edge.aten.max_pool2d_with_indices.default, diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md index 61ae1061214..54cfae6591c 100644 --- a/backends/qualcomm/builders/README.md +++ b/backends/qualcomm/builders/README.md @@ -448,7 +448,7 @@ Please help update following table if you are contributing new operators: | Pack | ✓ | | Pad | ✓ | | PoolAvg2d | ✓ | -| PoolAvg3d | ✗ | +| PoolAvg3d | ✓ | | PoolMax2d | ✓ | | Prelu | ✓ | | Quantize | ✓ | diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index 3fa8ae067fa..4bf0ea7e210 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -18,6 +18,7 @@ op_asin, op_atan, op_avg_pool2d, + op_avg_pool3d, op_batch_norm, op_binary, op_bmm, @@ -123,6 +124,7 @@ op_asin, op_atan, op_avg_pool2d, + op_avg_pool3d, op_batch_norm, op_binary, op_bmm, diff --git a/backends/qualcomm/builders/op_avg_pool3d.py b/backends/qualcomm/builders/op_avg_pool3d.py new file mode 100644 index 00000000000..9d585eaeecb --- /dev/null +++ b/backends/qualcomm/builders/op_avg_pool3d.py @@ -0,0 +1,299 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import warnings +from typing import cast, Dict, List + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import numpy as np + +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA + +from .node_visitor import NodeVisitor +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpPoolAvg3d, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class AvgPool3d(NodeVisitor): + target = ["aten.avg_pool3d.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + # kernel info + filter_size = cast(List[int], node.args[1]) + if len(filter_size) == 1: + filter_size *= 3 + filter_size_shape = [len(filter_size)] + + # stride info + stride = cast(List[int], node.args[2]) + if len(stride) == 1: + stride *= 3 + stride_shape = [len(stride)] + + # padding info + padding = [0, 0, 0] + if len(node.args) > 3: + padding = cast(List[int], node.args[3]) + if len(padding) == 1: + padding *= 3 + + # if ceil mode is True, use ceil instead of floor to compute the output shape + mode = OpPoolAvg3d.RoundingMode.FLOOR + if len(node.args) > 4: + ceil_mode = cast(bool, node.args[4]) + if ceil_mode: + mode = OpPoolAvg3d.RoundingMode.CEIL + + count_pad_for_edges = node.args[5] if len(node.args) > 5 else False + + # pad left, pad right + depth_pad_l = padding[0] + depth_pad_r = padding[0] + height_pad_l = padding[1] + height_pad_r = padding[1] + width_pad_l = padding[2] + width_pad_r = padding[2] + + shape_pad = [ + [depth_pad_l, depth_pad_r], + [height_pad_l, height_pad_r], + [width_pad_l, width_pad_r], + ] + padding_shape = [len(shape_pad), len(shape_pad[0])] + + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + avg_pool3d_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpPoolAvg3d.op_name, + ) + + avg_pool3d_op.AddInputTensors([input_tensor_wrapper]) + avg_pool3d_op.AddOutputTensors([output_tensor_wrapper]) + + avg_pool3d_op.AddTensorParam( + OpPoolAvg3d.param_filter_size, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(filter_size_shape), + filter_size_shape, + np.array( + filter_size, + dtype=np.uint32, + ), + True, + ) + + avg_pool3d_op.AddTensorParam( + OpPoolAvg3d.param_stride, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(stride_shape), + stride_shape, + np.array( + stride, + dtype=np.uint32, + ), + True, + ) + + avg_pool3d_op.AddTensorParam( + OpPoolAvg3d.param_pad_amount, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(padding_shape), + padding_shape, + np.array( + shape_pad, + dtype=np.uint32, + ), + True, + ) + + avg_pool3d_op.AddScalarParam( + OpPoolAvg3d.param_count_pad_for_edges, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + {QCOM_DATA: count_pad_for_edges}, + ) + + avg_pool3d_op.AddScalarParam( + OpPoolAvg3d.param_rounding_mode, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(mode)}, + ) + + return avg_pool3d_op + + +@register_node_visitor +class AdaptiveAvgPool3d(NodeVisitor): + target = ["aten._adaptive_avg_pool3d.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + # NOTE: This operator is layout sensitive, so the input tensor shape is always N,D,H,W,C. + input_depth = input_tensor.shape[1] + input_height = input_tensor.shape[2] + input_width = input_tensor.shape[3] + output_depth = node.args[1][0] + output_height = node.args[1][1] + output_width = node.args[1][2] + if output_depth is None: + output_depth = input_depth + if output_height is None: + output_height = input_height + if output_width is None: + output_width = input_width + + # kernel info & stride info + stride_height = input_height // output_height + filter_height = input_height - (output_height - 1) * stride_height + stride_width = input_width // output_width + filter_width = input_width - (output_width - 1) * stride_width + stride_depth = input_depth // output_depth + filter_depth = input_depth - (output_depth - 1) * stride_depth + + filter_size = [filter_depth, filter_height, filter_width] + filter_shape = [len(filter_size)] + stride = [stride_depth, stride_height, stride_width] + stride_shape = [len(stride)] + + depth = (output_depth - 1) * stride_depth + filter_depth - input_depth + height = (output_height - 1) * stride_height + filter_height - input_height + width = (output_width - 1) * stride_width + filter_width - input_width + + if any(x != 0 for x in (depth, height, width)): + warnings.warn( + "[QNN Delegate Op Builder]: Depth or Height or Width is not suitable, fallback op", + stacklevel=1, + ) + return + + count_pad_for_edges = False + # This operator use the default rounding mode of avg_pool3d, floor. + mode = OpPoolAvg3d.RoundingMode.FLOOR + + # pad left, pad right, use default 0 + depth_pad_b = 0 + depth_pad_a = 0 + height_pad_b = 0 + height_pad_a = 0 + width_pad_b = 0 + width_pad_a = 0 + + shape_pad = [ + [depth_pad_b, depth_pad_a], + [height_pad_b, height_pad_a], + [width_pad_b, width_pad_a], + ] + padding_shape = [len(shape_pad), len(shape_pad[0])] + + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + adaptive_avg_pool3d_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpPoolAvg3d.op_name, + ) + + adaptive_avg_pool3d_op.AddInputTensors([input_tensor_wrapper]) + adaptive_avg_pool3d_op.AddOutputTensors([output_tensor_wrapper]) + + adaptive_avg_pool3d_op.AddTensorParam( + OpPoolAvg3d.param_filter_size, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(filter_shape), + filter_shape, + np.array( + filter_size, + dtype=np.uint32, + ), + True, + ) + + adaptive_avg_pool3d_op.AddTensorParam( + OpPoolAvg3d.param_stride, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(stride_shape), + stride_shape, + np.array( + stride, + dtype=np.uint32, + ), + True, + ) + + adaptive_avg_pool3d_op.AddTensorParam( + OpPoolAvg3d.param_pad_amount, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(padding_shape), + padding_shape, + np.array( + shape_pad, + dtype=np.uint32, + ), + True, + ) + + adaptive_avg_pool3d_op.AddScalarParam( + OpPoolAvg3d.param_count_pad_for_edges, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + {QCOM_DATA: count_pad_for_edges}, + ) + + adaptive_avg_pool3d_op.AddScalarParam( + OpPoolAvg3d.param_rounding_mode, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(mode)}, + ) + + return adaptive_avg_pool3d_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 79a1c93d50c..19c63015f64 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -398,6 +398,21 @@ class RoundingMode(IntEnum): CEIL = 1 +@dataclass(init=False, frozen=True) +class OpPoolAvg3d: + op_name: str = "PoolAvg3d" + param_filter_size: str = "filter_size" + param_stride: str = "stride" + param_pad_amount: str = "pad_amount" + param_count_pad_for_edges: str = "count_pad_for_edges" + param_rounding_mode: str = "rounding_mode" + + @unique + class RoundingMode(IntEnum): + FLOOR = 0 + CEIL = 1 + + @dataclass(init=False, frozen=True) class OpPoolMax2d: op_name: str = "PoolMax2d" diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index 4abbcc3145c..76f22552c8d 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -19,10 +19,8 @@ ] to_be_implemented_operator = [ - exir_ops.edge.aten._adaptive_avg_pool3d.default, exir_ops.edge.aten.adaptive_max_pool2d.default, exir_ops.edge.aten.adaptive_max_pool3d.default, - exir_ops.edge.aten.avg_pool3d.default, exir_ops.edge.aten.div.Tensor_mode, exir_ops.edge.aten.log10.default, exir_ops.edge.aten.log1p.default, diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index ecceba24a89..8b59de3bd4e 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -578,6 +578,18 @@ def annotate_avgpool2d(node: Node, quantization_config: QuantizationConfig) -> N annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.avg_pool3d.default]) +def annotate_avgpool3d(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in_single_out(node, quantization_config) + + +@register_annotator([torch.ops.aten.adaptive_avg_pool3d.default]) +def annotate_adaptive_avgpool3d( + node: Node, quantization_config: QuantizationConfig +) -> None: + annotate_single_in_single_out(node, quantization_config) + + @register_annotator([torch.ops.aten.permute.default]) def annotate_permute(node: Node, quantization_config: QuantizationConfig) -> None: annotate_in_out_obs_sharing_op(node, quantization_config) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 58647441210..1674c99175a 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -59,6 +59,16 @@ def forward(self, x): return adaptive_avg_pool(x) +class AdaptiveAvgPool3D(torch.nn.Module): + def __init__(self, output_size): + super().__init__() + self.output_size = output_size + + def forward(self, x): + adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d(self.output_size) + return adaptive_avg_pool3d(x) + + class Add(torch.nn.Module): def __init__(self): super().__init__() @@ -224,6 +234,21 @@ def forward(self, x): return torch.atan(x) +class AvgPool3d(torch.nn.Module): + def __init__(self, kernel_size, stride, padding, ceil_mode, count_include_pad): + super().__init__() + self.avg_pool3d = torch.nn.AvgPool3d( + kernel_size=kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) + + def forward(self, x): + return self.avg_pool3d(x) + + class AvgPoolModule(torch.nn.Module): def __init__(self, kernel_size, stride, padding, ceil_mode): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 24a71af9001..93215218526 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -119,6 +119,21 @@ def test_qnn_backend_adaptive_avg_pool2d(self): sample_input = (torch.randn(1, 512, 7, 7),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_adaptive_avg_pool3d(self): + # NOTE: Support the cases mod(input_dhw, output_dhw) = 0 + modules = [ + AdaptiveAvgPool3D((2, 2, 2)), # noqa: F405 + AdaptiveAvgPool3D((8)), # noqa: F405 + AdaptiveAvgPool3D((2, None, None)), # noqa: F405 + ] + sample_inputs = [ + (torch.randn(1, 512, 16, 8, 16),), + ] + for j in range(len(sample_inputs)): + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_inputs[j]) + def test_qnn_backend_alias(self): module = Alias() # noqa: F405 sample_input = (torch.randn(1, 10),) @@ -255,6 +270,25 @@ def test_qnn_backend_avg_pool2d(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_inputs[i]) + def test_qnn_backend_avg_pool3d(self): + # NOTE: Support the cases mod(input_dhw, filter_dhw) = 0 + # NOTE: The pad should be at most half of effective kernel size. + modules = [ + AvgPool3d((8), (2), (1), True, True), # noqa: F405 + AvgPool3d((8), (2), (1), True, False), # noqa: F405 + AvgPool3d((8), (2), (1), False, False), # noqa: F405 + AvgPool3d((16, 16, 16), (4, 4, 4), (1, 1, 1), False, True), # noqa: F405 + AvgPool3d((8, 8, 8), (2, 2, 2), (1, 1, 1), True, True), # noqa: F405 + AvgPool3d((12, 12, 12), (4, 6, 2), (0, 0, 0), True, True), # noqa: F405 + ] + sample_inputs = [ + (torch.randn(1, 3, 64, 48, 32),), + ] + for j in range(len(sample_inputs)): + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_inputs[j]) + def test_qnn_backend_batch_norm(self): modules = [BatchNorm(32), BatchNorm(32, False)] # noqa: F405 sample_input = (torch.randn([4, 32, 16, 16]),) @@ -2041,6 +2075,22 @@ def test_qnn_backend_adaptive_avg_pool2d(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_adaptive_avg_pool3d(self): + # NOTE: Support the cases mod(input_dhw, output_dhw) = 0 + modules = [ + AdaptiveAvgPool3D((2, 2, 2)), # noqa: F405 + AdaptiveAvgPool3D((8)), # noqa: F405 + AdaptiveAvgPool3D((2, None, None)), # noqa: F405 + ] + sample_inputs = [ + (torch.randn(1, 512, 16, 8, 16),), + ] + for j in range(len(sample_inputs)): + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_inputs[j]) + self.lower_module_and_test_output(module, sample_inputs[j]) + def test_qnn_backend_alias(self): module = Alias() # noqa: F405 sample_input = (torch.randn(1, 10),) @@ -2187,6 +2237,26 @@ def test_qnn_backend_avg_pool2d(self): module = self.get_qdq_module(module, sample_inputs[i]) self.lower_module_and_test_output(module, sample_inputs[i]) + def test_qnn_backend_avg_pool3d(self): + # NOTE: Support the cases mod(input_dhw, filter_dhw) = 0 + # NOTE: The pad should be at most half of effective kernel size. + modules = [ + AvgPool3d((8), (2), (1), True, True), # noqa: F405 + AvgPool3d((8), (2), (1), True, False), # noqa: F405 + AvgPool3d((8), (2), (1), False, False), # noqa: F405 + AvgPool3d((16, 16, 16), (4, 4, 4), (1, 1, 1), False, True), # noqa: F405 + AvgPool3d((8, 8, 8), (2, 2, 2), (1, 1, 1), True, True), # noqa: F405 + AvgPool3d((12, 12, 12), (4, 6, 2), (0, 0, 0), True, True), # noqa: F405 + ] + sample_inputs = [ + (torch.randn(1, 3, 64, 48, 32),), + ] + for j in range(len(sample_inputs)): + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_inputs[j]) + self.lower_module_and_test_output(module, sample_inputs[j]) + def test_qnn_backend_batch_norm(self): modules = [BatchNorm(32), BatchNorm(32, False)] # noqa: F405 sample_input = (torch.randn([4, 32, 16, 16]),)