From 5b49d86f66eb5c56cdcca71dfa62c2ea778e2ffe Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Mon, 11 May 2026 10:19:39 +0200 Subject: [PATCH 1/2] [exir] Materialize alloc shapes in ToOutVarPass Fix a dynamic-shape lowering bug in exir. ConstraintBasedSymShapeEvalPass concretizes TensorSpec metadata, but ToOutVarPass was still building memory.alloc nodes from symbolic FakeTensor/tensor_meta shapes. That let symbolic dims leak into the generated ExecuTorch GraphModule and caused runtime failures when the lowered module was executed in Python. Build memory.alloc specs from concrete upper-bounded integer shapes instead. If an alloc shape is still not concretely bounded, raise a clear error. Add an EXIR regression test that exports a dynamic-shape model, runs ConstraintBasedSymShapeEvalPass + ToOutVarPass, and verifies that memory.alloc shapes are concrete integers. Signed-off-by: Oscar Andersson Change-Id: If9a7b4b9aad93c1d594f9f9178d33d7df944c5e6 --- exir/passes/__init__.py | 28 ++++++++++++++-------- exir/tests/test_passes.py | 49 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 10 deletions(-) diff --git a/exir/passes/__init__.py b/exir/passes/__init__.py index 9b1b8efe682..ede866549b2 100644 --- a/exir/passes/__init__.py +++ b/exir/passes/__init__.py @@ -62,6 +62,7 @@ from executorch.exir.passes.to_device_pass import ToDevicePass from executorch.exir.passes.weights_to_outputs_pass import weights_to_outputs_pass +from executorch.exir.sym_util import eval_shape_upper_bound from torch import fx from torch._subclasses import FakeTensor from torch.fx.passes.infra.pass_base import PassBase, PassResult @@ -281,31 +282,38 @@ def make_alloc_node( Note: tensor_metadata is only used in the case of a Tensor subclass, since fakifying a tensor subclass is not supported right now """ + + def materialize_alloc_spec( + shape: Union[torch.Size, Tuple[int, ...], List[int]], + dtype: torch.dtype, + ) -> memory.AllocSpec: + concrete_shape = eval_shape_upper_bound(shape) + if any(not isinstance(dim, int) for dim in concrete_shape): + raise RuntimeError( + "Memory allocator node requires concrete upper-bounded dimensions. " + f"Got shape {shape} and evaluated upper bounds {concrete_shape}." + ) + return (tuple(concrete_shape), dtype) + if val is None: if tensor_meta is not None: assert isinstance(tensor_meta, TensorMetadata) - alloc_spec = (tensor_meta.shape, tensor_meta.dtype) + alloc_spec = materialize_alloc_spec(tensor_meta.shape, tensor_meta.dtype) else: raise InternalError( "Memory allocator node needs FakeTensor val or TensorMetadata to proceed" ) elif isinstance(val, FakeTensor): - alloc_spec = (val.shape, val.dtype) + alloc_spec = materialize_alloc_spec(val.shape, val.dtype) else: assert isinstance(val, list) or isinstance(val, tuple) assert isinstance(tensor_meta, list) or isinstance(tensor_meta, tuple) alloc_spec: List[memory.AllocSpec] = [] for v, t in zip(val, tensor_meta): if v is not None: - # pyre-fixme[6]: For 1st argument expected - # `Union[List[Tuple[List[int], dtype]], Tuple[List[int], dtype]]` but - # got `Tuple[Size, dtype]`. - alloc_spec.append((v.shape, v.dtype)) + alloc_spec.append(materialize_alloc_spec(v.shape, v.dtype)) elif t is not None: - # pyre-fixme[6]: For 1st argument expected - # `Union[List[Tuple[List[int], dtype]], Tuple[List[int], dtype]]` but - # got `Tuple[Size, dtype]`. - alloc_spec.append((t.shape, t.dtype)) + alloc_spec.append(materialize_alloc_spec(t.shape, t.dtype)) else: raise InternalError( "Memory allocator node needs FakeTensor val or TensorMetadata to proceed" diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 8a084ba491a..1316dffb828 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -74,6 +75,7 @@ ) from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass from executorch.exir.passes.spec_prop_pass import SpecPropPass +from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass from executorch.exir.program._program import lift_constant_tensor_pass from executorch.exir.schema import TensorShapeDynamism @@ -1036,6 +1038,53 @@ def test_alloc_node_spec(self) -> None: for node in alloc_nodes: self.assertTrue(isinstance(node.meta.get("spec", None), TensorSpec)) + def test_to_out_var_dynamic_alloc_uses_concrete_upper_bounds(self) -> None: + class DynamicRelu(nn.Module): + def forward(self, x): + return torch.relu(x) + + eager_model = DynamicRelu() + inputs = (torch.randn(2, 4, 8, 3),) + dynamic_shapes = { + "x": { + 0: torch.export.Dim("batch", min=0, max=2), + 2: torch.export.Dim("height", min=0, max=8), + 3: torch.export.Dim("width", min=0, max=8), + } + } + prog = to_edge( + export( + eager_model, + inputs, + dynamic_shapes=dynamic_shapes, + strict=True, + ), + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), + ) + new_prog = prog.transform( + [ + SpecPropPass(), + ConstraintBasedSymShapeEvalPass(), + ] + ) + + new_gm_res = ToOutVarPass()(new_prog.exported_program().graph_module) + self.assertIsNotNone(new_gm_res) + new_gm = new_gm_res.graph_module + + alloc_nodes = [] + for node in new_gm.graph.nodes: + if node.target == memory.alloc: + alloc_nodes.append(node) + + self.assertTrue(len(alloc_nodes) > 0) + for node in alloc_nodes: + alloc_spec = node.args[0] + self.assertIsInstance(alloc_spec, tuple) + shape, _dtype = alloc_spec + for dim in shape: + self.assertIsInstance(dim, int) + def test_debug_pass_file_log(self) -> None: eager_model = Mul() inputs = eager_model.get_random_inputs() From 2309f0397b7559622540de8f037dc6209e42dfe5 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Thu, 28 May 2026 09:17:46 +0200 Subject: [PATCH 2/2] Arm backend: Update torch_functions xfails Remove nonzero test case from xfails in test_torch_functions.py. Signed-off-by: Oscar Andersson Change-Id: I5768429c6e289e114c55a1f77822cc03a619b8ab --- backends/arm/test/models/test_torch_functions.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/backends/arm/test/models/test_torch_functions.py b/backends/arm/test/models/test_torch_functions.py index 0ca8d3ac091..c6a4c5580dc 100644 --- a/backends/arm/test/models/test_torch_functions.py +++ b/backends/arm/test/models/test_torch_functions.py @@ -97,8 +97,6 @@ def forward(self, *args): "test_data", test_parameters, xfails={ - "nonzero": "torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u4, 0). " - "Requires dynamic output shape.", "topk": "NotImplementedError: No registered serialization name for found", "sort": "NotImplementedError: No registered serialization name for found", }, @@ -124,8 +122,6 @@ def test_torch_functions_tosa_FP(test_data): "test_data", test_parameters, xfails={ - "nonzero": "torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u4, 0). " - "Requires dynamic output shape.", "topk": "NotImplementedError: No registered serialization name for found", "sort": "NotImplementedError: No registered serialization name for found", },