Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4465,6 +4465,54 @@ def f(x, start, end):
res = f(x, start, 0)
self.assertEqual(res.shape, torch.Size([0]))

@skipIfTorchDynamo()
@torch.fx.experimental._config.patch("backed_size_oblivious", True)
def test_backed_size_oblivious_broadcast(self):
cnt = CompileCounterWithBackend("inductor")
torch._dynamo.reset()

def func(a, b):
torch.broadcast_shapes(a.size(), b.size())
return a + b

compiled = torch.compile(func, fullgraph=True, backend=cnt, dynamic=True)

def run(a, b):
self.assertEqual(compiled(a, b), func(a, b))

# No 0/1 specializations, no broadcasts.
# but a[0] == b[0] and a[1] == b[1] are asserted.
run(torch.rand(1, 10), torch.rand(1, 10))
run(torch.rand(1, 1), torch.rand(1, 1))
run(torch.rand(10, 10), torch.rand(10, 10))

self.assertEqual(cnt.frame_count, 1)
run(torch.rand(10, 10), torch.rand(1, 10))
self.assertEqual(cnt.frame_count, 2)

cnt.clear()
torch._dynamo.reset()

# specialize a[0] == 1. b[0] not specialized.
run(torch.rand(1, 10), torch.rand(9, 10))
run(torch.rand(1, 10), torch.rand(1, 10))
self.assertEqual(cnt.frame_count, 1)
# if we change a[0] we get recompilation.
run(torch.rand(10, 10), torch.rand(10, 10))
self.assertEqual(cnt.frame_count, 2)

cnt.clear()
torch._dynamo.reset()

# TODO duck sizing shall be disabled when backed_size_oblivious
# is on probably.
# specialize b[0] == 1. a[0] not specialized.
run(torch.rand(10, 11), torch.rand(1, 11))
run(torch.rand(1, 10), torch.rand(1, 10))
self.assertEqual(cnt.frame_count, 1)
run(torch.rand(2, 10), torch.rand(2, 10))
self.assertEqual(cnt.frame_count, 2)


instantiate_parametrized_tests(TestUnbacked)

Expand Down
24 changes: 23 additions & 1 deletion torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,13 @@ def handle_noncontiguous_outputs(input_tlist, output):


def _broadcast_shapes(*_shapes):
from torch.fx.experimental.symbolic_shapes import guard_or_false, is_nested_int
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
is_nested_int,
size_hint,
)

backed_so = torch.fx.experimental._config.backed_size_oblivious

shapes = tuple(
(x,) if isinstance(x, IntLike) else x
Expand Down Expand Up @@ -418,6 +424,22 @@ def _broadcast_shapes(*_shapes):
):
continue
else:
# When backed size oblivious is used, we specialize for broadcasting
# if its the only way to compile the example input.
# i.e: s0:1, s1:1 ==>
# assert s0==s1, no specialization on ==1 or !=1.
# The non-broadcast path is picked
# s0:1, s1:4 ==>
# specialize(s0) to be 1.
# s0:4, s1:1 ==>
# specialize(s1) to be 1.
if backed_so:
a = size_hint(shape[idx], allow_none=True)
b = size_hint(common_shape[idx], allow_none=True)
if a == 1 and b != 1:
torch._check(shape[idx] == 1)
if b == 1 and a != 1:
torch._check(common_shape[idx] == 1)
if guard_or_false(shape[idx] == common_shape[idx]):
continue

Expand Down
12 changes: 12 additions & 0 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class PendingUnbackedSymbolNotFound(RuntimeError):
aten = torch._ops.ops.aten # type: ignore[has-type]

__all__ = [
"size_hint",
"guard_or_false",
"guard_or_true",
"has_symbolic_sizes_strides",
Expand Down Expand Up @@ -255,6 +256,17 @@ def _nested_int_aware_sort(
)


def size_hint(x: int | torch.SymInt, *, allow_none: bool = False) -> int | None:
"""Gets a size hint for a given expression from the underlying shapes we had.
Does not introduce a guard, so only use this when you can guarantee that
your code is still valid for arbitrary shapes (such as optimization decisions)
"""
if isinstance(x, int):
return x
assert isinstance(x, torch.SymInt)
return x.node.shape_env.size_hint(x.node.expr, allow_none=allow_none)


# Wrapper on lru_cache that reports statistics at process end
def lru_cache(
maxsize: Optional[int],
Expand Down
Loading