From 5d31f61da7f05e035db769fa61935cb77e830c6c Mon Sep 17 00:00:00 2001 From: Abhinay Kukkadapu Date: Wed, 1 Oct 2025 21:16:01 -0700 Subject: [PATCH] Fix const prop pass when a const prop tensor has zero stride, make it contiguous (#14725) Summary: Found out that xnnpack lowering fails due to rejecting a zero stride tensor by executorch during to_backend phase. The fix is to identify such tensors and make them contiguous. For context: https://fb.workplace.com/groups/764610762549390/permalink/1868907967337731/ Differential Revision: D83631522 --- exir/passes/constant_prop_pass.py | 8 ++++ exir/tests/test_passes.py | 73 ++++++++++++++++++++++++++++++- 2 files changed, 80 insertions(+), 1 deletion(-) diff --git a/exir/passes/constant_prop_pass.py b/exir/passes/constant_prop_pass.py index 7daa3a247e8..06c1c78ee21 100644 --- a/exir/passes/constant_prop_pass.py +++ b/exir/passes/constant_prop_pass.py @@ -164,6 +164,14 @@ def get_propagated_const_tensor_dict( with torch.no_grad(): # Execute the `node.target` and create a new propagated constant tensor. prop_constant_tensor = node.target(*args_data, **kwargs_data) + + # ExecuTorch doesn't support zero strides, so we need to ensure the tensor is contiguous + # if it has any zero strides from broadcasting/expansion operations + if ( + isinstance(prop_constant_tensor, torch.Tensor) + and 0 in prop_constant_tensor.stride() + ): + prop_constant_tensor = prop_constant_tensor.contiguous() const_node_to_tensor[node] = prop_constant_tensor return const_node_to_tensor diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 716b808b087..14f105e8205 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -24,7 +24,17 @@ from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import ( QuantizationConfig, ) -from executorch.exir import EdgeCompileConfig, EdgeProgramManager, memory, to_edge +from executorch.backends.xnnpack.utils.configs import ( + get_xnnpack_executorch_backend_config, +) + +from executorch.exir import ( + EdgeCompileConfig, + EdgeProgramManager, + memory, + to_edge, + to_edge_transform_and_lower, +) from executorch.exir.dialects._ops import bind_pattern_to_op, ops, ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.emit import emit_program @@ -2022,3 +2032,64 @@ def forward(self, x): pass_result = constant_prop_pass(edge.exported_program()) # 1 constant: a (= self.w @ self.cst) self.assertEqual(1, len(pass_result.constants)) + + def test_constant_prop_pass_zero_stride_tensors(self) -> None: + """ + Test that constant propagation correctly handles tensors with zero strides + by converting them to contiguous tensors. Zero-stride tensors can be created + by operations like expand() and are not supported by ExecuTorch. + """ + + class ZeroStrideModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.const_param = torch.nn.Parameter(torch.tensor([1.0, 2.0, 3.0])) + + def forward(self, x): + unsqueezed = self.const_param.unsqueeze( + 1 + ) # Shape: (3, 1), strides: (1, 1) + # expand creates zero-stride tensor + expanded = unsqueezed.expand(3, 5) # Shape: (3, 5), strides: (1, 0) + + # Use the expanded tensor with the input to prevent elimination + result = x + expanded.sum() + return result + + model = ZeroStrideModel() + x = torch.randn(3, 5) + exported = torch.export.export(model, (x,)) + + # Before constant prop: verify we have the parameter + self.assertIn("const_param", exported.state_dict) + + const_prop_result = constant_prop_pass(exported) + lowered = to_edge_transform_and_lower( + const_prop_result, + partitioner=[XnnpackPartitioner()], + ) + + # Should go through + lowered.to_executorch(get_xnnpack_executorch_backend_config([SpecPropPass()])) + self.assertGreater(len(const_prop_result.constants), 0) + + # Find the propagated constant tensor + prop_tensor = None + for constant_name, constant_tensor in const_prop_result.constants.items(): + if constant_name.startswith("_prop_tensor_constant"): + prop_tensor = constant_tensor + break + + # Verify the propagated tensor exists and has no zero strides + self.assertIsNotNone(prop_tensor) + self.assertNotIn( + 0, + prop_tensor.stride(), + f"Propagated tensor still has zero stride: {prop_tensor.stride()}", + ) + + # Verify the tensor is contiguous + self.assertTrue( + prop_tensor.is_contiguous(), + f"Propagated tensor is not contiguous: {prop_tensor.stride()}", + )