Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class LayoutTransform(ExportPass):
layout_agnostic_ops = {
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.bitwise_or.Tensor,
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.ceil.default,
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
op_mul,
op_ne,
op_neg,
op_or,
op_pad,
op_pow,
op_prelu,
Expand Down Expand Up @@ -131,6 +132,7 @@
op_mul,
op_neg,
op_ne,
op_or,
op_pad,
op_pow,
op_prelu,
Expand Down
59 changes: 59 additions & 0 deletions backends/qualcomm/builders/op_or.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import torch

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpElementWiseOr, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class OpOr(NodeVisitor):
target = ["aten.bitwise_or.Tensor"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
out_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
node,
out_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
or_output_tensors = [output_tensor_wrapper]

or_input_tensors = []
for index in range(2):
input_node = node.args[index]
input_tensor = self.get_tensor(input_node, node)
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE

input_tensor_wrapper = self.define_tensor(
input_node,
node,
input_tensor,
tensor_type,
nodes_to_wrappers,
)
or_input_tensors.append(input_tensor_wrapper)
or_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpElementWiseOr.op_name,
)
or_op.AddInputTensors(or_input_tensors)
or_op.AddOutputTensors(or_output_tensors)
return or_op
5 changes: 5 additions & 0 deletions backends/qualcomm/builders/qnn_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ class OpElementWiseNotEqual:
op_name: str = "ElementWiseNotEqual"


@dataclass(init=False, frozen=True)
class OpElementWiseOr:
op_name: str = "ElementWiseOr"


@dataclass(init=False, frozen=True)
class OpElementWisePower:
op_name: str = "ElementWisePower"
Expand Down
5 changes: 5 additions & 0 deletions backends/qualcomm/quantizer/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,11 @@ def annotate_sigmoid(node: Node, quantization_config: QuantizationConfig) -> Non
)


@register_annotator([torch.ops.aten.bitwise_or.Tensor, torch.ops.aten.__or__.Tensor])
def annotate_bitwise_or(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_binary(node, quantization_config)


@register_annotator([torch.ops.aten.pow.Tensor_Tensor])
def annotate_pow(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
Expand Down
22 changes: 22 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,28 @@ def forward(self, x):
return x != self.constant


class OrBitWise(torch.nn.Module):
def __init__(self, pos, neg):
super().__init__()
self.pos = pos
self.neg = neg

def forward(self, x, y):
bitwise_or = torch.bitwise_or(x, y).bool()
return torch.where(bitwise_or, self.pos, self.neg)


class OrOperator(torch.nn.Module):
def __init__(self, pos, neg):
super().__init__()
self.pos = pos
self.neg = neg

def forward(self, x, y):
operator_or = x.to(torch.bool) | y.to(torch.bool)
return torch.where(operator_or, self.pos, self.neg)


class Pad(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
55 changes: 55 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,33 @@ def test_qnn_backend_element_wise_mul(self):
self.lower_module_and_test_output(module, sample_input)
index += 1

def test_qnn_backend_element_wise_or(self):
test_comb = [
{
QCOM_MODULE: OrBitWise( # noqa: F405
torch.tensor(1.7), torch.tensor(0.2)
),
QCOM_SAMPLE_INPUTS: (
torch.tensor([1, 0, 1, 0], dtype=torch.bool),
torch.tensor([1, 1, 0, 0], dtype=torch.bool),
),
},
{
QCOM_MODULE: OrOperator( # noqa: F405
torch.tensor(1.5), torch.tensor(-1.2)
),
QCOM_SAMPLE_INPUTS: (
torch.full((3, 3), 1).triu(),
torch.full((3, 3), 1).tril(diagonal=0),
),
},
]
for i, test in enumerate(test_comb):
with self.subTest(i=i):
self.lower_module_and_test_output(
test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS]
)

def test_qnn_backend_element_wise_sqrt(self):
modules = [Sqrt(), SqrtConstant()] # noqa: F405
for i, module in enumerate(modules):
Expand Down Expand Up @@ -1246,6 +1273,34 @@ def test_qnn_backend_element_wise_mul(self):
self.lower_module_and_test_output(module, sample_input)
index += 1

def test_qnn_backend_element_wise_or(self):
test_comb = [
{
QCOM_MODULE: OrBitWise( # noqa: F405
torch.tensor(1.7), torch.tensor(0.2)
),
QCOM_SAMPLE_INPUTS: (
torch.tensor([1, 0, 1, 0], dtype=torch.bool),
torch.tensor([1, 1, 0, 0], dtype=torch.bool),
),
},
{
QCOM_MODULE: OrOperator( # noqa: F405
torch.tensor(1.5), torch.tensor(-1.2)
),
QCOM_SAMPLE_INPUTS: (
torch.full((3, 3), 1).triu(),
torch.full((3, 3), 1).tril(diagonal=0),
),
},
]
for i, test in enumerate(test_comb):
with self.subTest(i=i):
module = self.get_qdq_module(
test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS]
)
self.lower_module_and_test_output(module, test[QCOM_SAMPLE_INPUTS])

def test_qnn_backend_element_wise_sqrt(self):
modules = [Sqrt(), SqrtConstant()] # noqa: F405
for i, module in enumerate(modules):
Expand Down
Loading