From aecde6b5b5d15165166256b3b4c2443b229260b9 Mon Sep 17 00:00:00 2001 From: thchenqti Date: Tue, 1 Jul 2025 13:49:03 +0800 Subject: [PATCH] op_enablement_round_floor_atan --- backends/qualcomm/_passes/layout_transform.py | 3 + backends/qualcomm/builders/README.md | 8 +-- backends/qualcomm/builders/__init__.py | 6 ++ backends/qualcomm/builders/op_atan.py | 55 ++++++++++++++++++ backends/qualcomm/builders/op_floor.py | 56 ++++++++++++++++++ backends/qualcomm/builders/op_round.py | 58 +++++++++++++++++++ backends/qualcomm/builders/qnn_constants.py | 15 +++++ backends/qualcomm/quantizer/annotators.py | 15 +++++ backends/qualcomm/tests/models.py | 24 ++++++++ backends/qualcomm/tests/test_qnn_delegate.py | 33 +++++++++++ 10 files changed, 269 insertions(+), 4 deletions(-) create mode 100644 backends/qualcomm/builders/op_atan.py create mode 100644 backends/qualcomm/builders/op_floor.py create mode 100644 backends/qualcomm/builders/op_round.py diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 9b21c0d33d9..0c6b3152561 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -63,6 +63,7 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.abs.default, exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.amax.default, + exir_ops.edge.aten.atan.default, exir_ops.edge.aten.bitwise_or.Tensor, exir_ops.edge.aten.bmm.default, exir_ops.edge.aten.bitwise_and.Tensor, @@ -75,6 +76,7 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.elu.default, exir_ops.edge.aten.eq.Tensor, exir_ops.edge.aten.exp.default, + exir_ops.edge.aten.floor.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.full_like.default, exir_ops.edge.aten.ge.Tensor, @@ -99,6 +101,7 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.pow.Tensor_Scalar, exir_ops.edge.aten.prelu.default, exir_ops.edge.aten.repeat.default, + exir_ops.edge.aten.round.default, exir_ops.edge.aten.relu.default, exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.split_with_sizes.default, diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md index 77944a8bfc2..4e150f1eeaa 100644 --- a/backends/qualcomm/builders/README.md +++ b/backends/qualcomm/builders/README.md @@ -360,7 +360,7 @@ The operator now should be functional for Qualcomm backends. For operator to wor ## Operator Support Status Please help update following table if you are contributing new operators: -| Operators | HTP - 77/116 Enabled | +| Operators | HTP - 80/116 Enabled | |-----------|---------| | Argmax | ✗ | | Argmin | ✓ | @@ -382,14 +382,14 @@ Please help update following table if you are contributing new operators: | ElementWiseAdd | ✓ | | ElementWiseAnd | ✓ | | ElementWiseAsin | ✗ | -| ElementWiseAtan | ✗ | +| ElementWiseAtan | ✓ | | ElementWiseBinary | ✗ | | ElementWiseCeil | ✓ | | ElementWiseCos | ✓ | | ElementWiseDivide | ✓ | | ElementWiseEqual | ✓ | | ElementWiseExp | ✓ | -| ElementWiseFloor | ✗ | +| ElementWiseFloor | ✓ | | ElementWiseFloorDiv | ✗ | | ElementWiseGreater | ✓ | | ElementWiseGreaterEqual | ✓ | @@ -405,7 +405,7 @@ Please help update following table if you are contributing new operators: | ElementWiseNotEqual | ✓ | | ElementWiseOr | ✓ | | ElementWisePower | ✓ | -| ElementWiseRound | ✗ | +| ElementWiseRound | ✓ | | ElementWiseRsqrt | ✓ | | ElementWiseSelect | ✓ | | ElementWiseSign | ✗ | diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index fff2a3b4a53..f8b2f11ff4c 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -13,6 +13,7 @@ op_and, op_arange, op_argmin, + op_atan, op_avg_pool2d, op_batch_norm, op_bmm, @@ -30,6 +31,7 @@ op_eq, op_exp, op_expand, + op_floor, op_full, op_full_like, op_gather, @@ -68,6 +70,7 @@ op_reshape, op_resize, op_rms_norm, + op_round, op_rsqrt, op_scalar_tensor, op_select_copy, @@ -103,6 +106,7 @@ op_and, op_arange, op_argmin, + op_atan, op_avg_pool2d, op_batch_norm, op_bmm, @@ -120,6 +124,7 @@ op_eq, op_exp, op_expand, + op_floor, op_full, op_full_like, op_gather, @@ -158,6 +163,7 @@ op_reshape, op_resize, op_rms_norm, + op_round, op_rsqrt, op_scalar_tensor, op_select_copy, diff --git a/backends/qualcomm/builders/op_atan.py b/backends/qualcomm/builders/op_atan.py new file mode 100644 index 00000000000..83c47b9103d --- /dev/null +++ b/backends/qualcomm/builders/op_atan.py @@ -0,0 +1,55 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import torch + +from .node_visitor import NodeVisitor +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpElementWiseAtan, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Atan(NodeVisitor): + target = ["aten.atan.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + atan_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseAtan.op_name, + ) + atan_op.AddInputTensors([input_tensor_wrapper]) + atan_op.AddOutputTensors([output_tensor_wrapper]) + + return atan_op diff --git a/backends/qualcomm/builders/op_floor.py b/backends/qualcomm/builders/op_floor.py new file mode 100644 index 00000000000..3d69389686e --- /dev/null +++ b/backends/qualcomm/builders/op_floor.py @@ -0,0 +1,56 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import torch + +from .node_visitor import NodeVisitor +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpElementWiseFloor, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Floor(NodeVisitor): + target = ["aten.floor.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + floor_inp_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + floor_input_tensors = [floor_inp_tensor_wrapper] + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + floor_output_tensors = [output_tensor_wrapper] + + floor_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseFloor.op_name, + ) + floor_op.AddInputTensors(floor_input_tensors) + floor_op.AddOutputTensors(floor_output_tensors) + return floor_op diff --git a/backends/qualcomm/builders/op_round.py b/backends/qualcomm/builders/op_round.py new file mode 100644 index 00000000000..08aa83b5811 --- /dev/null +++ b/backends/qualcomm/builders/op_round.py @@ -0,0 +1,58 @@ +import warnings +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import torch + +from .node_visitor import NodeVisitor +from .node_visitor_manager import register_node_visitor + +from .qnn_constants import OpElementWiseRound, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Round(NodeVisitor): + target = ["aten.round.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + if len(node.args) > 1: + warnings.warn( + "[QNN Delegate Op Builder]: QNN dose not support decimals", + stacklevel=1, + ) + return None + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + round_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseRound.op_name, + ) + round_op.AddInputTensors([input_tensor_wrapper]) + round_op.AddOutputTensors([output_tensor_wrapper]) + return round_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 7b545e5ab2d..aa245442f67 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -105,6 +105,11 @@ class OpElementWiseAnd: op_name: str = "ElementWiseAnd" +@dataclass(init=False, frozen=True) +class OpElementWiseAtan: + op_name: str = "ElementWiseAtan" + + @dataclass(init=False, frozen=True) class OpElementWiseCeil: op_name = "ElementWiseCeil" @@ -130,6 +135,11 @@ class OpElementWiseEqual: op_name: str = "ElementWiseEqual" +@dataclass(init=False, frozen=True) +class OpElementWiseFloor: + op_name: str = "ElementWiseFloor" + + @dataclass(init=False, frozen=True) class OpElementWiseGreater: op_name: str = "ElementWiseGreater" @@ -203,6 +213,11 @@ class OpElementWisePower: op_name: str = "ElementWisePower" +@dataclass(init=False, frozen=True) +class OpElementWiseRound: + op_name: str = "ElementWiseRound" + + @dataclass(init=False, frozen=True) class OpElementWiseRsqrt: op_name: str = "ElementWiseRsqrt" diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 6233abb01e1..cc7e0054ebe 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -163,6 +163,11 @@ def annotate_single_in_single_out( ) +@register_annotator([torch.ops.aten.atan.default]) +def annotate_atan(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in_single_out(node, quantization_config) + + @register_annotator([torch.ops.aten.topk.default]) def annotate_topk(node: Node, quantization_config: QuantizationConfig) -> None: if _is_annotated([node]): @@ -404,6 +409,11 @@ def annotate_clamp(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.floor.default]) +def annotate_floor(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in_single_out(node, quantization_config) + + @register_annotator([torch.ops.aten.relu.default, torch.ops.aten.relu_.default]) def annotate_relu(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) @@ -414,6 +424,11 @@ def annotate_repeat(node: Node, quantization_config: QuantizationConfig) -> None annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.round.default]) +def annotate_round(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in_single_out(node, quantization_config) + + @register_annotator([torch.ops.aten.cos.default]) def annotate_cos(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 8be05d46688..fc613575f9f 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -146,6 +146,14 @@ def forward(self, x, y): return squeeze_out, conv_out +class Atan(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.atan(x) + + class AvgPoolModule(torch.nn.Module): def __init__(self, kernel_size, stride, padding, ceil_mode): super().__init__() @@ -741,6 +749,14 @@ def forward(self, x): return torch.special.expm1(x) +class Floor(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.floor(x) + + class Fold(torch.nn.Module): def __init__(self): super().__init__() @@ -1448,6 +1464,14 @@ def forward(self, x): return torch.roll(x, shifts=self.shifts, dims=self.dims) +class Round(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.round(x) + + class Rsqrt(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 4a0edaf471d..bbee5e3d6ed 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -157,6 +157,11 @@ def test_qnn_backend_argmin(self): sample_input = (torch.randn(16, 3, 4, 4),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_atan(self): + sample_input = (torch.randn(3, 4),) + module = Atan() # noqa: F405 + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_avg_pool2d(self): modules = [ AvgPoolModule((2, 2), (1, 1), (1, 1), False), # noqa: F405 @@ -515,6 +520,11 @@ def test_qnn_backend_expm1(self): module = ExpM1() # noqa: F405 self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_floor(self): + sample_input = (torch.randn(3, 4),) + module = Floor() # noqa: F405 + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_fold(self): sample_input = (torch.randn(3, 512, 256),) module = Fold() # noqa: F405 @@ -928,6 +938,11 @@ def test_qnn_backend_roll(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_round(self): + module = Round() # noqa: F405 + sample_input = (torch.randn([3, 4]),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_rsqrt(self): module = Rsqrt() # noqa: F405 sample_input = (torch.abs(torch.randn([3, 4])),) @@ -1375,6 +1390,12 @@ def test_qnn_backend_argmin(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_atan(self): + sample_input = (torch.randn(3, 4),) + module = Atan() # noqa: F405 + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_avg_pool2d(self): modules = [ AvgPoolModule((2, 2), (1, 1), (1, 1), False), # noqa: F405 @@ -1788,6 +1809,12 @@ def test_qnn_backend_expm1(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_floor(self): + sample_input = (torch.randn(3, 4),) + module = Floor() # noqa: F405 + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_fold(self): sample_input = (torch.randn(3, 512, 256),) module = Fold() # noqa: F405 @@ -2261,6 +2288,12 @@ def test_qnn_backend_roll(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_round(self): + module = Round() # noqa: F405 + sample_input = (torch.randn([3, 4]),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_rsqrt(self): module = Rsqrt() # noqa: F405 sample_input = (torch.abs(torch.randn([3, 4])),)