Skip to content

Commit 12d17ef

Browse files
Cortex-M backend: Fuse Relu, Hardtanh and Hardsigmoid (#15917)
Implements a new pass which fuses activation passes with preceeding cortex-m ops if possible. Removed quantization of conv1d, conv3d as they are not tested + moves Conv+relu test to test_activations. Propagate qmin, qmax to conv kernel. Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent 3f92668 commit 12d17ef

File tree

8 files changed

+607
-42
lines changed

8 files changed

+607
-42
lines changed

backends/cortex_m/passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from .activation_fusion_pass import ActivationFusionPass # noqa
67
from .convert_to_cortex_m_pass import ConvertToCortexMPass # noqa
78
from .quantized_op_fusion_pass import QuantizedOpFusionPass # noqa
89
from .replace_quant_nodes_pass import ReplaceQuantNodesPass # noqa
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import logging
8+
9+
import executorch.backends.cortex_m.ops.operators # noqa: F401
10+
from executorch.backends.arm._passes.quant_args import QuantArgs
11+
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import ExportPass
14+
15+
from torch.fx import GraphModule, Node
16+
from torch.fx.passes.infra.pass_manager import PassResult
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
class ActivationFusionPass(ExportPass):
22+
"""Fuse activations into preceding Cortex-M quantized operators.
23+
24+
Supported activation patterns:
25+
q-> [conv2d, linear] -> [relu, hardtanh, hardsigmoid] -> dq
26+
27+
Fusing works by clamping the quantized output range (and zero-point when
28+
required) of the preceding Cortex-M operator, then removing the activation
29+
node from the graph.
30+
"""
31+
32+
TARGETS = {
33+
exir_ops.edge.aten.relu.default,
34+
exir_ops.edge.aten.hardtanh.default,
35+
exir_ops.edge.aten.hardsigmoid.default,
36+
}
37+
38+
FUSE_OPS = {
39+
exir_ops.edge.aten.linear.default,
40+
exir_ops.edge.aten.convolution.default,
41+
}
42+
43+
def _quantize(self, val, scale, zp, qmin, qmax):
44+
return min(max(round(val / scale + zp), qmin), qmax)
45+
46+
def _get_validated_qparams(self, node, input_node):
47+
48+
if "input_qparams" not in input_node.meta or "output_qparams" not in node.meta:
49+
logger.warning(
50+
f"Cannot fuse activation for {input_node.name}->{node.name} as the pattern wasn't quantized properly."
51+
)
52+
return None
53+
54+
qparams_dict = node.meta["output_qparams"][0]._asdict()
55+
zp = qparams_dict["zp"]
56+
scale = qparams_dict["scale"]
57+
qmin = qparams_dict["qmin"]
58+
qmax = qparams_dict["qmax"]
59+
60+
if not isinstance(scale, float) or not isinstance(zp, int):
61+
logger.warning(
62+
f"Cannot fuse activation {node.name} as quantization parameters are not per tensor."
63+
)
64+
return None
65+
66+
match node.target:
67+
case exir_ops.edge.aten.relu.default:
68+
quantized_min_val = self._quantize(0, scale, zp, qmin, qmax)
69+
quantized_max_val = qmax
70+
case exir_ops.edge.aten.hardtanh.default:
71+
quantized_min_val = self._quantize(node.args[1], scale, zp, qmin, qmax)
72+
quantized_max_val = self._quantize(node.args[2], scale, zp, qmin, qmax)
73+
case exir_ops.edge.aten.hardsigmoid.default:
74+
quantized_min_val = self._quantize(0, scale, zp, qmin, qmax)
75+
quantized_max_val = self._quantize(1, scale, zp, qmin, qmax)
76+
case _:
77+
raise RuntimeError("Unexpected target {node.target}.")
78+
79+
# If the minimal quantized value is larger than the qmin, it means that the quantized range contains
80+
# invalid values [qmin, ..., quantized_min_val-1], indicating bad quantization parameters.
81+
if qparams_dict["qmin"] != quantized_min_val:
82+
logger.warning(
83+
f"Cannot fuse activation {node.name} as qmin is out of range."
84+
)
85+
return None
86+
87+
# If the maximal quantized value is smaller than the qmax, it means that the quantized range contains
88+
# invalid values [quantized_max_val + 1, ... , qmax], indicating bad quantization parameters.
89+
if quantized_max_val != qparams_dict["qmax"]:
90+
logger.warning(
91+
f"Cannot fuse activation {node.name} as qmax is out of range."
92+
)
93+
return None
94+
95+
return qparams_dict
96+
97+
def _update_qparams_hardsigmoid(self, quant_dict):
98+
"""
99+
Returns quant_dict with scale and zp updated to match hardsigmoid activation.
100+
101+
The quantized output from the hard sigmoid is defined by
102+
Q(y) = clamp(round(y/scale + zp), qmin, qmax)
103+
y = clamp(x/6 + 1/2, 0, 1)
104+
where x is the output of the fused activation op, conv or linear.
105+
106+
Q(y) can be rewritten as a function of only x:
107+
Q(y) = clamp(round(clamp(x/6 + 1/2, 0, 1)/scale + zp), qmin, qmax)
108+
Q(y) = clamp(round(clamp((x/(6*scale) + 1/(2*scale) + zp, zp, 1/scale + zp)), qmin, qmax)
109+
110+
From definition of the qparams mapping the output in the range [0,1] to quantized range
111+
[qmin, qmax], we have:
112+
zp = Q(0) <= qmin
113+
1/scale + zp = Q(1) >= qmax
114+
which makes the inner clamp redundant.
115+
116+
Therefore, hardsigmoid is equivalent to a quantization with modified parameters
117+
new_scale := 6*scale
118+
new_zp = zp + 1/(2*scale) ~= zp + round(1/(2*scale))
119+
"""
120+
121+
new_scale = quant_dict["scale"] * 6
122+
123+
new_zp = quant_dict["zp"] + round(1 / (2 * quant_dict["scale"]))
124+
clamped_new_zp = max(quant_dict["qmin"], min(quant_dict["qmax"], new_zp))
125+
126+
quant_dict["scale"] = new_scale
127+
quant_dict["zp"] = clamped_new_zp
128+
129+
def call(self, graph_module: GraphModule) -> PassResult:
130+
modified = False
131+
nodes_to_erase: list[Node] = []
132+
133+
for node in list(graph_module.graph.nodes):
134+
if node.op != "call_function" or node.target not in self.TARGETS:
135+
continue
136+
137+
input_node = node.args[0]
138+
if (
139+
input_node.op != "call_function"
140+
or input_node.target not in self.FUSE_OPS
141+
):
142+
logger.warning(
143+
f"Cannot fuse activation {node.name} as input node {input_node.name} is not a supported fused activation op."
144+
)
145+
continue
146+
if len(input_node.users.values()) > 1:
147+
logger.warning(
148+
f"Cannot fuse activation {node.name} as input node {input_node.name} has multiple users."
149+
)
150+
continue
151+
152+
if (qparams_dict := self._get_validated_qparams(node, input_node)) is None:
153+
continue
154+
155+
if node.target == exir_ops.edge.aten.hardsigmoid.default:
156+
self._update_qparams_hardsigmoid(qparams_dict)
157+
158+
input_node.meta["output_qparams"][0] = QuantArgs(**qparams_dict)
159+
160+
node.replace_all_uses_with(input_node)
161+
nodes_to_erase.append(node)
162+
modified = True
163+
164+
for node in nodes_to_erase:
165+
graph_module.graph.erase_node(node)
166+
167+
if modified:
168+
graph_module.recompile()
169+
170+
return PassResult(graph_module, modified)

backends/cortex_m/passes/convert_to_cortex_m_pass.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,11 @@ def _get_convolution_replacement(self, node) -> int:
139139
if not isinstance(weight_scales, list):
140140
weight_scales = [weight_scales] * weight.data.shape[0]
141141

142-
output_scale = node.meta["output_qparams"][0].scale
143-
output_zero_point = node.meta["output_qparams"][0].zp
142+
output_qparams = node.meta["output_qparams"][0]
143+
output_scale = output_qparams.scale
144+
output_zero_point = output_qparams.zp
145+
output_qmin = output_qparams.qmin
146+
output_qmax = output_qparams.qmax
144147

145148
quantized_multipliers = []
146149
quantized_shifts = []
@@ -177,8 +180,8 @@ def _get_convolution_replacement(self, node) -> int:
177180
output_zero_point,
178181
torch.tensor(quantized_multipliers, dtype=torch.int32),
179182
torch.tensor(quantized_shifts, dtype=torch.int32),
180-
-128,
181-
127,
183+
output_qmin,
184+
output_qmax,
182185
)
183186
return exir_ops.edge.cortex_m.quantized_conv2d.default, new_args
184187

