Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class LayoutTransform(ExportPass):
layout_agnostic_ops = {
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.amax.default,
exir_ops.edge.aten.bitwise_or.Tensor,
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.bitwise_and.Tensor,
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
op_abs,
op_adaptive_avg_pool2d,
op_add,
op_amax,
op_and,
op_arange,
op_argmin,
Expand Down Expand Up @@ -95,6 +96,7 @@
op_abs,
op_adaptive_avg_pool2d,
op_add,
op_amax,
op_and,
op_arange,
op_argmin,
Expand Down
84 changes: 84 additions & 0 deletions backends/qualcomm/builders/op_amax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import cast, Dict, List

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import numpy as np

import torch
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpAmax, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class AMax(NodeVisitor):
target = ["aten.amax.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

# mean dims and keep dims
mean_dims = cast(List[int], node.args[1])
mean_dims = [
mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims
]
if QCOM_AXIS_ORDER in node.meta:
mean_dims = [
node.meta[QCOM_AXIS_ORDER].index(mean_dim) for mean_dim in mean_dims
]
mean_dims_shape = [len(mean_dims)]

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

reduce_max_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpAmax.op_name,
)
reduce_max_op.AddInputTensors([input_tensor_wrapper])
reduce_max_op.AddOutputTensors([output_tensor_wrapper])
reduce_max_op.AddTensorParam(
OpAmax.param_axes,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
len(mean_dims_shape),
mean_dims_shape,
np.array(mean_dims, dtype=np.uint32),
True,
)
if len(node.args) > 2:
keep_dims = cast(bool, node.args[2])
reduce_max_op.AddScalarParam(
OpAmax.param_keep_dims,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
{QCOM_DATA: keep_dims},
)

return reduce_max_op
7 changes: 7 additions & 0 deletions backends/qualcomm/builders/qnn_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
# instead of replicating them here.


@dataclass(init=False, frozen=True)
class OpAmax:
op_name: str = "ReduceMax"
param_axes: str = "axes"
param_keep_dims: str = "keep_dims"


@dataclass(init=False, frozen=True)
class OpBatchnorm:
op_name: str = "Batchnorm"
Expand Down
5 changes: 5 additions & 0 deletions backends/qualcomm/quantizer/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,11 @@ def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_binary(node, quantization_config)


@register_annotator([torch.ops.aten.amax.default])
def annotate_amax(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_binary(node, quantization_config)


@register_annotator([torch.ops.aten.argmin.default])
def annotate_argmin(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
Expand Down
10 changes: 10 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ def forward(self, x):
return torch.any(x, dim=self.dim, keepdim=self.keepdim)


class AMax(torch.nn.Module):
def __init__(self, dim=None, keepdim=False):
super().__init__()
self.dim = dim
self.keepdim = keepdim

def forward(self, x):
return torch.amax(x, dim=self.dim, keepdim=self.keepdim)


class Arange(torch.nn.Module):
def __init__(self, start, end, step, dtype):
super().__init__()
Expand Down
15 changes: 15 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ def test_qnn_backend_adaptive_avg_pool2d(self):
sample_input = (torch.randn(1, 512, 7, 7),)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_amax(self):
modules = [AMax(dim=1, keepdim=False), AMax(dim=1, keepdim=True)] # noqa: F405
sample_input = (torch.randn(4, 4),)
for i, module in enumerate(modules):
with self.subTest(i=i):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_any(self):
modules = [Any(), Any(dim=[0, 1]), Any(dim=1, keepdim=True)] # noqa: F405
sample_input = (torch.randn(3, 3, 3) > 0,)
Expand Down Expand Up @@ -1111,6 +1118,14 @@ def test_qnn_backend_adaptive_avg_pool2d(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_amax(self):
modules = [AMax(dim=1, keepdim=False), AMax(dim=1, keepdim=True)] # noqa: F405
sample_input = (torch.randn(4, 4),)
for i, module in enumerate(modules):
with self.subTest(i=i):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_any(self):
modules = [Any(), Any(dim=[0, 1]), Any(dim=1, keepdim=True)] # noqa: F405
sample_input = (torch.randn(3, 3, 3) > 0,)
Expand Down
Loading