diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index eeb5084ef60..bbfb18b1851 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -8,8 +8,8 @@ from .annotate_quant_attrs import AnnotateQuantAttrs from .annotate_stack import AnnotateStack from .annotate_unbind import AnnotateUnbind +from .canonicalize_conv import CanonicalizeConv from .convert_bmm_to_matmul import ConvertBmmToMatmul -from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d from .convert_linear_to_conv2d import ConvertLinearToConv2d from .convert_square_to_pow import ConvertSquareToPow from .decompose_any import DecomposeAny @@ -47,8 +47,8 @@ AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind, + CanonicalizeConv, ConvertBmmToMatmul, - ConvertConv1dToConv2d, ConvertLinearToConv2d, ConvertSquareToPow, DecomposeAny, diff --git a/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py b/backends/qualcomm/_passes/canonicalize_conv.py similarity index 67% rename from backends/qualcomm/_passes/convert_conv1d_to_conv2d.py rename to backends/qualcomm/_passes/canonicalize_conv.py index d09113ad42a..3804fb05da0 100644 --- a/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py +++ b/backends/qualcomm/_passes/canonicalize_conv.py @@ -4,32 +4,96 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import cast, Tuple + import torch + from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter from executorch.backends.qualcomm.utils.constants import QCOM_REQUANTIZE from executorch.exir.pass_base import ExportPass, PassResult +from torch._guards import detect_fake_mode from .utils import append_qdq, copy_meta -class ConvertConv1dToConv2d(ExportPass): +class CanonicalizeConv(ExportPass): """ - Conv1d is not supported by QNN. - Change it to input -> unsqueeze -> conv2d -> squeeze -> output + 1. QNN does not support dilation on TransposeConvND + Dilate the kernel manually for math-equivalent operation + 2. Conv1d is not supported by QNN. + Change it to input -> unsqueeze -> conv2d -> squeeze -> output """ def __init__(self, edge_program: torch.export.ExportedProgram): - super(ConvertConv1dToConv2d, self).__init__() + super(CanonicalizeConv, self).__init__() self.edge_program = edge_program - self.conv_op_map = { + self.conv1d_op_map = { torch.ops.aten.conv1d.default: torch.ops.aten.conv2d.default, torch.ops.aten.conv_transpose1d.default: torch.ops.aten.conv_transpose2d.input, } + self.transpose_conv_set = { + torch.ops.aten.conv_transpose1d.default, + torch.ops.aten.conv_transpose2d.input, + } + + def dilate(self, tensor, dilation): + # e.g. + # for 3x3 kernel with dilation == (2, 3) + # 1, 0, 0, 2, 0, 0, 3 + # 1, 2, 3 0, 0, 0, 0, 0, 0, 0 + # 4, 5, 6 --> 4, 0, 0, 5, 0, 0, 6 + # 7, 8, 9 0, 0, 0, 0, 0, 0, 0 + # 7, 0, 0, 8, 0, 0, 9 + i, o, *k = tensor.shape + new_k = [dim + (dim - 1) * (s - 1) for s, dim in zip(dilation, k)] + new_tensor = torch.zeros((i, o, *new_k), dtype=tensor.dtype) + indexing = (...,) + tuple([slice(None, None, d) for d in dilation]) + new_tensor[indexing] = tensor + return new_tensor def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph + # condition 1 + for node in graph.nodes: + # arg order (https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.conv_transpose2d.html) + # > input, weight, bias, stride, padding, output_padding, groups, dilation + if node.target in self.transpose_conv_set and len(node.args) > 7: + dilation = cast(Tuple[int], node.args[7]) + # dilate kernel in advance + filter_arg = node.args[1] + filter_node = ( + # fp graph + filter_arg + if filter_arg.op == "placeholder" + # qdq graph + else node.args[1].args[0] + ) + filter_tensor = self.dilate( + get_parameter(filter_node, self.edge_program), + dilation, + ) + # update tensor meta for kernel node + fake_mode = detect_fake_mode(filter_node.meta["val"]) + converter = fake_mode.fake_tensor_converter + filter_node.meta["val"] = converter.from_real_tensor( + fake_mode, filter_tensor + ) + # update kernel + set_parameter( + ( + torch.nn.Parameter(filter_tensor) + if filter_tensor.dtype == torch.float + else filter_tensor + ), + filter_node, + self.edge_program, + ) + # pop dilation for graph in cpu + node.args = node.args[0:-1] + + # condition 2 for node in graph.nodes: - if node.target in self.conv_op_map: + if node.target in self.conv1d_op_map: input_node = node.args[0] with graph_module.graph.inserting_after(input_node): unsqueeze_op = torch.ops.aten.unsqueeze_copy.default @@ -108,7 +172,7 @@ def call(self, graph_module: torch.fx.GraphModule): ) conv2d_node = graph.create_node( "call_function", - self.conv_op_map[node.target], + self.conv1d_op_map[node.target], conv_args, ) conv2d_node.meta = copy_meta( diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 461fb07fb16..ffb9f3221df 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -13,8 +13,8 @@ AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind, + CanonicalizeConv, ConvertBmmToMatmul, - ConvertConv1dToConv2d, ConvertLinearToConv2d, ConvertSquareToPow, DecomposeAny, @@ -82,6 +82,7 @@ def get_capture_program_passes(): (AnnotateQuantAttrs, True), (AnnotateStack, True), (AnnotateUnbind, True), + (CanonicalizeConv, True), (ConvertBmmToMatmul, False), (DecomposeAny, True), (DecomposeColIm, True), @@ -215,7 +216,7 @@ def transform_for_export_pipeline( self.add_pass(DecomposeWrapWithAutocast()) # this pass will rewrite state_dict, it needs to be accomplished before # to_edge_transform_and_lower - self.add_pass(ConvertConv1dToConv2d(exported_program)) + self.add_pass(CanonicalizeConv(exported_program)) if convert_linear_to_conv2d: self.add_pass(ConvertLinearToConv2d(exported_program)) self.add_pass(ConvertSquareToPow()) diff --git a/backends/qualcomm/_passes/remove_redundancy.py b/backends/qualcomm/_passes/remove_redundancy.py index 2ec8161613b..d5b7a2b0534 100644 --- a/backends/qualcomm/_passes/remove_redundancy.py +++ b/backends/qualcomm/_passes/remove_redundancy.py @@ -41,12 +41,7 @@ def __init__(self, quantization_capture=False): ) def _dim_order_op_condition(self, node): - dim_order = node.kwargs.get("dim_order") - # skip if there contains layout hint - # e.g. (0, 2, 3, 1) != (0, 1, 2, 3) - if node.meta["val"].dtype != node.args[0].meta["val"].dtype: - return False - return dim_order != list(range(len(dim_order))) + return node.meta["val"].dtype == node.args[0].meta["val"].dtype def _to_copy_op_condition(self, node): return "memory_format" in node.kwargs diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index 20492495156..6d908707892 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -64,8 +64,8 @@ def get_passes_dependency_for_capture_program(): AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind, + CanonicalizeConv, ConvertBmmToMatmul, - ConvertConv1dToConv2d, DecomposeAny, DecomposeColIm, DecomposeLinalgVectorNorm, @@ -99,7 +99,7 @@ def get_passes_dependency_for_capture_program(): I64toI32: [RemoveRedundancy], LayoutTransform: [ AnnotateQuantAttrs, - ConvertConv1dToConv2d, + CanonicalizeConv, ExpandBroadcastTensorShape, FixedLinearKeepDim, ], diff --git a/backends/qualcomm/builders/op_amax.py b/backends/qualcomm/builders/op_amax.py index 051355a8b6b..d0335f95463 100644 --- a/backends/qualcomm/builders/op_amax.py +++ b/backends/qualcomm/builders/op_amax.py @@ -40,14 +40,19 @@ def define_node( ) # mean dims and keep dims - mean_dims = cast(List[int], node.args[1]) - mean_dims = [ - mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims - ] - if QCOM_AXIS_ORDER in node.meta: + if len(node.args) > 1: + mean_dims = cast(List[int], node.args[1]) mean_dims = [ - node.meta[QCOM_AXIS_ORDER].index(mean_dim) for mean_dim in mean_dims + mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims ] + if QCOM_AXIS_ORDER in node.meta: + mean_dims = [ + node.meta[QCOM_AXIS_ORDER].index(mean_dim) for mean_dim in mean_dims + ] + else: + # reduce all dimensions + mean_dims = list(range(input_node.meta["val"].dim())) + mean_dims_shape = [len(mean_dims)] output_tensor = self.get_tensor(node, node) diff --git a/backends/qualcomm/builders/op_amin.py b/backends/qualcomm/builders/op_amin.py index 9f8f17b4e37..142340dbae0 100644 --- a/backends/qualcomm/builders/op_amin.py +++ b/backends/qualcomm/builders/op_amin.py @@ -40,14 +40,19 @@ def define_node( ) # mean dims and keep dims - mean_dims = cast(List[int], node.args[1]) - mean_dims = [ - mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims - ] - if QCOM_AXIS_ORDER in node.meta: + if len(node.args) > 1: + mean_dims = cast(List[int], node.args[1]) mean_dims = [ - node.meta[QCOM_AXIS_ORDER].index(mean_dim) for mean_dim in mean_dims + mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims ] + if QCOM_AXIS_ORDER in node.meta: + mean_dims = [ + node.meta[QCOM_AXIS_ORDER].index(mean_dim) for mean_dim in mean_dims + ] + else: + # reduce all dimensions + mean_dims = list(range(input_node.meta["val"].dim())) + mean_dims_shape = [len(mean_dims)] output_tensor = self.get_tensor(node, node) diff --git a/backends/qualcomm/builders/op_conv2d.py b/backends/qualcomm/builders/op_conv2d.py index 0456cd53524..1cfc1e45c9b 100644 --- a/backends/qualcomm/builders/op_conv2d.py +++ b/backends/qualcomm/builders/op_conv2d.py @@ -109,7 +109,7 @@ def define_node( input_tensor = self.get_tensor(input_node, node) assert ( input_tensor.dim() == 4 - ), "All Conv should be converted to Conv2D in ConvertConv1dToConv2d" + ), "All Conv1D should be converted to Conv2D in CanonicalizeConv," input_tensor_wrapper = self.define_tensor( input_node, node, diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 78fb8d15a4e..cacf0684988 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -214,7 +214,7 @@ def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None: @register_annotator([torch.ops.aten.amax.default]) def annotate_amax(node: Node, quantization_config: QuantizationConfig) -> None: - annotate_binary(node, quantization_config) + annotate_single_in_single_out(node, quantization_config) @register_annotator([torch.ops.aten.argmax.default]) @@ -224,7 +224,7 @@ def annotate_argmax(node: Node, quantization_config: QuantizationConfig) -> None @register_annotator([torch.ops.aten.amin.default]) def annotate_amin(node: Node, quantization_config: QuantizationConfig) -> None: - annotate_binary(node, quantization_config) + annotate_single_in_single_out(node, quantization_config) @register_annotator([torch.ops.aten.argmin.default]) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 54624471763..275ed5b0374 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -589,10 +589,16 @@ def forward(self, x): class ConvTranspose1dSingle(torch.nn.Module): - def __init__(self, bias=True): + def __init__(self, bias=True, dilation=1): super().__init__() self.conv_transpose = torch.nn.ConvTranspose1d( - in_channels=1, out_channels=3, kernel_size=3, stride=2, padding=1, bias=bias + in_channels=1, + out_channels=3, + kernel_size=3, + stride=2, + padding=1, + dilation=dilation, + bias=bias, ) def forward(self, x): @@ -600,7 +606,7 @@ def forward(self, x): class ConvTranspose2dSingle(torch.nn.Module): - def __init__(self, bias=True): + def __init__(self, bias=True, dilation=1): super().__init__() self.conv_transpose = torch.nn.ConvTranspose2d( in_channels=1, @@ -608,6 +614,7 @@ def __init__(self, bias=True): kernel_size=3, stride=2, padding=1, + dilation=dilation, bias=bias, ) diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index f5f8a1d904f..6e3ddffd802 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -128,7 +128,11 @@ def test_qnn_backend_alias(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_amax(self): - modules = [AMax(dim=1, keepdim=False), AMax(dim=1, keepdim=True)] # noqa: F405 + modules = [ + AMax(dim=1, keepdim=False), # noqa: F405 + AMax(dim=1, keepdim=True), # noqa: F405 + AMax(), # noqa: F405 + ] sample_input = (torch.randn(4, 4),) for i, module in enumerate(modules): with self.subTest(i=i): @@ -142,7 +146,11 @@ def test_qnn_backend_amax_conv(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_amin(self): - modules = [AMin(dim=1, keepdim=False), AMin(dim=1, keepdim=True)] # noqa: F405 + modules = [ + AMin(dim=1, keepdim=False), # noqa: F405 + AMin(dim=1, keepdim=True), # noqa: F405 + AMin(), # noqa: F405 + ] sample_input = (torch.randn(4, 4),) for i, module in enumerate(modules): with self.subTest(i=i): @@ -278,8 +286,9 @@ def test_qnn_backend_conv_transpose1d(self): modules = [ ConvTranspose1dSingle(), # noqa: F405 ConvTranspose1dSingle(bias=False), # noqa: F405 + ConvTranspose1dSingle(dilation=2), # noqa: F405 ] - sample_input = (torch.randn([1, 1, 3]),) + sample_input = (torch.randn([1, 1, 33]),) for i, module in enumerate(modules): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) @@ -288,8 +297,11 @@ def test_qnn_backend_conv_transpose2d(self): modules = [ ConvTranspose2dSingle(), # noqa: F405 ConvTranspose2dSingle(bias=False), # noqa: F405 + ConvTranspose2dSingle(dilation=2), # noqa: F405 + ConvTranspose2dSingle(dilation=(2, 3)), # noqa: F405 + ConvTranspose2dSingle(dilation=(2, 1)), # noqa: F405 ] - sample_input = (torch.randn([1, 1, 3, 3]),) + sample_input = (torch.randn([1, 1, 33, 33]),) for i, module in enumerate(modules): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) @@ -1544,7 +1556,11 @@ def test_qnn_backend_alias(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_amax(self): - modules = [AMax(dim=1, keepdim=False), AMax(dim=1, keepdim=True)] # noqa: F405 + modules = [ + AMax(dim=1, keepdim=False), # noqa: F405 + AMax(dim=1, keepdim=True), # noqa: F405 + AMax(), # noqa: F405 + ] sample_input = (torch.randn(4, 4),) for i, module in enumerate(modules): with self.subTest(i=i): @@ -1560,7 +1576,11 @@ def test_qnn_backend_amax_conv(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_amin(self): - modules = [AMin(dim=1, keepdim=False), AMin(dim=1, keepdim=True)] # noqa: F405 + modules = [ + AMin(dim=1, keepdim=False), # noqa: F405 + AMin(dim=1, keepdim=True), # noqa: F405 + AMin(), # noqa: F405 + ] sample_input = (torch.randn(4, 4),) for i, module in enumerate(modules): with self.subTest(i=i): @@ -1741,6 +1761,7 @@ def test_qnn_backend_conv_transpose1d(self): modules = [ ConvTranspose1dSingle(), # noqa: F405 ConvTranspose1dSingle(bias=False), # noqa: F405 + ConvTranspose1dSingle(dilation=2), # noqa: F405 ] sample_input = (torch.randn([1, 1, 3]),) for i, module in enumerate(modules): @@ -1752,7 +1773,9 @@ def test_qnn_backend_conv_transpose2d(self): modules = [ ConvTranspose2dSingle(), # noqa: F405 ConvTranspose2dSingle(bias=False), # noqa: F405 - ] # noqa: F405 + ConvTranspose2dSingle(dilation=(2, 3)), # noqa: F405 + ConvTranspose2dSingle(dilation=(2, 1)), # noqa: F405 + ] sample_input = (torch.randn([1, 1, 3, 3]),) for i, module in enumerate(modules): with self.subTest(i=i): @@ -5380,7 +5403,7 @@ def test_efficientnet(self): self.skipTest("missing required envs") cmds = [ "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/efficientnet.py" + f"{self.executorch_root}/examples/qualcomm/oss_scripts/efficientnet.py", "--dataset", self.image_dataset, "--artifact", diff --git a/docs/source/backends-qualcomm.md b/docs/source/backends-qualcomm.md index fb7e9c40931..f427c7c7cea 100644 --- a/docs/source/backends-qualcomm.md +++ b/docs/source/backends-qualcomm.md @@ -431,11 +431,12 @@ For practical examples, see [`test_qnn_delegate.py`](https://github.com/pytorch/ #### Step 3: Configure Compile Specs During this step, you will need to specify the target SoC, data type, and other QNN compiler spec. ```python -from executorch.backends.qualcomm.compiler import ( +from executorch.backends.qualcomm.utils.utils import ( generate_qnn_executorch_compiler_spec, generate_htp_compiler_spec, + QcomChipset, + to_edge_transform_and_lower_to_qnn, ) -from executorch.backends.qualcomm.utils.utils import QcomChipset # HTP Compiler Configuration backend_options = generate_htp_compiler_spec( @@ -450,11 +451,6 @@ compile_spec = generate_qnn_executorch_compiler_spec( ``` #### Step 4: Lower and Export the Model ```python -from executorch.backends.qualcomm.partition.qnn_partitioner import ( - to_edge_transform_and_lower_to_qnn, -) -from executorch.exir import ExecutorchBackendConfig - # Lower to QNN backend delegated_program = to_edge_transform_and_lower_to_qnn( quantized_model if quantized else model, @@ -463,9 +459,7 @@ delegated_program = to_edge_transform_and_lower_to_qnn( ) # Export to ExecuTorch format -executorch_program = delegated_program.to_executorch( - config=ExecutorchBackendConfig(extract_delegate_segments=False) -) +executorch_program = delegated_program.to_executorch() # Save the compiled model model_name = "custom_model_qnn.pte" diff --git a/examples/qualcomm/scripts/export_example.py b/examples/qualcomm/scripts/export_example.py index 1dbff982352..9b5ace12327 100644 --- a/examples/qualcomm/scripts/export_example.py +++ b/examples/qualcomm/scripts/export_example.py @@ -11,7 +11,6 @@ ) from executorch.examples.models import MODEL_NAME_TO_MODEL from executorch.examples.models.model_factory import EagerModelFactory -from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.extension.export_util.utils import save_pte_program from torchao.quantization.pt2e.quantize_pt2e import ( @@ -108,9 +107,7 @@ def main() -> None: m, example_inputs, compile_spec, generate_etrecord=args.generate_etrecord ) - executorch_program = delegated_program.to_executorch( - config=ExecutorchBackendConfig(extract_delegate_segments=False) - ) + executorch_program = delegated_program.to_executorch() if args.generate_etrecord: etrecord_path = args.output_folder + "etrecord.bin"