From 553aa3a95fd6334a4b8414a666804e754a0be235 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Tue, 8 Apr 2025 12:32:16 -0700 Subject: [PATCH] Add op_amax support (#9955) Summary: As title, add op_amax to support an internal model, add unit test in test_qnn_delegate.py Reviewed By: kirklandsign Differential Revision: D72613814 --- backends/qualcomm/_passes/layout_transform.py | 1 + backends/qualcomm/builders/__init__.py | 2 + backends/qualcomm/builders/op_amax.py | 84 +++++++++++++++++++ backends/qualcomm/builders/qnn_constants.py | 7 ++ backends/qualcomm/quantizer/annotators.py | 5 ++ backends/qualcomm/tests/models.py | 10 +++ backends/qualcomm/tests/test_qnn_delegate.py | 15 ++++ 7 files changed, 124 insertions(+) create mode 100644 backends/qualcomm/builders/op_amax.py diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 17960a6029b..4d47c38bc03 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -47,6 +47,7 @@ class LayoutTransform(ExportPass): layout_agnostic_ops = { exir_ops.edge.aten.abs.default, exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.amax.default, exir_ops.edge.aten.bitwise_or.Tensor, exir_ops.edge.aten.bmm.default, exir_ops.edge.aten.bitwise_and.Tensor, diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index cc85333f26b..645b823d0e5 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -9,6 +9,7 @@ op_abs, op_adaptive_avg_pool2d, op_add, + op_amax, op_and, op_arange, op_argmin, @@ -95,6 +96,7 @@ op_abs, op_adaptive_avg_pool2d, op_add, + op_amax, op_and, op_arange, op_argmin, diff --git a/backends/qualcomm/builders/op_amax.py b/backends/qualcomm/builders/op_amax.py new file mode 100644 index 00000000000..06f7c998502 --- /dev/null +++ b/backends/qualcomm/builders/op_amax.py @@ -0,0 +1,84 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import cast, Dict, List + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import numpy as np + +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpAmax, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class AMax(NodeVisitor): + target = ["aten.amax.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + # mean dims and keep dims + mean_dims = cast(List[int], node.args[1]) + mean_dims = [ + mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims + ] + if QCOM_AXIS_ORDER in node.meta: + mean_dims = [ + node.meta[QCOM_AXIS_ORDER].index(mean_dim) for mean_dim in mean_dims + ] + mean_dims_shape = [len(mean_dims)] + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + reduce_max_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpAmax.op_name, + ) + reduce_max_op.AddInputTensors([input_tensor_wrapper]) + reduce_max_op.AddOutputTensors([output_tensor_wrapper]) + reduce_max_op.AddTensorParam( + OpAmax.param_axes, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(mean_dims_shape), + mean_dims_shape, + np.array(mean_dims, dtype=np.uint32), + True, + ) + if len(node.args) > 2: + keep_dims = cast(bool, node.args[2]) + reduce_max_op.AddScalarParam( + OpAmax.param_keep_dims, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + {QCOM_DATA: keep_dims}, + ) + + return reduce_max_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 9613c755c7c..31822a174b9 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -14,6 +14,13 @@ # instead of replicating them here. +@dataclass(init=False, frozen=True) +class OpAmax: + op_name: str = "ReduceMax" + param_axes: str = "axes" + param_keep_dims: str = "keep_dims" + + @dataclass(init=False, frozen=True) class OpBatchnorm: op_name: str = "Batchnorm" diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 93af5e86c97..52662202795 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -182,6 +182,11 @@ def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) +@register_annotator([torch.ops.aten.amax.default]) +def annotate_amax(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_binary(node, quantization_config) + + @register_annotator([torch.ops.aten.argmin.default]) def annotate_argmin(node: Node, quantization_config: QuantizationConfig) -> None: if _is_annotated([node]): diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 0857a597d88..f56aed6f76c 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -72,6 +72,16 @@ def forward(self, x): return torch.any(x, dim=self.dim, keepdim=self.keepdim) +class AMax(torch.nn.Module): + def __init__(self, dim=None, keepdim=False): + super().__init__() + self.dim = dim + self.keepdim = keepdim + + def forward(self, x): + return torch.amax(x, dim=self.dim, keepdim=self.keepdim) + + class Arange(torch.nn.Module): def __init__(self, start, end, step, dtype): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 795459a9f77..e33eb16d1e5 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -113,6 +113,13 @@ def test_qnn_backend_adaptive_avg_pool2d(self): sample_input = (torch.randn(1, 512, 7, 7),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_amax(self): + modules = [AMax(dim=1, keepdim=False), AMax(dim=1, keepdim=True)] # noqa: F405 + sample_input = (torch.randn(4, 4),) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_any(self): modules = [Any(), Any(dim=[0, 1]), Any(dim=1, keepdim=True)] # noqa: F405 sample_input = (torch.randn(3, 3, 3) > 0,) @@ -1111,6 +1118,14 @@ def test_qnn_backend_adaptive_avg_pool2d(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_amax(self): + modules = [AMax(dim=1, keepdim=False), AMax(dim=1, keepdim=True)] # noqa: F405 + sample_input = (torch.randn(4, 4),) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_any(self): modules = [Any(), Any(dim=[0, 1]), Any(dim=1, keepdim=True)] # noqa: F405 sample_input = (torch.randn(3, 3, 3) > 0,)