From ca01b3e6386156d376102e3cb8a3bab65f726e90 Mon Sep 17 00:00:00 2001 From: Per Held Date: Fri, 17 Apr 2026 13:09:49 +0200 Subject: [PATCH] Arm backend: Fix quantized constant-folding for aten.cat lists (#18971) FuseConstantArgsPass resolved input_qparams by flattened input-node index, while FoldAndAnnotateQParamsPass stores them by top-level argument index. For aten.cat with a list-valued tensor argument, this caused only the first tensor to be dequantized before folding, which corrupted the fused constant. Resolve qparams by top-level argument index and propagate that qparam through nested list and tuple arguments. Add a regression test for quantized aten.cat constant folding with list-valued tensor inputs. Signed-off-by: Per Held Change-Id: I6e1a012d82a5dbeecb403c440a2944953dd5cba7 --- .../arm/_passes/fuse_constant_ops_pass.py | 19 ++-- .../passes/test_fuse_constant_ops_pass.py | 96 +++++++++++++++++++ 2 files changed, 107 insertions(+), 8 deletions(-) diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index 6fd9b145988..d6fd4b18b53 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -83,21 +83,24 @@ def _fuse_nodes(self, node) -> bool: input_nodes = list(node.all_input_nodes) qparams = node.meta.get("input_qparams", None) - def resolve_arg(arg): + def resolve_arg(arg, arg_index=None): + qparam = ( + qparams.get(arg_index) if qparams and arg_index is not None else None + ) if isinstance(arg, torch.fx.Node) and arg in input_nodes: - idx = input_nodes.index(arg) t = get_param_tensor(self.exported_program, arg) - # Check if qparams exist for this arg - if qparams and idx in qparams.keys(): - t = qparams[idx].dequantize_value(t) + if qparam is not None: + t = qparam.dequantize_value(t) return t if isinstance(arg, tuple): - return tuple(resolve_arg(x) for x in arg) + return tuple(resolve_arg(x, arg_index) for x in arg) if isinstance(arg, list): - return [resolve_arg(x) for x in arg] + return [resolve_arg(x, arg_index) for x in arg] return arg - new_args = tuple(resolve_arg(a) for a in node.args) + new_args = tuple( + resolve_arg(arg, arg_index) for arg_index, arg in enumerate(node.args) + ) new_kwargs = {k: resolve_arg(v) for k, v in node.kwargs.items()} data = node.target(*new_args, **new_kwargs) diff --git a/backends/arm/test/passes/test_fuse_constant_ops_pass.py b/backends/arm/test/passes/test_fuse_constant_ops_pass.py index 0f281dba24b..785744c1b37 100644 --- a/backends/arm/test/passes/test_fuse_constant_ops_pass.py +++ b/backends/arm/test/passes/test_fuse_constant_ops_pass.py @@ -11,8 +11,12 @@ ComputeConstantOpsAOTPass, FuseConstantArgsPass, ) +from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.test.harness.stages import StageType input_t = Tuple[torch.Tensor] # Input x input_t2 = Tuple[torch.Tensor, torch.Tensor] @@ -116,6 +120,52 @@ def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return torch.cat((a, b), dim=0) +class QuantizedCatConstantBuffers(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer( + "horizontal_ramp", + torch.tensor( + [ + [ + [ + [-95, -32, 32, 95, 0], + [-95, -32, 32, 95, 0], + [-95, -32, 32, 95, 0], + [-95, -32, 32, 95, 0], + ] + ] + ], + dtype=torch.int8, + ), + ) + self.register_buffer( + "vertical_ramp", + torch.tensor( + [ + [ + [ + [-95, -95, -95, -95, -95], + [-32, -32, -32, -32, -32], + [32, 32, 32, 32, 32], + [95, 95, 95, 95, 95], + ] + ] + ], + dtype=torch.int8, + ), + ) + + def forward(self) -> torch.Tensor: + return torch.cat( + ( + cast(torch.Tensor, self.horizontal_ramp), + cast(torch.Tensor, self.vertical_ramp), + ), + dim=1, + ) + + modules: Dict[str, ModuleWithFuseAttrs] = { "fuse_parameter": cast(ModuleWithFuseAttrs, FuseParameter()), "fuse_buffer": cast(ModuleWithFuseAttrs, FuseBuffer()), @@ -174,3 +224,49 @@ def test_fuse_constant_args_tosa_INT_cat(module: ModuleWithFuseAttrs) -> None: ], ) pipeline.run() + + +def test_fuse_constant_args_tosa_INT_cat_uses_top_level_arg_qparams() -> None: + qargs = QuantArgs( + scale=1.0 / 127.0, + zp=0, + qmin=-127, + qmax=127, + dtype=torch.int8, + ) + module = QuantizedCatConstantBuffers() + compile_spec = common.get_tosa_compile_spec( + TosaSpecification.create_from_string("TOSA-1.0+FP") + ) + tester = ArmTester(module, example_inputs=(), compile_spec=compile_spec) + tester.export().to_edge() + exported_program = tester.get_artifact(StageType.TO_EDGE).exported_program() + + cat_node = next( + node + for node in exported_program.graph_module.graph.nodes + if node.op == "call_function" + ) + cat_node.meta["input_qparams"] = {0: qargs} + cat_node.meta["output_qparams"] = {0: qargs} + + pass_result = FuseConstantArgsPass(exported_program).call( + exported_program.graph_module + ) + + assert list(exported_program.state_dict) == ["aten_cat_default_fused_const"] + torch.testing.assert_close( + exported_program.state_dict["aten_cat_default_fused_const"], + torch.cat( + ( + cast(torch.Tensor, module.horizontal_ramp), + cast(torch.Tensor, module.vertical_ramp), + ), + dim=1, + ), + ) + assert [ + node.name + for node in pass_result.graph_module.graph.nodes + if node.op == "placeholder" + ] == ["aten_cat_default_fused_const"]