Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from .annotate_quant_attrs import AnnotateQuantAttrs
from .annotate_stack import AnnotateStack
from .annotate_unbind import AnnotateUnbind
from .canonicalize_conv import CanonicalizeConv
from .convert_bmm_to_matmul import ConvertBmmToMatmul
from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
from .convert_linear_to_conv2d import ConvertLinearToConv2d
from .convert_square_to_pow import ConvertSquareToPow
from .decompose_any import DecomposeAny
Expand Down Expand Up @@ -47,8 +47,8 @@
AnnotateQuantAttrs,
AnnotateStack,
AnnotateUnbind,
CanonicalizeConv,
ConvertBmmToMatmul,
ConvertConv1dToConv2d,
ConvertLinearToConv2d,
ConvertSquareToPow,
DecomposeAny,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,96 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast, Tuple

import torch

from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
from executorch.backends.qualcomm.utils.constants import QCOM_REQUANTIZE
from executorch.exir.pass_base import ExportPass, PassResult
from torch._guards import detect_fake_mode

from .utils import append_qdq, copy_meta


class ConvertConv1dToConv2d(ExportPass):
class CanonicalizeConv(ExportPass):
"""
Conv1d is not supported by QNN.
Change it to input -> unsqueeze -> conv2d -> squeeze -> output
1. QNN does not support dilation on TransposeConvND
Dilate the kernel manually for math-equivalent operation
2. Conv1d is not supported by QNN.
Change it to input -> unsqueeze -> conv2d -> squeeze -> output
"""

