diff --git a/exir/tensor.py b/exir/tensor.py index 1345067354f..8a6cd51b1ab 100644 --- a/exir/tensor.py +++ b/exir/tensor.py @@ -67,12 +67,15 @@ def dim_order_from_stride(stride: Tuple[int]) -> Tuple[bytes]: Another example is: sizes = (1, 3, 1, 1) with strides = (3, 1, 3, 3), returned value is (0, 2, 3, 1) """ + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_size_oblivious, + ) + for _, s in enumerate(stride): - if s == 0: + if guard_or_false(s == 0): raise ValueError("0 in strides is not supported for ExecuTorch.") - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious - class K(NamedTuple): stride: int diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 855793d53a5..017a133b587 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -385,6 +385,7 @@ python_unittest( deps = [ "//caffe2:torch", "//executorch/exir:dim_order_utils", + "//executorch/exir:lib", ], ) diff --git a/exir/tests/test_dim_order_utils.py b/exir/tests/test_dim_order_utils.py index bf515edb567..21a12429af1 100644 --- a/exir/tests/test_dim_order_utils.py +++ b/exir/tests/test_dim_order_utils.py @@ -8,6 +8,7 @@ import unittest import torch +from executorch.exir import to_edge_transform_and_lower from executorch.exir.dim_order_utils import get_dim_order, get_memory_format @@ -27,3 +28,22 @@ def test_get_dim_order(self) -> None: list(range(ndim)), get_dim_order(torch.contiguous_format, ndim) ) self.assertEqual([0, 2, 3, 1], get_dim_order(torch.channels_last, 4)) + + def test_dim_order_from_stride(self): + class Test(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, t1, t2): + idx = torch.nonzero(t1).reshape(-1) + y = torch.index_select(t2, 0, idx) + return y + + M = Test() + x = torch.tensor([0, 1, 1, 0, 1], dtype=torch.bool) + y = torch.randn(5, 6) + M(x, y) + + expo_prog = torch.export.export_for_training(M, (x, y)) + edge_prog = to_edge_transform_and_lower(expo_prog) + edge_prog.to_executorch()