Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from executorch.backends.qualcomm.utils.constants import (
QCOM_AXIS_ORDER,
QCOM_INSERTED_PERMUTE,
QCOM_LAYOUT_CHANGE,
QCOM_QUANT_ATTRS,
QCOM_REQUANTIZE,
)
Expand All @@ -34,6 +35,7 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.max_pool2d_with_indices.default,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.native_group_norm.default,
exir_ops.edge.aten.pixel_shuffle.default,
exir_ops.edge.aten.pixel_unshuffle.default,
exir_ops.edge.aten.upsample_bilinear2d.default,
Expand Down Expand Up @@ -95,6 +97,7 @@ def __init__(
self.edge_program = edge_program
self.insert_permute = insert_permute
self.qdq_opset = {*q_ops, *dq_ops}
self.transformed_tag = QCOM_AXIS_ORDER

def mark_as_transformed(self, node: torch.fx.Node) -> None:
if isinstance(node.meta["val"], (tuple, list)):
Expand All @@ -105,18 +108,18 @@ def mark_as_transformed(self, node: torch.fx.Node) -> None:
f"got {getitem_node.target.__name__}"
)
index = getitem_node.args[1]
node.meta[QCOM_AXIS_ORDER] = self.get_axis_order(
node.meta[self.transformed_tag] = self.get_axis_order(
eval_shape(node.meta["val"][index].shape)
)
else:
node.meta[QCOM_AXIS_ORDER] = self.get_axis_order(
node.meta[self.transformed_tag] = self.get_axis_order(
eval_shape(node.meta["val"].shape)
)

def is_transformed_node(self, node: torch.fx.Node) -> bool:
if not hasattr(node, "meta"):
return False
return QCOM_AXIS_ORDER in node.meta
return self.transformed_tag in node.meta

def is_layout_sensitive(self, node: torch.fx.Node) -> bool:
return node.target in self.layout_sensitive_ops
Expand Down Expand Up @@ -186,8 +189,23 @@ def insert_node(self, graph_module, node, revert_layout: bool) -> None:
# we need this to check the annotation boundary
permute.meta[QCOM_INSERTED_PERMUTE] = True

# this is the case when residual connection happened:
# e.g. consider following graph
# x --> permute --> layer_norm --> permute --> conv2d --> add
# └-------------------------------------┙
# we should have premute node to be correctly inserted as:
# x --> permute --> layer_norm --> permute --> qnn_permute --> conv2d --> add
# └--------------------------------------> qnn_premute -┙
# i.e. insert permute by condition between user and current node
# if there are multiple users included
is_node_transformed = self.is_transformed_node(node)
for user in users:
user.replace_input_with(node, permute)
is_user_transformed = (
self.is_transformed_node(user) or QCOM_LAYOUT_CHANGE in user.meta
)
# insert permute only in exclusive condition
if is_node_transformed != is_user_transformed:
user.replace_input_with(node, permute)

def create_call_function_node(
self,
Expand Down Expand Up @@ -243,6 +261,15 @@ def call(self, graph_module: torch.fx.GraphModule):
sensitive_nodes = [
node for node in graph.nodes if self.is_layout_sensitive(node)
]
# perform first run traversal for identifying nodes subjected to layout changes
if self.insert_permute:
self.insert_permute, self.transformed_tag = False, QCOM_LAYOUT_CHANGE
for node in sensitive_nodes:
if not self.is_transformed_node(node):
self.mark_as_transformed(node)
self.traverse(node, graph_module)
self.insert_permute, self.transformed_tag = True, QCOM_AXIS_ORDER

for node in sensitive_nodes:
if not self.is_transformed_node(node):
self.mark_as_transformed(node)
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
op_embedding,
op_expand,
op_gelu,
op_group_norm,
op_hardsigmoid,
op_hardswish,
op_hardtanh,
Expand Down Expand Up @@ -76,6 +77,7 @@
op_embedding,
op_expand,
op_gelu,
op_group_norm,
op_hardswish,
op_hardtanh,
op_hardsigmoid,
Expand Down
92 changes: 92 additions & 0 deletions backends/qualcomm/builders/op_group_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import numpy as np
import torch

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpGroupNorm, QNN_OP_PACKAGE_NAME_QTI_AISW
from .utils import get_parameter


@register_node_visitor
class GroupNormVisitor(NodeVisitor):
target = ["aten.native_group_norm.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=True,
)

weight_node = node.args[1]
weight_tensor = get_parameter(weight_node, self.edge_program)
weight_tensor_wrapper = self.define_tensor(
weight_node,
weight_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
is_input_tensor=False,
)

bias_node = node.args[2]
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
bias_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
is_input_tensor=False,
)
group = node.args[6]
epsilon = node.args[7]

output_tensor = self.get_tensor(node, node, 0)
output_tensor_wrapper = self.define_tensor(
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=False,
)

group_norm_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpGroupNorm.op_name,
)
group_norm_op.AddInputTensors(
[input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper]
)
group_norm_op.AddOutputTensors([output_tensor_wrapper])
group_norm_op.AddScalarParam(
OpGroupNorm.param_epsilon,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
{"data": np.float32(epsilon)},
)
group_norm_op.AddScalarParam(
OpGroupNorm.param_group,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
{"data": np.uint32(group)},
)

return group_norm_op
6 changes: 6 additions & 0 deletions backends/qualcomm/builders/qnn_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@ class OpGelu:
op_name: str = "Gelu"


class OpGroupNorm:
op_name: str = "GroupNorm"
param_epsilon = "epsilon"
param_group = "group"


@dataclass(init=False, frozen=True)
class OpHardSwish:
op_name: str = "HardSwish"
Expand Down
33 changes: 33 additions & 0 deletions backends/qualcomm/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,39 @@ def annotate_expand(node: Node, quantization_config: QuantizationConfig) -> None
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.group_norm.default])
def annotate_group_norm(node: Node, quantization_config: QuantizationConfig) -> None:
act_node = node.args[0]
weight_node = node.args[2]
bias_node = None
if len(node.args) > 2:
bias_node = node.args[3]