def __init__(self, edge_program: torch.export.ExportedProgram):
super(ConvertConv1dToConv2d, self).__init__()
super(CanonicalizeConv, self).__init__()
self.edge_program = edge_program
self.conv_op_map = {
self.conv1d_op_map = {
torch.ops.aten.conv1d.default: torch.ops.aten.conv2d.default,
torch.ops.aten.conv_transpose1d.default: torch.ops.aten.conv_transpose2d.input,
}
self.transpose_conv_set = {
torch.ops.aten.conv_transpose1d.default,
torch.ops.aten.conv_transpose2d.input,
}

def dilate(self, tensor, dilation):
# e.g.
# for 3x3 kernel with dilation == (2, 3)
# 1, 0, 0, 2, 0, 0, 3
# 1, 2, 3 0, 0, 0, 0, 0, 0, 0
# 4, 5, 6 --> 4, 0, 0, 5, 0, 0, 6
# 7, 8, 9 0, 0, 0, 0, 0, 0, 0
# 7, 0, 0, 8, 0, 0, 9
i, o, *k = tensor.shape
new_k = [dim + (dim - 1) * (s - 1) for s, dim in zip(dilation, k)]
new_tensor = torch.zeros((i, o, *new_k), dtype=tensor.dtype)
indexing = (...,) + tuple([slice(None, None, d) for d in dilation])
new_tensor[indexing] = tensor
return new_tensor

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
# condition 1
for node in graph.nodes:
# arg order (https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.conv_transpose2d.html)
# > input, weight, bias, stride, padding, output_padding, groups, dilation
if node.target in self.transpose_conv_set and len(node.args) > 7:
dilation = cast(Tuple[int], node.args[7])
# dilate kernel in advance
filter_arg = node.args[1]
filter_node = (
# fp graph
filter_arg
if filter_arg.op == "placeholder"
# qdq graph
else node.args[1].args[0]
)
filter_tensor = self.dilate(
get_parameter(filter_node, self.edge_program),
dilation,
)
# update tensor meta for kernel node
fake_mode = detect_fake_mode(filter_node.meta["val"])
converter = fake_mode.fake_tensor_converter
filter_node.meta["val"] = converter.from_real_tensor(
fake_mode, filter_tensor
)
# update kernel
set_parameter(
(
torch.nn.Parameter(filter_tensor)
if filter_tensor.dtype == torch.float
else filter_tensor
),
filter_node,
self.edge_program,
)
# pop dilation for graph in cpu
node.args = node.args[0:-1]

# condition 2
for node in graph.nodes:
if node.target in self.conv_op_map:
if node.target in self.conv1d_op_map:
input_node = node.args[0]
with graph_module.graph.inserting_after(input_node):
unsqueeze_op = torch.ops.aten.unsqueeze_copy.default
Expand Down Expand Up @@ -108,7 +172,7 @@ def call(self, graph_module: torch.fx.GraphModule):
)
conv2d_node = graph.create_node(
"call_function",
self.conv_op_map[node.target],
self.conv1d_op_map[node.target],
conv_args,
)
conv2d_node.meta = copy_meta(
Expand Down
5 changes: 3 additions & 2 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
AnnotateQuantAttrs,
AnnotateStack,
AnnotateUnbind,
CanonicalizeConv,
ConvertBmmToMatmul,
ConvertConv1dToConv2d,
ConvertLinearToConv2d,
ConvertSquareToPow,
DecomposeAny,
Expand Down Expand Up @@ -82,6 +82,7 @@ def get_capture_program_passes():
(AnnotateQuantAttrs, True),
(AnnotateStack, True),
(AnnotateUnbind, True),
(CanonicalizeConv, True),
(ConvertBmmToMatmul, False),
(DecomposeAny, True),
(DecomposeColIm, True),
Expand Down Expand Up @@ -215,7 +216,7 @@ def transform_for_export_pipeline(
self.add_pass(DecomposeWrapWithAutocast())
# this pass will rewrite state_dict, it needs to be accomplished before
# to_edge_transform_and_lower
self.add_pass(ConvertConv1dToConv2d(exported_program))
self.add_pass(CanonicalizeConv(exported_program))
if convert_linear_to_conv2d:
self.add_pass(ConvertLinearToConv2d(exported_program))
self.add_pass(ConvertSquareToPow())
Expand Down
7 changes: 1 addition & 6 deletions backends/qualcomm/_passes/remove_redundancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,7 @@ def __init__(self, quantization_capture=False):
)

def _dim_order_op_condition(self, node):
dim_order = node.kwargs.get("dim_order")
# skip if there contains layout hint
# e.g. (0, 2, 3, 1) != (0, 1, 2, 3)
if node.meta["val"].dtype != node.args[0].meta["val"].dtype:
return False
return dim_order != list(range(len(dim_order)))
return node.meta["val"].dtype == node.args[0].meta["val"].dtype

def _to_copy_op_condition(self, node):
return "memory_format" in node.kwargs
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/_passes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def get_passes_dependency_for_capture_program():
AnnotateQuantAttrs,
AnnotateStack,
AnnotateUnbind,
CanonicalizeConv,
ConvertBmmToMatmul,
ConvertConv1dToConv2d,
DecomposeAny,
DecomposeColIm,
DecomposeLinalgVectorNorm,
Expand Down Expand Up @@ -99,7 +99,7 @@ def get_passes_dependency_for_capture_program():
I64toI32: [RemoveRedundancy],
LayoutTransform: [
AnnotateQuantAttrs,
ConvertConv1dToConv2d,
CanonicalizeConv,
ExpandBroadcastTensorShape,
FixedLinearKeepDim,
],
Expand Down
17 changes: 11 additions & 6 deletions backends/qualcomm/builders/op_amax.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,19 @@ def define_node(
)

# mean dims and keep dims
mean_dims = cast(List[int], node.args[1])
mean_dims = [
mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims
]
if QCOM_AXIS_ORDER in node.meta:
if len(node.args) > 1:
mean_dims = cast(List[int], node.args[1])
mean_dims = [
node.meta[QCOM_AXIS_ORDER].index(mean_dim) for mean_dim in mean_dims
mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims
]
if QCOM_AXIS_ORDER in node.meta:
mean_dims = [
node.meta[QCOM_AXIS_ORDER].index(mean_dim) for mean_dim in mean_dims
]
else:
# reduce all dimensions
mean_dims = list(range(input_node.meta["val"].dim()))

mean_dims_shape = [len(mean_dims)]

output_tensor = self.get_tensor(node, node)
Expand Down
17 changes: 11 additions & 6 deletions backends/qualcomm/builders/op_amin.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,19 @@ def define_node(
)

# mean dims and keep dims
mean_dims = cast(List[int], node.args[1])
mean_dims = [
mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims
]
if QCOM_AXIS_ORDER in node.meta:
if len(node.args) > 1:
mean_dims = cast(List[int], node.args[1])
mean_dims = [
node.meta[QCOM_AXIS_ORDER].index(mean_dim) for mean_dim in mean_dims
mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims
]
if QCOM_AXIS_ORDER in node.meta:
mean_dims = [
node.meta[QCOM_AXIS_ORDER].index(mean_dim) for mean_dim in mean_dims
]
else:
# reduce all dimensions
mean_dims = list(range(input_node.meta["val"].dim()))

mean_dims_shape = [len(mean_dims)]

output_tensor = self.get_tensor(node, node)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
assert (
input_tensor.dim() == 4
), "All Conv should be converted to Conv2D in ConvertConv1dToConv2d"
), "All Conv1D should be converted to Conv2D in CanonicalizeConv,"
input_tensor_wrapper = self.define_tensor(
input_node,
node,
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/quantizer/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None:

@register_annotator([torch.ops.aten.amax.default])
def annotate_amax(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_binary(node, quantization_config)
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.argmax.default])
Expand All @@ -224,7 +224,7 @@ def annotate_argmax(node: Node, quantization_config: QuantizationConfig) -> None

@register_annotator([torch.ops.aten.amin.default])
def annotate_amin(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_binary(node, quantization_config)
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.argmin.default])
Expand Down
13 changes: 10 additions & 3 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,25 +589,32 @@ def forward(self, x):


class ConvTranspose1dSingle(torch.nn.Module):
def __init__(self, bias=True):
def __init__(self, bias=True, dilation=1):
super().__init__()
self.conv_transpose = torch.nn.ConvTranspose1d(
in_channels=1, out_channels=3, kernel_size=3, stride=2, padding=1, bias=bias
in_channels=1,
out_channels=3,
kernel_size=3,
stride=2,
padding=1,
dilation=dilation,
bias=bias,
)

def forward(self, x):
return self.conv_transpose(x)


class ConvTranspose2dSingle(torch.nn.Module):
def __init__(self, bias=True):
def __init__(self, bias=True, dilation=1):
super().__init__()
self.conv_transpose = torch.nn.ConvTranspose2d(
in_channels=1,
out_channels=3,
kernel_size=3,
stride=2,
padding=1,
dilation=dilation,
bias=bias,
)

Expand Down
Loading
Loading