diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 13175fe41bd..8b8c9d15bf9 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -64,8 +64,10 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.amax.default, exir_ops.edge.aten.amin.default, + exir_ops.edge.aten.asin.default, exir_ops.edge.aten.atan.default, exir_ops.edge.aten.bitwise_or.Tensor, + exir_ops.edge.aten.bitwise_xor.Tensor, exir_ops.edge.aten.bmm.default, exir_ops.edge.aten.bitwise_and.Tensor, exir_ops.edge.aten.cat.default, @@ -78,6 +80,7 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.eq.Tensor, exir_ops.edge.aten.exp.default, exir_ops.edge.aten.floor.default, + exir_ops.edge.aten.floor_divide.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.full_like.default, exir_ops.edge.aten.ge.Tensor, @@ -107,6 +110,7 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.relu.default, exir_ops.edge.aten.round.default, exir_ops.edge.aten.sigmoid.default, + exir_ops.edge.aten.sign.default, exir_ops.edge.aten.split_with_sizes.default, exir_ops.edge.aten.split_with_sizes_copy.default, exir_ops.edge.aten.sqrt.default, diff --git a/backends/qualcomm/_passes/lift_constant_scalar_operands.py b/backends/qualcomm/_passes/lift_constant_scalar_operands.py index dc9592e415b..f5c5915cab2 100644 --- a/backends/qualcomm/_passes/lift_constant_scalar_operands.py +++ b/backends/qualcomm/_passes/lift_constant_scalar_operands.py @@ -54,6 +54,7 @@ class TensorOpInfo: aten.where.ScalarOther: TensorOpInfo(aten.where.self, False, True), aten.where.Scalar: TensorOpInfo(aten.where.self, False, True), aten.masked_fill.Scalar: TensorOpInfo(aten.masked_fill.Tensor, False, False), + aten.bitwise_xor.Scalar: TensorOpInfo(aten.bitwise_xor.Tensor, False, False), } @@ -64,6 +65,7 @@ class TensorOpInfo: aten.arange.default, aten.scalar_tensor.default, aten.elu.default, + aten.hardtanh.default, } diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md index 9c62e1080fe..6ba4eafb01f 100644 --- a/backends/qualcomm/builders/README.md +++ b/backends/qualcomm/builders/README.md @@ -360,7 +360,12 @@ The operator now should be functional for Qualcomm backends. For operator to wor ## Operator Support Status Please help update following table if you are contributing new operators: -| Operators | HTP - 82/116 Enabled | ++ ✓ = Supported ++ ✗ = Not Supported ++ 🚫 = Deprecated, supported with other QNN Ops + + +| Operators | HTP - 90/116 Enabled | |-----------|---------| | Argmax | ✓ | | Argmin | ✓ | @@ -381,16 +386,16 @@ Please help update following table if you are contributing new operators: | ElementWiseAbs | ✓ | | ElementWiseAdd | ✓ | | ElementWiseAnd | ✓ | -| ElementWiseAsin | ✗ | +| ElementWiseAsin | ✓ | | ElementWiseAtan | ✓ | -| ElementWiseBinary | ✗ | +| ElementWiseBinary | ✓ | | ElementWiseCeil | ✓ | | ElementWiseCos | ✓ | | ElementWiseDivide | ✓ | | ElementWiseEqual | ✓ | | ElementWiseExp | ✓ | | ElementWiseFloor | ✓ | -| ElementWiseFloorDiv | ✗ | +| ElementWiseFloorDiv | ✓ | | ElementWiseGreater | ✓ | | ElementWiseGreaterEqual | ✓ | | ElementWiseLess | ✓ | @@ -408,13 +413,13 @@ Please help update following table if you are contributing new operators: | ElementWiseRound | ✓ | | ElementWiseRsqrt | ✓ | | ElementWiseSelect | ✓ | -| ElementWiseSign | ✗ | +| ElementWiseSign | ✓ | | ElementWiseSin | ✓ | | ElementWiseSquaredDifference | ✗ | | ElementWiseSquareRoot | ✓ | | ElementWiseSubtract | ✓ | | ElementWiseUnary | ✗ | -| ElementWiseXor | ✗ | +| ElementWiseXor | ✓ | | Elu | ✓ | | ExpandDims | ✓ | | ExtractGlimpse | ✗ | @@ -452,11 +457,11 @@ Please help update following table if you are contributing new operators: | ReduceMin | ✓ | | ReduceSum | ✓ | | Relu | ✓ | -| Relu1 | ✗ | -| Relu6 | ✗ | +| Relu1 | 🚫 | +| Relu6 | 🚫 | | ReluMinMax | ✓ | | Reshape | ✓ | -| Resize | ✗ | +| Resize | ✓ | | ResizeBilinear | ✓ | | ResizeNearestNeighbor | ✓ | | RoiAlign | ✗ | diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index 62e8e476257..68873d15b3e 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -15,9 +15,11 @@ op_arange, op_argmax, op_argmin, + op_asin, op_atan, op_avg_pool2d, op_batch_norm, + op_binary, op_bmm, op_cat, op_ceil, @@ -79,6 +81,7 @@ op_scalar_tensor, op_select_copy, op_sigmoid, + op_sign, op_sin, op_skip_ops, op_slice_copy, @@ -99,6 +102,7 @@ op_upsample_bilinear2d, op_upsample_nearest2d, op_where, + op_xor, ) __all__ = [ @@ -112,9 +116,11 @@ op_arange, op_argmax, op_argmin, + op_asin, op_atan, op_avg_pool2d, op_batch_norm, + op_binary, op_bmm, op_cat, op_ceil, @@ -176,6 +182,7 @@ op_scalar_tensor, op_select_copy, op_sigmoid, + op_sign, op_sin, op_skip_ops, op_slice_copy, @@ -196,4 +203,5 @@ op_upsample_bilinear2d, op_upsample_nearest2d, op_where, + op_xor, ] diff --git a/backends/qualcomm/builders/op_asin.py b/backends/qualcomm/builders/op_asin.py new file mode 100644 index 00000000000..ff50380e62c --- /dev/null +++ b/backends/qualcomm/builders/op_asin.py @@ -0,0 +1,56 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree.from typing import cast, Dict +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import torch + +from .node_visitor import NodeVisitor +from .node_visitor_manager import register_node_visitor + +from .qnn_constants import OpElementWiseAsin, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class asin(NodeVisitor): + target = ["aten.asin.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + asin_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseAsin.op_name, + ) + asin_op.AddInputTensors([input_tensor_wrapper]) + asin_op.AddOutputTensors([output_tensor_wrapper]) + + return asin_op diff --git a/backends/qualcomm/builders/op_binary.py b/backends/qualcomm/builders/op_binary.py new file mode 100644 index 00000000000..4f4d8b9b560 --- /dev/null +++ b/backends/qualcomm/builders/op_binary.py @@ -0,0 +1,84 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import warnings +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import numpy as np +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA +from executorch.exir.dialects._ops import ops as exir_ops + +from .node_visitor import NodeVisitor +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpElementWiseBinary, QNN_OP_PACKAGE_NAME_QTI_AISW + + +# Refer to QnnOpDef.h for the value. +QNN_BINARY_OPERATOR = { + exir_ops.edge.aten.floor_divide.default: 4, +} + + +@register_node_visitor +class Binary(NodeVisitor): + target = ["aten.floor_divide.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + binary_output_tensors = [output_tensor_wrapper] + + binary_input_tensors = [] + for index in range(2): + input_node = self.get_node(node.args[index]) + input_tensor = self.get_tensor(input_node, node) + tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + tensor_type, + nodes_to_wrappers, + ) + binary_input_tensors.append(input_tensor_wrapper) + + binary_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseBinary.op_name, + ) + binary_op.AddInputTensors(binary_input_tensors) + binary_op.AddOutputTensors(binary_output_tensors) + + if node.target not in QNN_BINARY_OPERATOR: + warnings.warn( + "[QNN Delegate Op Builder]: This binary operator is not yet supported.", + stacklevel=1, + ) + return None + + binary_op.AddScalarParam( + OpElementWiseBinary.param_operation, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(QNN_BINARY_OPERATOR[node.target])}, + ) + + return binary_op diff --git a/backends/qualcomm/builders/op_ne.py b/backends/qualcomm/builders/op_ne.py index 1c7f87d4f4f..660c78e3e14 100644 --- a/backends/qualcomm/builders/op_ne.py +++ b/backends/qualcomm/builders/op_ne.py @@ -16,7 +16,7 @@ @register_node_visitor class NotEqual(NodeVisitor): - target = ["aten.ne.Tensor", "aten.ne.Scalar"] + target = ["aten.ne.Tensor"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_sign.py b/backends/qualcomm/builders/op_sign.py new file mode 100644 index 00000000000..faf2f2e0066 --- /dev/null +++ b/backends/qualcomm/builders/op_sign.py @@ -0,0 +1,56 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch + +from .node_visitor import NodeVisitor +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpElementWiseSign, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Sign(NodeVisitor): + target = ["aten.sign.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + sign_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseSign.op_name, + ) + sign_op.AddInputTensors([input_tensor_wrapper]) + sign_op.AddOutputTensors([output_tensor_wrapper]) + + return sign_op diff --git a/backends/qualcomm/builders/op_xor.py b/backends/qualcomm/builders/op_xor.py new file mode 100644 index 00000000000..d4462d9c707 --- /dev/null +++ b/backends/qualcomm/builders/op_xor.py @@ -0,0 +1,60 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch + +from .node_visitor import NodeVisitor +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpElementWiseXor, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class OpXor(NodeVisitor): + target = ["aten.bitwise_xor.Tensor"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + xor_output_tensors = [output_tensor_wrapper] + + xor_input_tensors = [] + for index in range(2): + input_node = self.get_node(node.args[index]) + input_tensor = self.get_tensor(input_node, node) + tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + tensor_type, + nodes_to_wrappers, + ) + xor_input_tensors.append(input_tensor_wrapper) + xor_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseXor.op_name, + ) + xor_op.AddInputTensors(xor_input_tensors) + xor_op.AddOutputTensors(xor_output_tensors) + return xor_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 74ffe24e3c4..b0c44dcae80 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -112,11 +112,22 @@ class OpElementWiseAnd: op_name: str = "ElementWiseAnd" +@dataclass(init=False, frozen=True) +class OpElementWiseAsin: + op_name: str = "ElementWiseAsin" + + @dataclass(init=False, frozen=True) class OpElementWiseAtan: op_name: str = "ElementWiseAtan" +@dataclass(init=False, frozen=True) +class OpElementWiseBinary: + op_name: str = "ElementWiseBinary" + param_operation: str = "operation" + + @dataclass(init=False, frozen=True) class OpElementWiseCeil: op_name = "ElementWiseCeil" @@ -240,6 +251,11 @@ class OpElementWiseSelect: op_name = "ElementWiseSelect" +@dataclass(init=False, frozen=True) +class OpElementWiseSign: + op_name: str = "ElementWiseSign" + + @dataclass(init=False, frozen=True) class OpElementWiseSquareRoot: op_name = "ElementWiseSquareRoot" @@ -250,6 +266,11 @@ class OpElementWiseSubtract: op_name = "ElementWiseSubtract" +@dataclass(init=False, frozen=True) +class OpElementWiseXor: + op_name: str = "ElementWiseXor" + + @dataclass(init=False, frozen=True) class OpElu: op_name: str = "Elu" diff --git a/backends/qualcomm/partition/utils.py b/backends/qualcomm/partition/utils.py index 4eee818efe5..97e0b4bd109 100644 --- a/backends/qualcomm/partition/utils.py +++ b/backends/qualcomm/partition/utils.py @@ -44,6 +44,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]: torch.ops.aten.adaptive_avg_pool2d.default, torch.ops.aten.col2im.default, torch.ops.aten.elu.default, + torch.ops.aten.floor_divide.default, torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardswish.default, torch.ops.aten.im2col.default, diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 38a8bc6ebe6..721fc85362f 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -463,6 +463,11 @@ def annotate_sin(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.floor_divide.default]) +def annotate_floor_divide(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_binary(node, quantization_config) + + @register_annotator([torch.ops.aten.scalar_tensor.default]) def annotate_scalar_tensor(node: Node, quantization_config: QuantizationConfig) -> None: if _is_annotated([node]): @@ -626,6 +631,11 @@ def annotate_softmax(node: Node, quantization_config: QuantizationConfig) -> Non annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.asin.default]) +def annotate_asin(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in_single_out(node, quantization_config) + + @register_annotator([torch.ops.aten.linalg_vector_norm.default]) def annotate_linalg_vector_norm( node: Node, quantization_config: QuantizationConfig @@ -658,6 +668,11 @@ def annotate_select(node: Node, quantization_config: QuantizationConfig) -> None annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.sign.default]) +def annotate_sign(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in_single_out(node, quantization_config) + + @register_annotator([torch.ops.aten.slice.Tensor]) def annotate_slice(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) @@ -802,6 +817,11 @@ def annotate_bitwise_or(node: Node, quantization_config: QuantizationConfig) -> annotate_binary(node, quantization_config) +@register_annotator([torch.ops.aten.bitwise_xor.Tensor, torch.ops.aten.__xor__.Tensor]) +def annotate_xor(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_binary(node, quantization_config) + + @register_annotator([torch.ops.aten.pow.Tensor_Tensor]) def annotate_pow(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 01ed37f80a3..51dd1b2c950 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -165,6 +165,14 @@ def forward(self, x, y): return squeeze_out, conv_out +class Asin(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.asin(x) + + class Atan(torch.nn.Module): def __init__(self): super().__init__() @@ -776,6 +784,23 @@ def forward(self, x): return torch.floor(x) +class FloorDiv(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.floor_divide(x, y) + + +class FloorDivConstantFloat(torch.nn.Module): + def __init__(self, constant=2.0): + super().__init__() + self.constant = constant + + def forward(self, x): + return torch.floor(x / self.constant) + + class Fold(torch.nn.Module): def __init__(self): super().__init__() @@ -1380,6 +1405,15 @@ def forward(self, x): return self.relu(x) +class Relu6(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu6 = torch.nn.ReLU6() + + def forward(self, x): + return self.relu6(x) + + class Repeat(torch.nn.Module): def __init__(self): super().__init__() @@ -1564,6 +1598,14 @@ def forward(self, x): return torch.sigmoid(x) +class Sign(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sign(x) + + class Sin(torch.nn.Module): def __init__(self): super().__init__() @@ -1853,6 +1895,28 @@ def forward(self, x): ) +class XorBitWise(torch.nn.Module): + def __init__(self, pos, neg): + super().__init__() + self.pos = pos + self.neg = neg + + def forward(self, x, y): + bitwise_xor = torch.bitwise_xor(x, y).bool() + return torch.where(bitwise_xor, self.pos, self.neg) + + +class XorOperator(torch.nn.Module): + def __init__(self, pos, neg): + super().__init__() + self.pos = pos + self.neg = neg + + def forward(self, x, y): + operator_xor = x.to(torch.bool) ^ y.to(torch.bool) + return torch.where(operator_xor, self.pos, self.neg) + + # Mimi Decoder has 0D tensor which QNN cannot handle. class ZeroDimTensor(torch.nn.Module): def __init__(self): diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index b4577946cc3..868843c8682 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -164,7 +164,13 @@ def test_qnn_backend_argmax(self): def test_qnn_backend_argmin(self): module = Argmin() # noqa: F405 - sample_input = (torch.randn(16, 3, 4, 4),) + sample_input = (torch.rand(3, 4),) + self.lower_module_and_test_output(module, sample_input) + + @unittest.expectedFailure + def test_qnn_backend_asin(self): + sample_input = (torch.rand(3, 4) * 2 - 1,) + module = Asin() # noqa: F405 self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_atan(self): @@ -535,6 +541,30 @@ def test_qnn_backend_floor(self): module = Floor() # noqa: F405 self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_floor_divide(self): + eps = 1e-03 + test_comb = [ + { + QCOM_MODULE: [FloorDiv()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ + (torch.randn(2, 5, 1, 3), eps + torch.randn(2, 5, 1, 3)), + (torch.randn([2, 5, 1, 3]), eps + torch.randn([4, 1])), + ], + }, + { + QCOM_MODULE: [FloorDivConstantFloat()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], + }, + ] + + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + self.lower_module_and_test_output(module, sample_input) + index += 1 + def test_qnn_backend_fold(self): sample_input = (torch.randn(3, 512, 256),) module = Fold() # noqa: F405 @@ -817,6 +847,11 @@ def test_qnn_backend_less_equal(self): test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] ) + def test_qnn_backend_sign(self): + module = Sign() # noqa: F405 + sample_input = (torch.randn(3, 4),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_less_than(self): test_comb = [ { @@ -971,6 +1006,11 @@ def test_qnn_backend_relu(self): sample_input = (torch.randn([2, 5, 1, 3]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_relu6(self): + module = Relu6() # noqa: F405 + sample_input = (torch.randn([2, 5, 1, 3]),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_repeat(self): module = Repeat() # noqa: F405 sample_input = (torch.randn([2, 2, 2, 2]),) @@ -1129,6 +1169,33 @@ def test_qnn_backend_where(self): for i, module in enumerate(modules): self.lower_module_and_test_output(module, sample_inputs[i]) + def test_qnn_backend_element_wise_xor(self): + test_comb = [ + { + QCOM_MODULE: XorBitWise( # noqa: F405 + torch.tensor(1.7), torch.tensor(0.2) + ), + QCOM_SAMPLE_INPUTS: ( + torch.tensor([1, 0, 1, 0], dtype=torch.bool), + torch.tensor([1, 1, 0, 0], dtype=torch.bool), + ), + }, + { + QCOM_MODULE: XorOperator( # noqa: F405 + torch.tensor(1.5), torch.tensor(-1.2) + ), + QCOM_SAMPLE_INPUTS: ( + torch.full((3, 3), 1).triu(), + torch.full((3, 3), 1).tril(diagonal=0), + ), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + self.lower_module_and_test_output( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + def test_qnn_backend_masked_fill(self): module = MaskedFill() # noqa: F405 attn_mask = torch.ones((64, 49, 49), dtype=torch.float32) @@ -1473,6 +1540,12 @@ def test_qnn_backend_argmin(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_asin(self): + module = Asin() # noqa: F405 + sample_input = (torch.rand([3, 4]) * 2 - 1,) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_atan(self): sample_input = (torch.randn(3, 4),) module = Atan() # noqa: F405 @@ -1883,6 +1956,31 @@ def test_qnn_backend_floor(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_floor_divide(self): + eps = 1e-03 + test_comb = [ + { + QCOM_MODULE: [FloorDiv()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ + (torch.randn(2, 5, 1, 3), eps + torch.randn(2, 5, 1, 3)), + (torch.randn([2, 5, 1, 3]), eps + torch.randn([4, 1])), + ], + }, + { + QCOM_MODULE: [FloorDivConstantFloat()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], + }, + ] + + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + index += 1 + def test_qnn_backend_fold(self): sample_input = (torch.randn(3, 512, 256),) module = Fold() # noqa: F405 @@ -2375,6 +2473,12 @@ def test_qnn_backend_relu(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_relu6(self): + module = Relu6() # noqa: F405 + sample_input = (torch.randn([2, 5, 1, 3]),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_repeat(self): module = Repeat() # noqa: F405 sample_input = (torch.randn([2, 2, 2, 2]),) @@ -2453,6 +2557,12 @@ def test_qnn_backend_sigmoid(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_sign(self): + module = Sign() # noqa: F405 + sample_input = (torch.randn(3, 4),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_sin(self): module = Sin() # noqa: F405 sample_input = (torch.randn(2, 5, 1, 3),) @@ -2572,6 +2682,34 @@ def test_qnn_backend_masked_fill(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_element_wise_xor(self): + test_comb = [ + { + QCOM_MODULE: XorBitWise( # noqa: F405 + torch.tensor(1.7), torch.tensor(0.2) + ), + QCOM_SAMPLE_INPUTS: ( + torch.tensor([1, 0, 1, 0], dtype=torch.bool), + torch.tensor([1, 1, 0, 0], dtype=torch.bool), + ), + }, + { + QCOM_MODULE: XorOperator( # noqa: F405 + torch.tensor(1.5), torch.tensor(-1.2) + ), + QCOM_SAMPLE_INPUTS: ( + torch.full((3, 3), 1).triu(), + torch.full((3, 3), 1).tril(diagonal=0), + ), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + module = self.get_qdq_module( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + self.lower_module_and_test_output(module, test[QCOM_SAMPLE_INPUTS]) + class TestQNNQuantizedModel(TestQNN): # TODO: refactor to support different backends