Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
ezyang committed Jun 9, 2024
1 parent 8e8ca4c commit 0b36b03
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
2 changes: 1 addition & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,7 +1203,7 @@ def f(src_tokens):
gm = make_fx(f, tracing_mode="symbolic")(src_tokens)
# Guards to rule out batch_size == sys.maxsize (wobbling between 2 and
# 1 ok)
self.assertEqual(len(gm.shape_env.guards), 2)
self.assertEqual(len(gm.shape_env.guards), 1)

@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
def test_cpu_scalar_cuda(self):
Expand Down
9 changes: 7 additions & 2 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,10 @@ def slice_forward(
end: Optional[int] = None,
step: int = 1,
):
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, statically_known_true
from torch.fx.experimental.symbolic_shapes import (
guard_size_oblivious,
statically_known_true,
)

ndim = self.dim()
if ndim == 0:
Expand Down Expand Up @@ -762,7 +765,9 @@ def slice_forward(

if end_val < start_val:
end_val = start_val
elif statically_known_true(end_val == sys.maxsize) or guard_size_oblivious(end_val > sizes[dim]):
elif statically_known_true(end_val == sys.maxsize) or guard_size_oblivious(
end_val > sizes[dim]
):
end_val = sizes[dim]

storage_offset = self.storage_offset() + start_val * strides[dim]
Expand Down

0 comments on commit 0b36b03

Please sign in to comment.