Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.amax.default,
exir_ops.edge.aten.atan.default,
exir_ops.edge.aten.bitwise_or.Tensor,
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.bitwise_and.Tensor,
Expand All @@ -75,6 +76,7 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.elu.default,
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.floor.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.full_like.default,
exir_ops.edge.aten.ge.Tensor,
Expand All @@ -99,6 +101,7 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.pow.Tensor_Scalar,
exir_ops.edge.aten.prelu.default,
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten.round.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.split_with_sizes.default,
Expand Down
8 changes: 4 additions & 4 deletions backends/qualcomm/builders/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ The operator now should be functional for Qualcomm backends. For operator to wor
## Operator Support Status
Please help update following table if you are contributing new operators:

| Operators | HTP - 77/116 Enabled |
| Operators | HTP - 80/116 Enabled |
|-----------|---------|
| Argmax | ✗ |
| Argmin | ✓ |
Expand All @@ -382,14 +382,14 @@ Please help update following table if you are contributing new operators:
| ElementWiseAdd | ✓ |
| ElementWiseAnd | ✓ |
| ElementWiseAsin | ✗ |
| ElementWiseAtan | ✗ |
| ElementWiseAtan | ✓ |
| ElementWiseBinary | ✗ |
| ElementWiseCeil | ✓ |
| ElementWiseCos | ✓ |
| ElementWiseDivide | ✓ |
| ElementWiseEqual | ✓ |
| ElementWiseExp | ✓ |
| ElementWiseFloor | ✗ |
| ElementWiseFloor | ✓ |
| ElementWiseFloorDiv | ✗ |
| ElementWiseGreater | ✓ |
| ElementWiseGreaterEqual | ✓ |
Expand All @@ -405,7 +405,7 @@ Please help update following table if you are contributing new operators:
| ElementWiseNotEqual | ✓ |
| ElementWiseOr | ✓ |
| ElementWisePower | ✓ |
| ElementWiseRound | ✗ |
| ElementWiseRound | ✓ |
| ElementWiseRsqrt | ✓ |
| ElementWiseSelect | ✓ |
| ElementWiseSign | ✗ |
Expand Down
6 changes: 6 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
op_and,
op_arange,
op_argmin,
op_atan,
op_avg_pool2d,
op_batch_norm,
op_bmm,
Expand All @@ -30,6 +31,7 @@
op_eq,
op_exp,
op_expand,
op_floor,
op_full,
op_full_like,
op_gather,
Expand Down Expand Up @@ -68,6 +70,7 @@
op_reshape,
op_resize,
op_rms_norm,
op_round,
op_rsqrt,
op_scalar_tensor,
op_select_copy,
Expand Down Expand Up @@ -103,6 +106,7 @@
op_and,
op_arange,
op_argmin,
op_atan,
op_avg_pool2d,
op_batch_norm,
op_bmm,
Expand All @@ -120,6 +124,7 @@
op_eq,
op_exp,
op_expand,
op_floor,
op_full,
op_full_like,
op_gather,
Expand Down Expand Up @@ -158,6 +163,7 @@
op_reshape,
op_resize,
op_rms_norm,
op_round,
op_rsqrt,
op_scalar_tensor,
op_select_copy,
Expand Down
55 changes: 55 additions & 0 deletions backends/qualcomm/builders/op_atan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
import torch

from .node_visitor import NodeVisitor
from .node_visitor_manager import register_node_visitor
from .qnn_constants import OpElementWiseAtan, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class Atan(NodeVisitor):
target = ["aten.atan.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

atan_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpElementWiseAtan.op_name,
)
atan_op.AddInputTensors([input_tensor_wrapper])
atan_op.AddOutputTensors([output_tensor_wrapper])

return atan_op
56 changes: 56 additions & 0 deletions backends/qualcomm/builders/op_floor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
import torch

from .node_visitor import NodeVisitor
from .node_visitor_manager import register_node_visitor
from .qnn_constants import OpElementWiseFloor, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class Floor(NodeVisitor):
target = ["aten.floor.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
floor_inp_tensor_wrapper = self.define_tensor(
input_node,
node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
floor_input_tensors = [floor_inp_tensor_wrapper]

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
floor_output_tensors = [output_tensor_wrapper]

floor_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpElementWiseFloor.op_name,
)
floor_op.AddInputTensors(floor_input_tensors)
floor_op.AddOutputTensors(floor_output_tensors)
return floor_op
58 changes: 58 additions & 0 deletions backends/qualcomm/builders/op_round.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import warnings
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
import torch

from .node_visitor import NodeVisitor
from .node_visitor_manager import register_node_visitor

from .qnn_constants import OpElementWiseRound, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class Round(NodeVisitor):
target = ["aten.round.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

if len(node.args) > 1:
warnings.warn(
"[QNN Delegate Op Builder]: QNN dose not support decimals",
stacklevel=1,
)
return None

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

round_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpElementWiseRound.op_name,
)
round_op.AddInputTensors([input_tensor_wrapper])
round_op.AddOutputTensors([output_tensor_wrapper])
return round_op
15 changes: 15 additions & 0 deletions backends/qualcomm/builders/qnn_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ class OpElementWiseAnd:
op_name: str = "ElementWiseAnd"


@dataclass(init=False, frozen=True)
class OpElementWiseAtan:
op_name: str = "ElementWiseAtan"


@dataclass(init=False, frozen=True)
class OpElementWiseCeil:
op_name = "ElementWiseCeil"
Expand All @@ -130,6 +135,11 @@ class OpElementWiseEqual:
op_name: str = "ElementWiseEqual"


@dataclass(init=False, frozen=True)
class OpElementWiseFloor:
op_name: str = "ElementWiseFloor"


@dataclass(init=False, frozen=True)
class OpElementWiseGreater:
op_name: str = "ElementWiseGreater"
Expand Down Expand Up @@ -203,6 +213,11 @@ class OpElementWisePower:
op_name: str = "ElementWisePower"


@dataclass(init=False, frozen=True)
class OpElementWiseRound:
op_name: str = "ElementWiseRound"


@dataclass(init=False, frozen=True)
class OpElementWiseRsqrt:
op_name: str = "ElementWiseRsqrt"
Expand Down
15 changes: 15 additions & 0 deletions backends/qualcomm/quantizer/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ def annotate_single_in_single_out(
)


@register_annotator([torch.ops.aten.atan.default])
def annotate_atan(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.topk.default])
def annotate_topk(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
Expand Down Expand Up @@ -404,6 +409,11 @@ def annotate_clamp(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.floor.default])
def annotate_floor(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.relu.default, torch.ops.aten.relu_.default])
def annotate_relu(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
Expand All @@ -414,6 +424,11 @@ def annotate_repeat(node: Node, quantization_config: QuantizationConfig) -> None
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.round.default])
def annotate_round(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.cos.default])
def annotate_cos(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
Expand Down
24 changes: 24 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,14 @@ def forward(self, x, y):
return squeeze_out, conv_out


class Atan(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.atan(x)


class AvgPoolModule(torch.nn.Module):
def __init__(self, kernel_size, stride, padding, ceil_mode):
super().__init__()
Expand Down Expand Up @@ -741,6 +749,14 @@ def forward(self, x):
return torch.special.expm1(x)


class Floor(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.floor(x)


class Fold(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -1448,6 +1464,14 @@ def forward(self, x):
return torch.roll(x, shifts=self.shifts, dims=self.dims)


class Round(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.round(x)


class Rsqrt(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading
Loading