Skip to content

Commit c00d726

Browse files
NXP Backend: Add pass to remove unnecessary Quantize/Dequantize nodes. (#15148)
### Summary This PR adds an edge dialect pre-processing pass to remove some Q/DQ nodes. This enables some non-delegated nodes (which run on the CPU) to run in directly in int8 and avoid the QDQ compute overhead. This improves the inference speed (by eliminating the need to artificially quantize and de-quantize input and output values. ### Test plan Unit tests provided. cc @robert-kalmar
1 parent d9f2c46 commit c00d726

File tree

6 files changed

+250
-11
lines changed

6 files changed

+250
-11
lines changed

backends/nxp/backend/edge_helper.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,14 @@ def previous_non_qdq_node(node: Node, input_index: int = 0) -> Node | None:
125125
current_node = current_node.args[0]
126126
else:
127127
return current_node
128+
129+
130+
Scale = list[float] | float
131+
ZeroPoint = list[int] | int
132+
133+
134+
def get_quantization_parameters_for(node: Node) -> tuple[Scale, ZeroPoint] | None:
135+
if "quantize" not in node.target.__name__ or len(node.args) < 3:
136+
return None
137+
138+
return node.args[1], node.args[2] # Scale and zero_point
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import numpy as np
7+
import torch
8+
9+
from executorch.backends.nxp.backend.edge_helper import get_quantization_parameters_for
10+
from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass
11+
from executorch.backends.nxp.neutron_partitioner import QDQClusterRecognizer
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from torch.fx.passes.infra.pass_base import PassResult
14+
15+
16+
class RemoveAdditionalQDQClustersPass(NeutronEdgePass):
17+
"""
18+
After delegation of partitions, there may be additional dequantize quantize nodes for QDQ clusters that were
19+
not delegated. If dequantize quantize nodes are quantized per tensor and quantization parameters of dequantize
20+
and quantize nodes in a QDQ cluster are equal, the nodes can be removed and thus the inner nodes computed in int8.
21+
22+
23+
┌────────────▼──────────┐
24+
│ dequantize_per_tensor │
25+
└────────────┬──────────┘
26+
│ │
27+
┌───▼──┐ replace with ┌───▼──┐
28+
│ node │ ──────────────► │ node │
29+
└───┬──┘ └───┬──┘
30+
│ ▼
31+
┌───────────▼─────────┐
32+
│ quantize_per_tensor │
33+
└───────────┬─────────┘
34+
35+
36+
"""
37+
38+
qdq_per_channel_nodes = (
39+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
40+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
41+
)
42+
43+
qdq_per_tensor_nodes = (
44+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
45+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
46+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
47+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
48+
)
49+
50+
def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
51+
nodes = list(graph_module.graph.nodes)
52+
qdq_clusterer = QDQClusterRecognizer()
53+
qdq_clusterer.tag_qdq_clusters(nodes)
54+
55+
for cluster in qdq_clusterer.cluster_map.values():
56+
# For now, enable only permute_copy and cat.
57+
if cluster.compute_node.target not in [
58+
exir_ops.edge.aten.permute_copy.default,
59+
exir_ops.edge.aten.cat.default,
60+
]:
61+
continue
62+
63+
# Ensure cluster doesn't contain dequantize/quantize per channel nodes.
64+
if any(
65+
node
66+
for node in cluster.ops
67+
if node.target in self.qdq_per_channel_nodes
68+
):
69+
continue
70+
71+
qdq_nodes = [
72+
node for node in cluster.ops if node.target in self.qdq_per_tensor_nodes
73+
]
74+
75+
qdq_nodes_quant_params = [
76+
get_quantization_parameters_for(node) for node in qdq_nodes
77+
]
78+
79+
equal_quant_scales = [
80+
np.allclose(
81+
qdq_nodes_quant_params[idx][0], qdq_nodes_quant_params[idx + 1][0]
82+
)
83+
for idx in range(len(qdq_nodes_quant_params[:-1]))
84+
]
85+
86+
equal_quant_zero_points = [
87+
np.allclose(
88+
qdq_nodes_quant_params[idx][1], qdq_nodes_quant_params[idx + 1][1]
89+
)
90+
for idx in range(len(qdq_nodes_quant_params[:-1]))
91+
]
92+
93+
# Check if all quantization params are equal to ensure that QDQ cluster can be removed.
94+
if not all(equal_quant_scales + equal_quant_zero_points):
95+
continue
96+
97+
# Replace the uses of each dequantize/quantize node with its arg node.
98+
for qdq_node in qdq_nodes:
99+
qdq_node.replace_all_uses_with(qdq_node.args[0])
100+
graph_module.graph.erase_node(qdq_node)
101+
102+
# Remove compute node cluster info from node meta.
103+
cluster.compute_node.meta.pop("cluster")
104+
105+
graph_module = self.recompile_module(graph_module)
106+
107+
# The graph has now changed, and we cannot keep iterating through it. Return the new graph and the parent
108+
# class will call this pass again.
109+
return PassResult(graph_module, True)
110+
111+
return PassResult(graph_module, False)

backends/nxp/tests/executorch_pipeline.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from executorch.backends.nxp.edge_passes.neutron_edge_pass_manager import (
1818
NeutronEdgePassManager,
1919
)
20+
from executorch.backends.nxp.edge_passes.remove_additional_quantize_dequantize_nodes_pass import (
21+
RemoveAdditionalQDQClustersPass,
22+
)
2023
from executorch.backends.nxp.edge_passes.remove_io_quant_ops_pass import (
2124
RemoveIOQuantOpsPass,
2225
)
@@ -35,7 +38,6 @@
3538
from torch.export import export
3639
from torchao.quantization.pt2e.quantizer import Quantizer
3740

38-
3941
neutron_converter_flavor = "SDK_25_09"
4042
neutron_target_spec = NeutronTargetSpec(
4143
target="imxrt700", neutron_converter_flavor=neutron_converter_flavor
@@ -64,7 +66,6 @@ def _get_default_quantizer(target_spec: NeutronTargetSpec) -> Quantizer:
6466
def to_model_input_spec(
6567
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]]
6668
) -> tuple[ModelInputSpec, ...]:
67-
6869
if isinstance(input_spec, tuple) and all(
6970
isinstance(spec, ModelInputSpec) for spec in input_spec
7071
):
@@ -139,6 +140,10 @@ def to_quantized_edge_program(
139140
[RemoveIOQuantOpsPass(edge_program_manager=edge_program_manager)]
140141
)
141142

