From 87dc06dafa711dd487251f7bb757120819605149 Mon Sep 17 00:00:00 2001 From: Prashant Rawat Date: Mon, 3 Nov 2025 17:50:21 -0800 Subject: [PATCH] Don't assert for symbolic stride in dim_order_from_stride() (#15472) Summary: Curently dim_order_for_stride() checks if any stride is 0, and fails if so. N7613577 shows a min repro from factorized joiner use case with symbolic stride, where this check fails: P2015309933 This failure is blocking us from migrating live translation models to ExecuTorch. This diff fixes the block by skipping the assert for symbolic strides. Reviewed By: angelayi Differential Revision: D85875885 --- exir/tensor.py | 9 ++++++--- exir/tests/TARGETS | 1 + exir/tests/test_dim_order_utils.py | 20 ++++++++++++++++++++ 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/exir/tensor.py b/exir/tensor.py index 1345067354f..8a6cd51b1ab 100644 --- a/exir/tensor.py +++ b/exir/tensor.py @@ -67,12 +67,15 @@ def dim_order_from_stride(stride: Tuple[int]) -> Tuple[bytes]: Another example is: sizes = (1, 3, 1, 1) with strides = (3, 1, 3, 3), returned value is (0, 2, 3, 1) """ + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_size_oblivious, + ) + for _, s in enumerate(stride): - if s == 0: + if guard_or_false(s == 0): raise ValueError("0 in strides is not supported for ExecuTorch.") - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious - class K(NamedTuple): stride: int diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 855793d53a5..017a133b587 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -385,6 +385,7 @@ python_unittest( deps = [ "//caffe2:torch", "//executorch/exir:dim_order_utils", + "//executorch/exir:lib", ], ) diff --git a/exir/tests/test_dim_order_utils.py b/exir/tests/test_dim_order_utils.py index bf515edb567..21a12429af1 100644 --- a/exir/tests/test_dim_order_utils.py +++ b/exir/tests/test_dim_order_utils.py @@ -8,6 +8,7 @@ import unittest import torch +from executorch.exir import to_edge_transform_and_lower from executorch.exir.dim_order_utils import get_dim_order, get_memory_format @@ -27,3 +28,22 @@ def test_get_dim_order(self) -> None: list(range(ndim)), get_dim_order(torch.contiguous_format, ndim) ) self.assertEqual([0, 2, 3, 1], get_dim_order(torch.channels_last, 4)) + + def test_dim_order_from_stride(self): + class Test(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, t1, t2): + idx = torch.nonzero(t1).reshape(-1) + y = torch.index_select(t2, 0, idx) + return y + + M = Test() + x = torch.tensor([0, 1, 1, 0, 1], dtype=torch.bool) + y = torch.randn(5, 6) + M(x, y) + + expo_prog = torch.export.export_for_training(M, (x, y)) + edge_prog = to_edge_transform_and_lower(expo_prog) + edge_prog.to_executorch()