Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions exir/passes/constant_prop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,14 @@ def get_propagated_const_tensor_dict(
with torch.no_grad():
# Execute the `node.target` and create a new propagated constant tensor.
prop_constant_tensor = node.target(*args_data, **kwargs_data)

# ExecuTorch doesn't support zero strides, so we need to ensure the tensor is contiguous
# if it has any zero strides from broadcasting/expansion operations
if (
isinstance(prop_constant_tensor, torch.Tensor)
and 0 in prop_constant_tensor.stride()
):
prop_constant_tensor = prop_constant_tensor.contiguous()
const_node_to_tensor[node] = prop_constant_tensor

return const_node_to_tensor
Expand Down
73 changes: 72 additions & 1 deletion exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,17 @@
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import (
QuantizationConfig,
)
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, memory, to_edge
from executorch.backends.xnnpack.utils.configs import (
get_xnnpack_executorch_backend_config,
)

from executorch.exir import (
EdgeCompileConfig,
EdgeProgramManager,
memory,
to_edge,
to_edge_transform_and_lower,
)
from executorch.exir.dialects._ops import bind_pattern_to_op, ops, ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.emit import emit_program
Expand Down Expand Up @@ -2022,3 +2032,64 @@ def forward(self, x):
pass_result = constant_prop_pass(edge.exported_program())
# 1 constant: a (= self.w @ self.cst)
self.assertEqual(1, len(pass_result.constants))

def test_constant_prop_pass_zero_stride_tensors(self) -> None:
"""
Test that constant propagation correctly handles tensors with zero strides
by converting them to contiguous tensors. Zero-stride tensors can be created
by operations like expand() and are not supported by ExecuTorch.
"""

class ZeroStrideModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.const_param = torch.nn.Parameter(torch.tensor([1.0, 2.0, 3.0]))

def forward(self, x):
unsqueezed = self.const_param.unsqueeze(
1
) # Shape: (3, 1), strides: (1, 1)
# expand creates zero-stride tensor
expanded = unsqueezed.expand(3, 5) # Shape: (3, 5), strides: (1, 0)

# Use the expanded tensor with the input to prevent elimination
result = x + expanded.sum()
return result

model = ZeroStrideModel()
x = torch.randn(3, 5)
exported = torch.export.export(model, (x,))

# Before constant prop: verify we have the parameter
self.assertIn("const_param", exported.state_dict)

const_prop_result = constant_prop_pass(exported)
lowered = to_edge_transform_and_lower(
const_prop_result,
partitioner=[XnnpackPartitioner()],
)

# Should go through
lowered.to_executorch(get_xnnpack_executorch_backend_config([SpecPropPass()]))
self.assertGreater(len(const_prop_result.constants), 0)

# Find the propagated constant tensor
prop_tensor = None
for constant_name, constant_tensor in const_prop_result.constants.items():
if constant_name.startswith("_prop_tensor_constant"):
prop_tensor = constant_tensor
break

# Verify the propagated tensor exists and has no zero strides
self.assertIsNotNone(prop_tensor)
self.assertNotIn(
0,
prop_tensor.stride(),
f"Propagated tensor still has zero stride: {prop_tensor.stride()}",
)

# Verify the tensor is contiguous
self.assertTrue(
prop_tensor.is_contiguous(),
f"Propagated tensor is not contiguous: {prop_tensor.stride()}",
)
Loading