if _is_annotated([node]):
return

_annotate_input_qspec_map(
node,
act_node,
quantization_config.input_activation,
)
_annotate_input_qspec_map(
node,
weight_node,
quantization_config.weight,
)
nodes_to_mark_annotated = [node, weight_node]
if bias_node:
_annotate_input_qspec_map(
node,
bias_node,
quantization_config.bias,
)
nodes_to_mark_annotated.append(bias_node)
_annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated(nodes_to_mark_annotated)


@register_annotator([torch.ops.aten.flatten.using_ints])
def annotate_flatten(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
Expand Down
18 changes: 18 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,24 @@ def forward(self, x):
return self.gelu(x)


class GroupNorm(torch.nn.Module):
def __init__(self, bias=True):
super().__init__()
self.conv = torch.nn.Conv2d(
32,
256,
kernel_size=3,
stride=1,
padding=1,
bias=bias,
)
self.norm = torch.nn.GroupNorm(32, 256)

def forward(self, x):
y = self.conv(x)
return y, self.norm(y)


class HardSigmoid(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
50 changes: 50 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,13 @@ def test_qnn_backend_gelu(self):
sample_input = (torch.randn(2, 5, 1, 3),)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_group_norm(self):
modules = [GroupNorm(), GroupNorm(bias=False)] # noqa: F405
sample_input = (torch.randn(3, 32, 56, 56),)
for i, module in enumerate(modules):
with self.subTest(i=i):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_hardsigmoid(self):
module = HardSigmoid() # noqa: F405
sample_input = (torch.randn(2, 5, 1, 3),)
Expand Down Expand Up @@ -964,6 +971,14 @@ def test_qnn_backend_gelu(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_group_norm(self):
modules = [GroupNorm(), GroupNorm(bias=False)] # noqa: F405
sample_input = (torch.randn(3, 32, 56, 56),)
for i, module in enumerate(modules):
with self.subTest(i=i):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_hardsigmoid(self):
module = HardSigmoid() # noqa: F405
sample_input = (torch.randn(2, 5, 1, 3),)
Expand Down Expand Up @@ -2147,6 +2162,41 @@ def test_regnet(self):
self.assertGreaterEqual(msg["top_1"], 60)
self.assertGreaterEqual(msg["top_5"], 85)

def test_retinanet(self):
if not self.required_envs([self.image_dataset]):
self.skipTest("missing required envs")

cmds = [
"python",
f"{self.executorch_root}/examples/qualcomm/oss_scripts/retinanet.py",
"--artifact",
self.artifact_dir,
"--build_folder",
self.build_folder,
"--device",
self.device,
"--model",
self.model,
"--dataset",
self.image_dataset,
"--ip",
self.ip,
"--port",
str(self.port),
]
if self.host:
cmds.extend(["--host", self.host])

p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
with Listener((self.ip, self.port)) as listener:
conn = listener.accept()
p.communicate()
msg = json.loads(conn.recv())
if "Error" in msg:
self.fail(msg["Error"])
else:
self.assertGreaterEqual(msg["mAP"], 0.6)

def test_squeezenet(self):
if not self.required_envs([self.image_dataset]):
self.skipTest("missing required envs")
Expand Down
1 change: 1 addition & 0 deletions backends/qualcomm/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
QCOM_DTYPE = "dtype"
QCOM_ENCODING = "encoding"
QCOM_INSERTED_PERMUTE = "qnn_permute"
QCOM_LAYOUT_CHANGE = "layout_change"
QCOM_OFFSET = "offset"
QCOM_QUANTIZED_IO = "q_tensor_io"
QCOM_QUANT_ATTRS = "quant_attrs"
Expand Down
Loading
Loading