diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 967ae7afd2b..b65d8759bbe 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -47,6 +47,7 @@ class LayoutTransform(ExportPass): layout_agnostic_ops = { exir_ops.edge.aten.abs.default, exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.bitwise_or.Tensor, exir_ops.edge.aten.bmm.default, exir_ops.edge.aten.cat.default, exir_ops.edge.aten.ceil.default, diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index a16d4fb5057..c5352a7fbee 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -52,6 +52,7 @@ op_mul, op_ne, op_neg, + op_or, op_pad, op_pow, op_prelu, @@ -131,6 +132,7 @@ op_mul, op_neg, op_ne, + op_or, op_pad, op_pow, op_prelu, diff --git a/backends/qualcomm/builders/op_or.py b/backends/qualcomm/builders/op_or.py new file mode 100644 index 00000000000..c2751744788 --- /dev/null +++ b/backends/qualcomm/builders/op_or.py @@ -0,0 +1,59 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpElementWiseOr, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class OpOr(NodeVisitor): + target = ["aten.bitwise_or.Tensor"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + or_output_tensors = [output_tensor_wrapper] + + or_input_tensors = [] + for index in range(2): + input_node = node.args[index] + input_tensor = self.get_tensor(input_node, node) + tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + tensor_type, + nodes_to_wrappers, + ) + or_input_tensors.append(input_tensor_wrapper) + or_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseOr.op_name, + ) + or_op.AddInputTensors(or_input_tensors) + or_op.AddOutputTensors(or_output_tensors) + return or_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 5e0b63d6d19..1d55d56de0f 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -168,6 +168,11 @@ class OpElementWiseNotEqual: op_name: str = "ElementWiseNotEqual" +@dataclass(init=False, frozen=True) +class OpElementWiseOr: + op_name: str = "ElementWiseOr" + + @dataclass(init=False, frozen=True) class OpElementWisePower: op_name: str = "ElementWisePower" diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index a232d231c27..8867f92b54a 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -680,6 +680,11 @@ def annotate_sigmoid(node: Node, quantization_config: QuantizationConfig) -> Non ) +@register_annotator([torch.ops.aten.bitwise_or.Tensor, torch.ops.aten.__or__.Tensor]) +def annotate_bitwise_or(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_binary(node, quantization_config) + + @register_annotator([torch.ops.aten.pow.Tensor_Tensor]) def annotate_pow(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index bdb5541353b..85491c62dcb 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1025,6 +1025,28 @@ def forward(self, x): return x != self.constant +class OrBitWise(torch.nn.Module): + def __init__(self, pos, neg): + super().__init__() + self.pos = pos + self.neg = neg + + def forward(self, x, y): + bitwise_or = torch.bitwise_or(x, y).bool() + return torch.where(bitwise_or, self.pos, self.neg) + + +class OrOperator(torch.nn.Module): + def __init__(self, pos, neg): + super().__init__() + self.pos = pos + self.neg = neg + + def forward(self, x, y): + operator_or = x.to(torch.bool) | y.to(torch.bool) + return torch.where(operator_or, self.pos, self.neg) + + class Pad(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 986243d7a9c..c7b68baa9eb 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -310,6 +310,33 @@ def test_qnn_backend_element_wise_mul(self): self.lower_module_and_test_output(module, sample_input) index += 1 + def test_qnn_backend_element_wise_or(self): + test_comb = [ + { + QCOM_MODULE: OrBitWise( # noqa: F405 + torch.tensor(1.7), torch.tensor(0.2) + ), + QCOM_SAMPLE_INPUTS: ( + torch.tensor([1, 0, 1, 0], dtype=torch.bool), + torch.tensor([1, 1, 0, 0], dtype=torch.bool), + ), + }, + { + QCOM_MODULE: OrOperator( # noqa: F405 + torch.tensor(1.5), torch.tensor(-1.2) + ), + QCOM_SAMPLE_INPUTS: ( + torch.full((3, 3), 1).triu(), + torch.full((3, 3), 1).tril(diagonal=0), + ), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + self.lower_module_and_test_output( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + def test_qnn_backend_element_wise_sqrt(self): modules = [Sqrt(), SqrtConstant()] # noqa: F405 for i, module in enumerate(modules): @@ -1246,6 +1273,34 @@ def test_qnn_backend_element_wise_mul(self): self.lower_module_and_test_output(module, sample_input) index += 1 + def test_qnn_backend_element_wise_or(self): + test_comb = [ + { + QCOM_MODULE: OrBitWise( # noqa: F405 + torch.tensor(1.7), torch.tensor(0.2) + ), + QCOM_SAMPLE_INPUTS: ( + torch.tensor([1, 0, 1, 0], dtype=torch.bool), + torch.tensor([1, 1, 0, 0], dtype=torch.bool), + ), + }, + { + QCOM_MODULE: OrOperator( # noqa: F405 + torch.tensor(1.5), torch.tensor(-1.2) + ), + QCOM_SAMPLE_INPUTS: ( + torch.full((3, 3), 1).triu(), + torch.full((3, 3), 1).tril(diagonal=0), + ), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + module = self.get_qdq_module( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + self.lower_module_and_test_output(module, test[QCOM_SAMPLE_INPUTS]) + def test_qnn_backend_element_wise_sqrt(self): modules = [Sqrt(), SqrtConstant()] # noqa: F405 for i, module in enumerate(modules):