backends/cortex_m/passes/cortex_m_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
ScalarsToAttributePass,
1212
)
1313
from executorch.backends.cortex_m.passes import (
14+
ActivationFusionPass,
1415
ConvertToCortexMPass,
1516
QuantizedOpFusionPass,
1617
ReplaceQuantNodesPass,
@@ -31,6 +32,7 @@ class CortexMPassManager(PassManager):
3132
ReplaceScalarWithTensorArgPass,
3233
ReplaceQuantNodesPass,
3334
QuantizedOpFusionPass,
35+
ActivationFusionPass,
3436
ConvertToCortexMPass,
3537
]
3638

backends/cortex_m/quantizer/operator_configs.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,21 @@
2424
LINEAR_OP_PATTERNS = [
2525
[torch.ops.aten.linear.default],
2626
[torch.ops.aten.linear.default, torch.ops.aten.relu.default],
27+
[torch.ops.aten.linear.default, torch.ops.aten.relu_.default],
28+
[torch.ops.aten.linear.default, torch.ops.aten.hardtanh.default],
29+
[torch.ops.aten.linear.default, torch.ops.aten.hardtanh_.default],
30+
[torch.ops.aten.linear.default, torch.ops.aten.hardsigmoid.default],
31+
[torch.ops.aten.linear.default, torch.ops.aten.hardsigmoid_.default],
2732
]
2833

2934
CONV_OP_PATTERNS = [
30-
[torch.ops.aten.conv1d.default],
3135
[torch.ops.aten.conv2d.default],
32-
[torch.ops.aten.conv3d.default],
36+
[torch.ops.aten.conv2d.default, torch.ops.aten.relu.default],
37+
[torch.ops.aten.conv2d.default, torch.ops.aten.relu_.default],
38+
[torch.ops.aten.conv2d.default, torch.ops.aten.hardtanh.default],
39+
[torch.ops.aten.conv2d.default, torch.ops.aten.hardtanh_.default],
40+
[torch.ops.aten.conv2d.default, torch.ops.aten.hardsigmoid.default],
41+
[torch.ops.aten.conv2d.default, torch.ops.aten.hardsigmoid_.default],
3342
]
3443

3544
# ----------------- OPERATOR CONFIG PRESETS -----------------

0 commit comments

Comments
 (0)