diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index d3f9e415ff944..50fb3952cc080 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -4465,6 +4465,54 @@ def f(x, start, end): res = f(x, start, 0) self.assertEqual(res.shape, torch.Size([0])) + @skipIfTorchDynamo() + @torch.fx.experimental._config.patch("backed_size_oblivious", True) + def test_backed_size_oblivious_broadcast(self): + cnt = CompileCounterWithBackend("inductor") + torch._dynamo.reset() + + def func(a, b): + torch.broadcast_shapes(a.size(), b.size()) + return a + b + + compiled = torch.compile(func, fullgraph=True, backend=cnt, dynamic=True) + + def run(a, b): + self.assertEqual(compiled(a, b), func(a, b)) + + # No 0/1 specializations, no broadcasts. + # but a[0] == b[0] and a[1] == b[1] are asserted. + run(torch.rand(1, 10), torch.rand(1, 10)) + run(torch.rand(1, 1), torch.rand(1, 1)) + run(torch.rand(10, 10), torch.rand(10, 10)) + + self.assertEqual(cnt.frame_count, 1) + run(torch.rand(10, 10), torch.rand(1, 10)) + self.assertEqual(cnt.frame_count, 2) + + cnt.clear() + torch._dynamo.reset() + + # specialize a[0] == 1. b[0] not specialized. + run(torch.rand(1, 10), torch.rand(9, 10)) + run(torch.rand(1, 10), torch.rand(1, 10)) + self.assertEqual(cnt.frame_count, 1) + # if we change a[0] we get recompilation. + run(torch.rand(10, 10), torch.rand(10, 10)) + self.assertEqual(cnt.frame_count, 2) + + cnt.clear() + torch._dynamo.reset() + + # TODO duck sizing shall be disabled when backed_size_oblivious + # is on probably. + # specialize b[0] == 1. a[0] not specialized. + run(torch.rand(10, 11), torch.rand(1, 11)) + run(torch.rand(1, 10), torch.rand(1, 10)) + self.assertEqual(cnt.frame_count, 1) + run(torch.rand(2, 10), torch.rand(2, 10)) + self.assertEqual(cnt.frame_count, 2) + instantiate_parametrized_tests(TestUnbacked) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 9224643fe55ab..5137a65545b11 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -385,7 +385,13 @@ def handle_noncontiguous_outputs(input_tlist, output): def _broadcast_shapes(*_shapes): - from torch.fx.experimental.symbolic_shapes import guard_or_false, is_nested_int + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + is_nested_int, + size_hint, + ) + + backed_so = torch.fx.experimental._config.backed_size_oblivious shapes = tuple( (x,) if isinstance(x, IntLike) else x @@ -418,6 +424,22 @@ def _broadcast_shapes(*_shapes): ): continue else: + # When backed size oblivious is used, we specialize for broadcasting + # if its the only way to compile the example input. + # i.e: s0:1, s1:1 ==> + # assert s0==s1, no specialization on ==1 or !=1. + # The non-broadcast path is picked + # s0:1, s1:4 ==> + # specialize(s0) to be 1. + # s0:4, s1:1 ==> + # specialize(s1) to be 1. + if backed_so: + a = size_hint(shape[idx], allow_none=True) + b = size_hint(common_shape[idx], allow_none=True) + if a == 1 and b != 1: + torch._check(shape[idx] == 1) + if b == 1 and a != 1: + torch._check(common_shape[idx] == 1) if guard_or_false(shape[idx] == common_shape[idx]): continue diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 693d25aea6130..c3b64f2d85a53 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -131,6 +131,7 @@ class PendingUnbackedSymbolNotFound(RuntimeError): aten = torch._ops.ops.aten # type: ignore[has-type] __all__ = [ + "size_hint", "guard_or_false", "guard_or_true", "has_symbolic_sizes_strides", @@ -255,6 +256,17 @@ def _nested_int_aware_sort( ) +def size_hint(x: int | torch.SymInt, *, allow_none: bool = False) -> int | None: + """Gets a size hint for a given expression from the underlying shapes we had. + Does not introduce a guard, so only use this when you can guarantee that + your code is still valid for arbitrary shapes (such as optimization decisions) + """ + if isinstance(x, int): + return x + assert isinstance(x, torch.SymInt) + return x.node.shape_env.size_hint(x.node.expr, allow_none=allow_none) + + # Wrapper on lru_cache that reports statistics at process end def lru_cache( maxsize: Optional[int],