diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 67f33873d44..6ca0b512643 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -75,6 +75,7 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.minimum.default, exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.neg.default, exir_ops.edge.aten.pow.Tensor_Scalar, exir_ops.edge.aten.prelu.default, exir_ops.edge.aten.repeat.default, diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index 3267cef7b8a..ce19b6dbc73 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -47,6 +47,7 @@ op_mean_dim, op_min, op_mul, + op_neg, op_pad, op_pow, op_prelu, @@ -120,6 +121,7 @@ op_mean_dim, op_min, op_mul, + op_neg, op_pad, op_pow, op_prelu, diff --git a/backends/qualcomm/builders/op_neg.py b/backends/qualcomm/builders/op_neg.py new file mode 100644 index 00000000000..a950a1887ab --- /dev/null +++ b/backends/qualcomm/builders/op_neg.py @@ -0,0 +1,53 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpElementWiseNeg, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Neg(NodeVisitor): + target = ["aten.neg.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + neg_inp_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + neg_input_tensors = [neg_inp_tensor_wrapper] + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + neg_output_tensors = [output_tensor_wrapper] + neg_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseNeg.op_name, + ) + neg_op.AddInputTensors(neg_input_tensors) + neg_op.AddOutputTensors(neg_output_tensors) + return neg_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 233f3118783..d53e6792869 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -145,6 +145,11 @@ class OpElementWiseMultiply: op_name: str = "ElementWiseMultiply" +@dataclass(init=False, frozen=True) +class OpElementWiseNeg: + op_name: str = "ElementWiseNeg" + + @dataclass(init=False, frozen=True) class OpElementWiseNeuron: op_name: str = "ElementWiseNeuron" diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index c9f28ae760b..fe1729d19b8 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -403,6 +403,11 @@ def annotate_max_pool2d_with_indices( annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.neg.default]) +def annotate_neg(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in_single_out(node, quantization_config) + + @register_annotator([torch.ops.aten.adaptive_avg_pool2d.default]) def annotate_adaptive_avgpool2d( node: Node, quantization_config: QuantizationConfig diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 6e758a5c45f..4e733087808 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -894,6 +894,14 @@ def forward(self, x): return attn_output +class Neg(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.neg(x) + + class Pad(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 9ceaf60c93d..0d1e80904fe 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -546,6 +546,11 @@ def test_qnn_backend_minimum(self): sample_input = (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4)) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_neg(self): + module = Neg() # noqa: F405 + sample_input = (torch.randn(1, 4, 16, 16),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_pad(self): module = Pad() # noqa: F405 sample_input = (torch.randn([1, 8, 128]),) @@ -1429,6 +1434,12 @@ def test_qnn_backend_minimum(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_neg(self): + module = Neg() # noqa: F405 + sample_input = (torch.randn(1, 4, 16, 16),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_pad(self): module = Pad() # noqa: F405 sample_input = (torch.randn([1, 8, 128]),)