Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,12 @@ class LayoutTransform(ExportPass):
layout_sensitive_ops = {
exir_ops.edge.aten.adaptive_avg_pool2d.default,
exir_ops.edge.aten._adaptive_avg_pool3d.default,
exir_ops.edge.aten.adaptive_max_pool2d.default,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.avg_pool3d.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.grid_sampler_2d.default,
exir_ops.edge.aten.grid_sampler_3d.default,
exir_ops.edge.aten.instance_norm.default,
exir_ops.edge.aten.max_pool2d_with_indices.default,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
Expand Down
26 changes: 15 additions & 11 deletions backends/qualcomm/builders/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@
Thank you for contributing to Qualcomm AI Engine Direct delegate for ExecuTorch. Reading and following these guidelines will help you quickly get the essentials of implementing operator builder to unblock yourself and land pull requests more efficiently.

## Sections
* [References](#references)
* [Getting Started](#getting-started)
* [Identify Unsupported Operator](#identify-unsupported-operator)
* [Check Operator Spec](#check-operator-spec)
* [Implementation](#implementation)
* [Quantizer Annotation](#quantizer-annotation)
* [Operator Support Status](#operator-support-status)
* [Issues](#issues)
* [Pull Requests](#pull-requests)
- [Contribution for More Operators](#contribution-for-more-operators)
- [Sections](#sections)
- [References](#references)
- [Qualcomm AI Engine Direct](#qualcomm-ai-engine-direct)
- [PyTorch](#pytorch)
- [Getting Started](#getting-started)
- [Identify Unsupported Operator](#identify-unsupported-operator)
- [Check Operator Spec](#check-operator-spec)
- [Implementation](#implementation)
- [Quantizer Annotation](#quantizer-annotation)
- [Operator Support Status](#operator-support-status)
- [Issues](#issues)
- [Pull Requests](#pull-requests)

## References
### Qualcomm AI Engine Direct
Expand Down Expand Up @@ -365,7 +369,7 @@ Please help update following table if you are contributing new operators:
+ 🚫 = Deprecated, supported with other QNN Ops


| Operators | HTP - 92/116 Enabled |
| Operators | HTP - 94/116 Enabled |
|-----------|---------|
| Argmax | ✓ |
| Argmin | ✓ |
Expand Down Expand Up @@ -431,7 +435,7 @@ Please help update following table if you are contributing new operators:
| Gelu | ✓ |
| GetSparseIndices | ✗ |
| GetSparseValues | ✗ |
| GridSample | ✗ |
| GridSample | ✓ |
| GroupNorm | ✓ |
| HardSwish | ✓ |
| InstanceNorm | ✓ |
Expand Down
4 changes: 4 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
node_visitor,
op_abs,
op_adaptive_avg_pool2d,
op_adaptive_max_pool2d,
op_add,
op_amax,
op_amin,
Expand Down Expand Up @@ -44,6 +45,7 @@
op_gather,
op_ge,
op_gelu,
op_grid_sampler_2d,
op_group_norm,
op_gt,
op_hardsigmoid,
Expand Down Expand Up @@ -114,6 +116,7 @@
node_visitor,
op_abs,
op_adaptive_avg_pool2d,
op_adaptive_max_pool2d,
op_add,
op_amax,
op_amin,
Expand Down Expand Up @@ -150,6 +153,7 @@
op_gather,
op_ge,
op_gelu,
op_grid_sampler_2d,
op_group_norm,
op_gt,
op_hardswish,
Expand Down
151 changes: 151 additions & 0 deletions backends/qualcomm/builders/op_adaptive_max_pool2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import cast, Dict, List

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
import numpy as np

import torch
from executorch.backends.qualcomm.utils.constants import QCOM_DATA

from .node_visitor import NodeVisitor
from .node_visitor_manager import register_node_visitor
from .qnn_constants import OpPoolMax2d, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class AdaptiveMaxPool2D(NodeVisitor):
target = ["aten.adaptive_max_pool2d.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
users = list(node.users.keys())
for user in users:
if user.target.__name__ == "getitem":
getitem_index = user.args[1]
if getitem_index != 0:
warnings.warn(
f"[QNN Delegate Op Builder]: Expected second argument of getitem node for {node.target.__name__ } to be 0, got {getitem_index}",
stacklevel=1,
)
return

if len(node.args) > 2:
warnings.warn(
"[QNN Delegate Op Builder]: The return_indices is not supported, fallback op",
stacklevel=1,
)
return

input_height = input_tensor.shape[1]
input_width = input_tensor.shape[2]
# output cases
out_wh = cast(List[int], node.args[1])
if len(out_wh) == 1:
output_height = node.args[1][0]
output_width = node.args[1][0]
else:
output_height = node.args[1][0]
output_width = node.args[1][1]
if output_height is None:
output_height = input_height
if output_width is None:
output_width = input_width
# NOTE: Here we need not to emphasize on mode, cuz the output shape is decided by user.
mode = OpPoolMax2d.RoundingMode.FLOOR

# floor division
stride_height = input_height // output_height
filter_height = input_height - (output_height - 1) * stride_height
stride_width = input_width // output_width
filter_width = input_width - (output_width - 1) * stride_width

filter = [filter_height, filter_width]
filter_shape = [len(filter)]

stride = [stride_height, stride_width]
stride_shape = [len(stride)]

padding = [0, 0]
padding_shape = [len(padding), len(padding)]

out_tensor = self.get_tensor(node, node, 0)
output_tensor_wrapper = self.define_tensor(
node,
node,
out_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

adaptive_max_pool2d_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpPoolMax2d.op_name,
)

adaptive_max_pool2d_op.AddInputTensors([input_tensor_wrapper])
adaptive_max_pool2d_op.AddOutputTensors([output_tensor_wrapper])

adaptive_max_pool2d_op.AddTensorParam(
OpPoolMax2d.param_filter_size,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
len(filter_shape),
filter_shape,
np.array(
filter,
dtype=np.uint32,
),
True,
)

adaptive_max_pool2d_op.AddTensorParam(
OpPoolMax2d.param_stride,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
len(stride_shape),
stride_shape,
np.array(
stride,
dtype=np.uint32,
),
True,
)

adaptive_max_pool2d_op.AddTensorParam(
OpPoolMax2d.param_pad_amount,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
len(padding_shape),
padding_shape,
np.array(
[[padding[0], padding[0]], [padding[1], padding[1]]],
dtype=np.uint32,
),
True,
)

adaptive_max_pool2d_op.AddScalarParam(
OpPoolMax2d.param_rounding_mode,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
{QCOM_DATA: np.uint32(mode)},
)

return adaptive_max_pool2d_op
Loading
Loading