143+
edge_program_manager = edge_program_manager.transform(
144+
NeutronEdgePassManager([RemoveAdditionalQDQClustersPass()])
145+
)
146+
142147
return edge_program_manager
143148

144149

backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def forward(self, x):
104104
return torch.permute(x, self.perm)
105105

106106

107-
class TestPermuteCopyConversion(kgb.SpyAgency, unittest.TestCase):
107+
class TestPermuteCopyConversion(unittest.TestCase):
108108
@classmethod
109109
def setUpClass(cls):
110110
torch.manual_seed(23)
@@ -302,9 +302,9 @@ def test_permute_copy_non_delegated_conversion__from_permute_4D__quantized(
302302
edge_program = to_quantized_edge_program(model, input_shape).exported_program()
303303

304304
nodes = list(edge_program.graph.nodes)
305-
assert len(nodes) == 10
305+
assert len(nodes) == 8
306306
assert (
307-
nodes[6].target == exir_ops.edge.aten.permute_copy.default
307+
nodes[5].target == exir_ops.edge.aten.permute_copy.default
308308
) # PermuteCopy not delegated.
309309

310310
@parameterized.expand(
@@ -320,7 +320,7 @@ def test_permute_copy_non_delegated_conversion__from_transpose_4D__quantized(
320320
edge_program = to_quantized_edge_program(model, input_shape).exported_program()
321321

322322
nodes = list(edge_program.graph.nodes)
323-
assert len(nodes) == 10
323+
assert len(nodes) == 8
324324
assert (
325-
nodes[6].target == exir_ops.edge.aten.permute_copy.default
325+
nodes[5].target == exir_ops.edge.aten.permute_copy.default
326326
) # PermuteCopy not delegated.

backends/nxp/tests/test_edge_passes.py

Lines changed: 109 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,54 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import copy
67
import unittest
78

89
import kgb
910
import numpy as np
1011
import torch
1112

13+
from executorch.backends.nxp.backend.custom_delegation_options import (
14+
CustomDelegationOptions,
15+
)
1216
from executorch.backends.nxp.backend.edge_helper import _is_dequantize, _is_quantize
1317
from executorch.backends.nxp.backend.edge_program_converter import (
1418
EdgeProgramToIRConverter,
1519
)
1620
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import (
1721
ViewCopyConverter,
1822
)
23+
from executorch.backends.nxp.edge_passes.neutron_edge_pass_manager import (
24+
NeutronEdgePassManager,
25+
)
26+
from executorch.backends.nxp.edge_passes.remove_additional_quantize_dequantize_nodes_pass import (
27+
RemoveAdditionalQDQClustersPass,
28+
)
29+
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
30+
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
31+
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
32+
from executorch.backends.nxp.quantizer.utils import post_training_quantize
1933
from executorch.backends.nxp.tests.executorch_pipeline import (
34+
get_random_calibration_inputs,
2035
neutron_target_spec,
36+
to_model_input_spec,
2137
to_quantized_edge_program,
2238
)
2339
from executorch.backends.nxp.tests.executors import (
40+
compare_output_arrays,
2441
EdgeProgramExecutor,
2542
OverrideTargetSupportCheck,
2643
)
44+
from executorch.backends.nxp.tests.ir.converter.node_converter.test_permute_copy_converter import (
45+
Conv2dPermuteModule,
46+
)
2747
from executorch.backends.nxp.tests.models import (
2848
ConvActivationModule,
2949
ConvFCFCSoftmaxModuleWithoutReshape,
3050
LinearActivationModule,
3151
)
3252
from executorch.exir.dialects._ops import ops as exir_ops
53+
from executorch.extension.export_util.utils import export_to_edge
3354
from parameterized import parameterized
3455
from torch.export import ExportedProgram
3556
from torch.fx import Graph, Node
@@ -117,7 +138,6 @@ def test_moving_fusable_activations_into_separate_qdq_clusters__addmm(
117138
call_original=True,
118139
owner=EdgeProgramToIRConverter,
119140
) as converter_spy:
120-
121141
input_shape = (1, 4)
122142
model = LinearActivationModule(
123143
activation=activation,
@@ -161,7 +181,6 @@ def test_moving_fusable_activations_into_separate_qdq_clusters__mm(
161181
call_original=True,
162182
owner=EdgeProgramToIRConverter,
163183
) as converter_spy:
164-
165184
input_shape = (1, 4)
166185
model = LinearActivationModule(
167186
activation=activation,
@@ -205,7 +224,6 @@ def test_moving_fusable_activations_into_separate_qdq_clusters__linear(
205224
call_original=True,
206225
owner=EdgeProgramToIRConverter,
207226
) as converter_spy:
208-
209227
input_shape = (1, 4)
210228
model = LinearActivationModule(
211229
activation=activation,
@@ -249,7 +267,6 @@ def test_moving_fusable_activations_into_separate_qdq_clusters__conv(
249267
call_original=True,
250268
owner=EdgeProgramToIRConverter,
251269
) as converter_spy:
252-
253270
input_shape = (1, 4, 8, 8)
254271
model = ConvActivationModule(
255272
activation=activation, inplace=True, in_channels=input_shape[1]
@@ -273,3 +290,91 @@ def test_moving_fusable_activations_into_separate_qdq_clusters__conv(
273290
nodes[13]
274291
)
275292
assert _is_quantize(nodes[14])
293+
294+
def test_remove_additional_quantize_dequantize_nodes_pass(self):
295+
input_shape = (1, 3, 8, 16)
296+
new_dims = (3, 2, 1, 0)
297+
model = Conv2dPermuteModule(input_shape[1], new_dims)
298+
target = "imxrt700"
299+
custom_delegation_options = CustomDelegationOptions()
300+
301+
calibration_inputs = get_random_calibration_inputs(
302+
to_model_input_spec(input_shape)
303+
)
304+
305+
example_input = calibration_inputs[0]
306+
exir_program_aten = torch.export.export(model, example_input, strict=True)
307+
308+
exir_program_aten_quant = post_training_quantize(
309+
exir_program_aten,
310+
calibration_inputs,
311+
NeutronQuantizer(neutron_target_spec),
312+
)
313+
edge_program_manager = export_to_edge(
314+
exir_program_aten_quant,
315+
example_input,
316+
)
317+
318+
edge_program_manager = edge_program_manager.transform(NeutronEdgePassManager())
319+
320+
compile_spec = generate_neutron_compile_spec(target, "SDK_25_09")
321+
partitioner = NeutronPartitioner(
322+
compile_spec, neutron_target_spec, custom_delegation_options
323+
)
324+
325+
edge_program_manager = edge_program_manager.to_backend(partitioner)
326+
327+
# Make sure QDQ cluster for permute_copy is present.
328+
edge_program_with_qdq_cluster = copy.deepcopy(
329+
edge_program_manager.exported_program()
330+
)
331+
nodes = list(edge_program_with_qdq_cluster.graph.nodes)
332+
assert len(nodes) == 10
333+
assert (
334+
nodes[5].target
335+
== exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
336+
)
337+
assert nodes[6].target == exir_ops.edge.aten.permute_copy.default
338+
assert "cluster" in nodes[6].meta
339+
assert (
340+
nodes[7].target
341+
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
342+
)
343+
344+
# Run pass for removal of additional QDQ nodes and compute in non-float types where possible
345+
edge_program_manager = edge_program_manager.transform(
346+
NeutronEdgePassManager([RemoveAdditionalQDQClustersPass()])
347+
)
348+
349+
# Make sure QDQ cluster for permute_copy is removed.
350+
edge_program_without_qdq_cluster = edge_program_manager.exported_program()
351+
nodes = list(edge_program_without_qdq_cluster.graph.nodes)
352+
assert len(nodes) == 8
353+
assert nodes[4].name == "getitem"
354+
assert nodes[5].target == exir_ops.edge.aten.permute_copy.default
355+
assert "cluster" not in nodes[5].meta
356+
assert (
357+
nodes[6].target
358+
== exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
359+
)
360+
361+
edge_program_executor_without_qdq_cluster = EdgeProgramExecutor(
362+
edge_program_without_qdq_cluster
363+
)
364+
edge_program_executor_with_qdq_cluster = EdgeProgramExecutor(
365+
edge_program_with_qdq_cluster
366+
)
367+
368+
input_data = np.random.random(input_shape).astype(np.float32)
369+
edge_program_output_without_qdq_cluster = (
370+
edge_program_executor_without_qdq_cluster.inference(input_data)
371+
)
372+
edge_program_output_with_qdq_cluster = (
373+
edge_program_executor_with_qdq_cluster.inference(input_data)
374+
)
375+
376+
compare_output_arrays(
377+
edge_program_output_without_qdq_cluster,
378+
edge_program_output_with_qdq_cluster,
379+
"main output",
380+
)

examples/nxp/aot_neutron_compile.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from executorch.backends.nxp.edge_passes.neutron_edge_pass_manager import (
1919
NeutronEdgePassManager,
2020
)
21+
from executorch.backends.nxp.edge_passes.remove_additional_quantize_dequantize_nodes_pass import (
22+
RemoveAdditionalQDQClustersPass,
23+
)
2124
from executorch.backends.nxp.edge_passes.remove_io_quant_ops_pass import (
2225
RemoveIOQuantOpsPass,
2326
)
@@ -258,6 +261,10 @@ def get_model_and_inputs_from_name(model_name: str):
258261
[RemoveIOQuantOpsPass(edge_program_manager=edge_program_manager)]
259262
)
260263

264+
edge_program_manager = edge_program_manager.transform(
265+
NeutronEdgePassManager([RemoveAdditionalQDQClustersPass()])
266+
)
267+
261268
logging.debug(f"Lowered graph:\n{edge_program_manager.exported_program().graph}")
262269

263270
# 5. Export to ExecuTorch program

0 commit comments

Comments
 (0)