Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .convert_linear_to_conv2d import ConvertLinearToConv2d
from .convert_square_to_pow import ConvertSquareToPow
from .decompose_any import DecomposeAny
from .decompose_binary_alpha import DecomposeBinaryAlpha
from .decompose_cdist import DecomposeCDist
from .decompose_col_im import DecomposeColIm
from .decompose_einsum import DecomposeEinsum
Expand Down Expand Up @@ -53,6 +54,7 @@
ConvertLinearToConv2d,
ConvertSquareToPow,
DecomposeAny,
DecomposeBinaryAlpha,
DecomposeCDist,
DecomposeColIm,
DecomposeEinsum,
Expand Down
1 change: 1 addition & 0 deletions backends/qualcomm/_passes/canonicalize_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self, edge_program: torch.export.ExportedProgram):
self.transpose_conv_set = {
torch.ops.aten.conv_transpose1d.default,
torch.ops.aten.conv_transpose2d.input,
torch.ops.aten.conv_transpose3d.input,
}

def dilate(self, tensor, dilation):
Expand Down
61 changes: 61 additions & 0 deletions backends/qualcomm/_passes/decompose_binary_alpha.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import copy_meta

decomp_set = {torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor}


class DecomposeBinaryAlpha(ExportPass):
"""
QNN does not support alpha parameter for add/sub.
Decompose to mul + add / mul + sub
"""

def __init__(self) -> None:
super().__init__()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph
for node in graph.nodes:
if (
node.target in decomp_set
and "alpha" in node.kwargs
and node.kwargs["alpha"] != 1
):
alpha = node.kwargs["alpha"]
# Remove alpha from immutable dict
node.kwargs = {k: v for k, v in node.kwargs.items() if k != "alpha"}
input2_node = node.args[1]
# If input2 is constant, we can just multiply the value for optimization
if isinstance(input2_node, (int, float)):
arg_list = list(node.args)
arg_list[1] = input2_node * alpha
node.args = tuple(arg_list)
continue
with graph.inserting_before(node):
mul_op = torch.ops.aten.mul.Scalar
mul_node = graph.create_node(
"call_function",
mul_op,
(
input2_node,
alpha,
),
)
mul_node.meta = copy_meta(node.meta)
node.replace_input_with(input2_node, mul_node)
node.args = (
node.args[0],
mul_node,
)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
3 changes: 3 additions & 0 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ConvertLinearToConv2d,
ConvertSquareToPow,
DecomposeAny,
DecomposeBinaryAlpha,
DecomposeCDist,
DecomposeColIm,
DecomposeEinsum,
Expand Down Expand Up @@ -193,6 +194,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(RecomposePixelUnshuffle(quantization_capture=True))
self.add_pass(RecomposeRmsNorm(quantization_capture=True))
self.add_pass(ReplaceArangeArgs())
self.add_pass(DecomposeBinaryAlpha())
self.add_pass(DecomposeCDist())
self.add_pass(DecomposeScaledDotProductAttention())
self.add_pass(DecomposeRoll())
Expand All @@ -208,6 +210,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
def transform_for_export_pipeline(
self, exported_program: ExportedProgram, convert_linear_to_conv2d: bool = False
):
self.add_pass(DecomposeBinaryAlpha())
self.add_pass(DecomposeCDist())
self.add_pass(DecomposeScaledDotProductAttention())
self.add_pass(DecomposeRoll())
Expand Down
6 changes: 3 additions & 3 deletions backends/qualcomm/builders/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ Please help update following table if you are contributing new operators:
+ 🚫 = Deprecated, supported with other QNN Ops


| Operators | HTP - 90/116 Enabled |
| Operators | HTP - 92/116 Enabled |
|-----------|---------|
| Argmax | ✓ |
| Argmin | ✓ |
Expand All @@ -375,7 +375,7 @@ Please help update following table if you are contributing new operators:
| ChannelShuffle | ✗ |
| Concat | ✓ |
| Conv2d | ✓ |
| Conv3d | ✗ |
| Conv3d | ✓ |
| Convert | ✓ |
| CreateSparse | ✗ |
| CumulativeSum | ✓ |
Expand Down Expand Up @@ -481,7 +481,7 @@ Please help update following table if you are contributing new operators:
| TopK | ✓ |
| TransPose | ✓ |
| TransPoseConv2d | ✓ |
| TransPoseConv3d | ✗ |
| TransPoseConv3d | ✓ |
| Unpack | ✓ |

