From 65dc3d07e798a48bccb7c55ceaa13645f832c2f0 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Mon, 1 Dec 2025 14:42:45 +0100 Subject: [PATCH] Arm backend: Move FuseConstantArgsPass Move FuseConstantArgsPass to run after ComputeConstantOpsAOTPass to make sure that ops without input nodes are fused first. Signed-off-by: Oscar Andersson Change-Id: I3a671352fe8eb6a44d946f994bb097f9b2fe8638 --- backends/arm/_passes/arm_pass_manager.py | 2 +- backends/arm/_passes/fuse_constant_ops_pass.py | 5 ++++- backends/arm/test/models/test_torch_functions.py | 2 -- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index a8ca9eb1544..88b7d42fcda 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -243,7 +243,6 @@ def _tosa_pipeline( # passes. Ticket: MLETORCH-1540 DecomposeNotEqualPass(), MatchArgRanksPass(exported_program), - FuseConstantArgsPass(exported_program), ] ) @@ -265,6 +264,7 @@ def _tosa_pipeline( DecomposeAvgPool2dPass(), DecorateFp32toInt32CastingPass(), ComputeConstantOpsAOTPass(exported_program), + FuseConstantArgsPass(exported_program), ConvertExpandCopyToRepeatPass(), UnsqueezeBeforeRepeatPass(), DecomposeCumsumPass(exported_program), diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index a574ef554ad..0ca3dc38f75 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -178,7 +178,10 @@ def f(node_name_pre_computed): return node_name_pre_computed """ - _passes_required_after: Set[Type[ExportPass]] = {FuseEqualPlaceholdersPass} + _passes_required_after: Set[Type[ExportPass]] = { + FuseEqualPlaceholdersPass, + FuseConstantArgsPass, + } targeted_ops = [ exir_ops.edge.aten.full.default, diff --git a/backends/arm/test/models/test_torch_functions.py b/backends/arm/test/models/test_torch_functions.py index 54a9a6ae676..3632a9dd141 100644 --- a/backends/arm/test/models/test_torch_functions.py +++ b/backends/arm/test/models/test_torch_functions.py @@ -126,8 +126,6 @@ def test_torch_fns_FP(test_data): xfails={ "nonzero": "torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u4, 0). " "Requires dynamic output shape.", - "eye": "ValueError: Failed processing buffer placeholder: aten_arange_start_step_1_pre_computed_common. " - "Is the original torch function supported?", "topk": "NotImplementedError: No registered serialization name for found", "sort": "NotImplementedError: No registered serialization name for found", },