Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.neg.default,
exir_ops.edge.aten.pow.Tensor_Scalar,
exir_ops.edge.aten.prelu.default,
exir_ops.edge.aten.repeat.default,
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
op_mean_dim,
op_min,
op_mul,
op_neg,
op_pad,
op_pow,
op_prelu,
Expand Down Expand Up @@ -120,6 +121,7 @@
op_mean_dim,
op_min,
op_mul,
op_neg,
op_pad,
op_pow,
op_prelu,
Expand Down
53 changes: 53 additions & 0 deletions backends/qualcomm/builders/op_neg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
import torch

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpElementWiseNeg, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class Neg(NodeVisitor):
target = ["aten.neg.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)
neg_inp_tensor_wrapper = self.define_tensor(
input_node,
node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
neg_input_tensors = [neg_inp_tensor_wrapper]
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
neg_output_tensors = [output_tensor_wrapper]
neg_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpElementWiseNeg.op_name,
)
neg_op.AddInputTensors(neg_input_tensors)
neg_op.AddOutputTensors(neg_output_tensors)
return neg_op
5 changes: 5 additions & 0 deletions backends/qualcomm/builders/qnn_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ class OpElementWiseMultiply:
op_name: str = "ElementWiseMultiply"


@dataclass(init=False, frozen=True)
class OpElementWiseNeg:
op_name: str = "ElementWiseNeg"


@dataclass(init=False, frozen=True)
class OpElementWiseNeuron:
op_name: str = "ElementWiseNeuron"
Expand Down
5 changes: 5 additions & 0 deletions backends/qualcomm/quantizer/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,11 @@ def annotate_max_pool2d_with_indices(
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.neg.default])
def annotate_neg(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.adaptive_avg_pool2d.default])
def annotate_adaptive_avgpool2d(
node: Node, quantization_config: QuantizationConfig
Expand Down
8 changes: 8 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,14 @@ def forward(self, x):
return attn_output


class Neg(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.neg(x)


class Pad(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
11 changes: 11 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,11 @@ def test_qnn_backend_minimum(self):
sample_input = (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4))
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_neg(self):
module = Neg() # noqa: F405
sample_input = (torch.randn(1, 4, 16, 16),)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_pad(self):
module = Pad() # noqa: F405
sample_input = (torch.randn([1, 8, 128]),)
Expand Down Expand Up @@ -1429,6 +1434,12 @@ def test_qnn_backend_minimum(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_neg(self):
module = Neg() # noqa: F405
sample_input = (torch.randn(1, 4, 16, 16),)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_pad(self):
module = Pad() # noqa: F405
sample_input = (torch.randn([1, 8, 128]),)
Expand Down
Loading