## Issues
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
op_cat,
op_ceil,
op_clamp,
op_conv2d,
op_conv,
op_copy,
op_cos,
op_cum_sum,
Expand Down Expand Up @@ -129,7 +129,7 @@
op_cat,
op_ceil,
op_clamp,
op_conv2d,
op_conv,
op_copy,
op_cos,
op_cum_sum,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import cast, Dict, List

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
Expand All @@ -16,8 +15,10 @@
from .node_visitor_manager import register_node_visitor
from .qnn_constants import (
OpConv2d,
OpConv3d,
OpDepthWiseConv2d,
OpTransposeConv2d,
OpTransposeConv3d,
QNN_OP_PACKAGE_NAME_QTI_AISW,
)
from .utils import get_parameter
Expand Down Expand Up @@ -66,7 +67,7 @@ def _add_conv_op_parameter(
len(padding_shape),
padding_shape,
np.array(
[[padding[0], padding[0]], [padding[1], padding[1]]],
padding,
dtype=np.uint32,
),
True,
Expand Down Expand Up @@ -108,8 +109,14 @@ def define_node(
input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
assert (
input_tensor.dim() == 4
input_tensor.dim() != 3
), "All Conv1D should be converted to Conv2D in CanonicalizeConv,"
assert input_tensor.dim() in {
4,
5,
}, "Only Conv2d and Conv3d is supported in conv builder,"

is_conv2d = input_tensor.dim() == 4
input_tensor_wrapper = self.define_tensor(
input_node,
node,
Expand All @@ -120,9 +127,15 @@ def define_node(

filter_node = self.get_node(node.args[1])
filter_tensor = get_parameter(filter_node, self.edge_program)
# weight of pytorch OIHW(conv2d) | IOHW(conv_transpose2d), yet QNN is HWIO
# weight of pytorch OIHW(conv2d) / OIDHW(conv3d) or IOHW(conv_transpose2d) / IODHW(conv_transpose3d),
# yet QNN is HWIO or DHWIO
is_transpose_conv = cast(bool, node.args[6])
filter_axis_order = (2, 3, 0, 1) if is_transpose_conv else (2, 3, 1, 0)
if is_conv2d:
filter_axis_order = (2, 3, 0, 1) if is_transpose_conv else (2, 3, 1, 0)
else:
filter_axis_order = (
(2, 3, 4, 0, 1) if is_transpose_conv else (2, 3, 4, 1, 0)
)
filter_tensor = filter_tensor.permute(dims=filter_axis_order).contiguous()
filter_tensor_wrapper = self.define_tensor(
filter_node,
Expand All @@ -132,7 +145,6 @@ def define_node(
nodes_to_wrappers,
)
conv_input_tensors = [input_tensor_wrapper, filter_tensor_wrapper]

if node.args[2] is not None:
bias_node = self.get_node(node.args[2])
bias_tensor = get_parameter(bias_node, self.edge_program)
Expand All @@ -159,11 +171,10 @@ def define_node(
padding = cast(List[int], node.args[4])
dilation = cast(List[int], node.args[5])
output_padding = cast(List[int], node.args[7])

groups = cast(int, node.args[8])
# Qnn filter tensor is (H, W, Cin, Cout)
group_input_channels = filter_tensor.shape[2]
group_output_channels = int(filter_tensor.shape[3] / groups)
# Qnn filter tensor is (H, W, Cin, Cout) or (D, H, W, Cin, Cout)
group_input_channels = filter_tensor.shape[-2]
group_output_channels = int(filter_tensor.shape[-1] / groups)
# 1) groups = input_channels (i.e. group_input_channels = 1)
# 2) output_channels is a positive integer multiple of input channels
# TODO: Currently, negative results will be zero with Depthwise conv2d when input_channel == groups == 1
Expand All @@ -175,18 +186,23 @@ def define_node(
)
if len(padding) == 1:
padding = padding + padding
padding = [[x, x] for x in padding]

stride_shape = [len(stride)]
padding_shape = [2, 2]
padding_shape = [len(padding), len(padding[0])]
dilation_shape = [len(dilation)]
output_padding_shape = [len(output_padding)]

if is_depthwise_conv:
if is_transpose_conv:
assert all(
val == 1 for val in dilation
), "CanonicalizeConv pass should perform dilate for transpose_conv."
op_class = OpTransposeConv2d if is_conv2d else OpTransposeConv3d
elif is_depthwise_conv:
assert is_conv2d, "DepthWise only supports Conv2d"
op_class = OpDepthWiseConv2d
elif is_transpose_conv:
op_class = OpTransposeConv2d
else:
op_class = OpConv2d
op_class = OpConv2d if is_conv2d else OpConv3d

conv_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
Expand Down
18 changes: 18 additions & 0 deletions backends/qualcomm/builders/qnn_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ class OpConv2d:
param_dilation: str = "dilation"


@dataclass(init=False, frozen=True)
class OpConv3d:
op_name: str = "Conv3d"
param_stride: str = "stride"
param_pad_amount: str = "pad_amount"
param_group: str = "group"
param_dilation: str = "dilation"


@dataclass(init=False, frozen=True)
class OpConvert:
op_name: str = "Convert"
Expand Down Expand Up @@ -573,6 +582,15 @@ class OpTransposeConv2d:
param_output_padding: str = "output_padding"


@dataclass(init=False, frozen=True)
class OpTransposeConv3d:
op_name: str = "TransposeConv3d"
param_stride: str = "stride"
param_pad_amount: str = "pad_amount"
param_group: str = "group"
param_output_padding: str = "output_padding"


@dataclass(init=False, frozen=True)
class OpUnpack:
op_name: str = "UnPack"
Expand Down
6 changes: 4 additions & 2 deletions backends/qualcomm/quantizer/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,11 +1094,13 @@ def annotate_cdist(node: Node, quantization_config: QuantizationConfig) -> None:

@register_annotator(
[
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.conv2d.padding,
torch.ops.aten.conv1d.default,
torch.ops.aten.conv_transpose2d.input,
torch.ops.aten.conv3d.default,
torch.ops.aten.conv_transpose1d.default,
torch.ops.aten.conv_transpose2d.input,
torch.ops.aten.conv_transpose3d.input,
torch.ops.aten.convolution.default,
]
)
Expand Down
Loading
Loading