From b60972b14eabe82b4ca34fc782df2fb6e07d2e0d Mon Sep 17 00:00:00 2001 From: Thibaut Goetghebuer-Planchon Date: Tue, 26 Nov 2024 10:23:10 +0000 Subject: [PATCH] Quantize compatible node + activation patterns as one block Annotate conv1d/conv2d/linear followed by relu/relu6 patterns as one block and fuse the activation into its parent. The activation will then be implicitly done in the tosa.rescale node that will have a -128 zero-point. Change-Id: I5bf1e2c91be21ab842012fbc20d159af7fe2222d --- backends/arm/_passes/arm_pass_manager.py | 4 ++ .../_passes/fuse_quantized_activation_pass.py | 60 +++++++++++++++++ .../arm/quantizer/quantization_annotator.py | 66 ++++++++++++++++++- backends/arm/test/ops/test_conv_combos.py | 7 +- 4 files changed, 133 insertions(+), 4 deletions(-) create mode 100644 backends/arm/_passes/fuse_quantized_activation_pass.py diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 6d747d8129c..b52e0879f92 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -37,6 +37,9 @@ QuantizeFullArgument, RetraceFoldedDtypesPass, ) +from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( + FuseQuantizedActivationPass, +) from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import ( KeepDimsFalseToSqueezePass, @@ -72,6 +75,7 @@ def transform_to_backend_pipeline( self, exported_program: ExportedProgram, compile_spec: list[CompileSpec] ): """Apply passes before transforming program to backend""" + self.add_pass(FuseQuantizedActivationPass()) self.add_pass(DecomposeLinearPass()) self.add_pass(RemoveGetItemPass()) self.add_pass(DecomposeLayerNormPass()) diff --git a/backends/arm/_passes/fuse_quantized_activation_pass.py b/backends/arm/_passes/fuse_quantized_activation_pass.py new file mode 100644 index 00000000000..86836842bb1 --- /dev/null +++ b/backends/arm/_passes/fuse_quantized_activation_pass.py @@ -0,0 +1,60 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.tosa_quant_utils import q_op +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx import Node + + +class FuseQuantizedActivationPass(ExportPass): + def _is_fuseable_quantized_activation(self, node: Node): + """Fuse activations that have a 0 lower bound and quantized with a qmin zero-point""" + is_fuseable = node.target == exir_ops.edge.aten.relu.default + if node.target == exir_ops.edge.aten.hardtanh.default: + min_val = node.args[1] + is_fuseable = min_val == 0 + + is_quantized = len(node.users) == 1 and next(iter(node.users)).target == q_op + if is_quantized: + quant_node = next(iter(node.users)) + zp = quant_node.args[2] + qmin = quant_node.args[3] + + return is_fuseable and is_quantized and zp == qmin + + def _is_fuseable_input(self, node: Node): + return ( + node.target + in ( + exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.linear.default, + ) + and len(node.users) == 1 + ) + + def call(self, graph_module: torch.fx.GraphModule): + modified = False + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + + if not self._is_fuseable_quantized_activation(node): + continue + + input_node = node.args[0] + if not self._is_fuseable_input(input_node): + continue + + node.replace_all_uses_with(input_node) + graph_module.graph.erase_node(node) + modified = True + + if modified: + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 0b29570f36f..9ddeb61c30a 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -89,6 +89,41 @@ def _annotate_output(node: Node, quant_property: _QuantProperty): _annotate_output_qspec(node, quant_property.qspec) +def _match_pattern( + node: Node, pattern: List[List], filter_fn: Optional[Callable[[Node], bool]] = None +) -> bool: + """ + Check if there's a chain of node.ancestors? -> node -> node.descendant? that matches the + chain provided in 'pattern'. If 'filter_fn' is provided, check that all the nodes in the + chain pass the filtering. + + Each 'pattern' element is composed of a list of disjunctive nodes types. + """ + assert len(pattern) == 2, "Only two-nodes patterns supported currently" + + if node.target in pattern[0]: + assert len(node.users) != 0 + parent = node + child = next(iter(node.users)) + elif node.target in pattern[1]: + assert len(node.args) != 0 + parent = node.args[0] + child = node + else: + return False + + if len(parent.users) != 1: + return False + + if parent.target not in pattern[0] or child.target not in pattern[1]: + return False + + if filter_fn is not None: + return filter_fn(parent) and filter_fn(child) + + return True + + _one_to_one = [ torch.ops.aten.exp.default, torch.ops.aten.log.default, @@ -164,7 +199,36 @@ def get_quant_properties( # noqa: C901 bias_qspec = quantization_config.get_bias_qspec() quant_properties = _OpQuantProperties() - if node.target in ( + + def any_or_hardtanh_min_zero(n: Node): + # Check that if the node is a hardtanh, its min_val is zero + return n.target != torch.ops.aten.hardtanh.default or n.args[1] == 0 + + if _match_pattern( + node, + [ + [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.linear.default, + ], + [torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default], + ], + any_or_hardtanh_min_zero, + ): + if node.target in ( + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.linear.default, + ): + quant_properties.quant_inputs = [ + _QuantProperty(0, input_act_qspec), + _QuantProperty(1, weight_qspec, mark_annotated=True), + _QuantProperty(2, bias_qspec, optional=True, mark_annotated=True), + ] + else: + quant_properties.quant_output = _QuantProperty(0, output_act_qspec) + elif node.target in ( torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default, torch.ops.aten.linear.default, diff --git a/backends/arm/test/ops/test_conv_combos.py b/backends/arm/test/ops/test_conv_combos.py index 17083b129f0..4a5615f97c6 100644 --- a/backends/arm/test/ops/test_conv_combos.py +++ b/backends/arm/test/ops/test_conv_combos.py @@ -137,10 +137,11 @@ class ComboConvRelu6(torch.nn.Module): ] test_data = [ - (20 * torch.randn(1, 3, 256, 256),), - (5 * torch.randn(1, 3, 256, 256),), + (2 * torch.randn(1, 3, 256, 256),), + (0.5 * torch.randn(1, 3, 256, 256),), (torch.randn(1, 3, 256, 256),), - (-5 * torch.randn(1, 3, 256, 256),), + (-0.5 * torch.randn(1, 3, 256, 256),), + (-2 * torch.randn(1, 3, 256, 256),), ] def __init__(self):