From 9f1ff69e7b812ca6ff1473fa94e5129b64d74416 Mon Sep 17 00:00:00 2001 From: Michael Voznesensky Date: Thu, 2 Feb 2023 00:53:45 +0000 Subject: [PATCH 1/4] Small refactor of shape guards to allow for 1:1 code_parts --- test/dynamo/test_misc.py | 33 ++++++++++++++++++++++++ test/test_proxy_tensor.py | 2 +- torch/_dynamo/guards.py | 6 ++--- torch/fx/experimental/symbolic_shapes.py | 14 +++++----- 4 files changed, 44 insertions(+), 11 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index e6d4cfbc9d73f..273ff86f6f489 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -3282,6 +3282,39 @@ def guard_failures(failure): self.assertTrue(guard_failure is not None) self.assertEqual(guard_failure[0], "k == 3") + @patch.object(torch._dynamo.config, "dynamic_shapes", True) + def test_guard_failure_fn_shape_control(self): + def fn(x, y): + if x.shape[0] < 3: + if y.shape[0] < 3: + return x * y + else: + return x + y + else: + return -1 + + x = torch.randn([2, 2]) + y = torch.randn([2, 2]) + + guard_failure = None + + def guard_failures(failure): + nonlocal guard_failure + guard_failure = failure + + opt_fn = torch._dynamo.optimize( + "eager", nopython=True, guard_fail_fn=guard_failures + )(fn) + + x2 = torch.randn([5, 5]) + y2 = torch.randn([5, 5]) + + opt_fn(x, y) + opt_fn(x2, y2) + + self.assertTrue(guard_failure is not None) + self.assertEqual(guard_failure[0], "x.size()[0] < 3") + def test_guard_failure_fn2(self): def fn(x, y): x = x + 1 diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 6c8478a4a64b0..64ba5a6b00e0e 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1081,7 +1081,7 @@ def f(a, b): fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(16), torch.randn(8)) from torch._dynamo.source import LocalSource self.assertExpectedInline( - fx_g.shape_env.codegen_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")]), + fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")]), """a.size()[0] == 2*b.size()[0] and a.stride()[0] == 1 and a.storage_offset() == 0 and b.stride()[0] == 1 and b.storage_offset() == 0 and b.size()[0] != 0 and b.size()[0] != 1""" # noqa: B950 ) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 9e9599e546ab3..351b2182fb0e2 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -401,13 +401,13 @@ def SHAPE_ENV(self, guard: Guard): output_graph = self.guarded_code.output_graph # NB: self.output_graph can be None in the debug_nops tests fs = output_graph.tracked_fakes - code = output_graph.shape_env.codegen_guards( + guards = output_graph.shape_env.produce_guards( [a.fake for a in fs], [a.source for a in fs], source_ref=self.source_ref, ) - if code != "True": - self._produce_guard_code(guard, [code], shape_env=True) + for shape_guard in guards: + self._produce_guard_code(guard, [shape_guard], shape_env=True) def TENSOR_MATCH(self, guard: Guard): if guard.is_nn_module(): diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index a7037550ed14b..cefc7929093a5 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -821,7 +821,7 @@ def duck_int(self, val): # on if the guards evaluated to True or not. Primarily used by Dynamo, # but this is also helpful for manual testing of guards (see # evaluate_guards_for_args) - def codegen_guards(self, placeholders, sources, + def produce_guards(self, placeholders, sources, source_ref=lambda n: n.name()): # It took a lot of sweat to figure out the algorithm here. Let's # explain how it works. @@ -963,16 +963,16 @@ def track_symint(source, val): # negative inferences on shape variables exprs.append(f"{source_ref(sources[0])} != 0 and {source_ref(sources[0])} != 1") - if exprs: - return " and ".join(exprs) - else: - return "True" + return exprs def evaluate_guards_for_args(self, placeholders, args): from torch._dynamo.source import GlobalSource arg_names = [f"t{i}" for i in range(len(args))] - code = self.codegen_guards(placeholders, [GlobalSource(a) for a in arg_names]) - return eval(code, {}, dict(zip(arg_names, args))) + guards = self.produce_guards(placeholders, [GlobalSource(a) for a in arg_names]) + if guards: + code = " and ".join(guards) + return eval(code, {}, dict(zip(arg_names, args))) + return True def bind_symbols(self, placeholders, args): # Given a paired list of placeholders (fake tensors with From 5391d6fda3e1e71d75669a8aa8f084486fa7f0ba Mon Sep 17 00:00:00 2001 From: Michael Voznesensky Date: Thu, 2 Feb 2023 01:04:12 +0000 Subject: [PATCH 2/4] A little doc pizzaz --- torch/fx/experimental/symbolic_shapes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index cefc7929093a5..196e0fe9979db 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -816,13 +816,13 @@ def duck_int(self, val): ) return self.val_to_var[val] - # Generates a Python string which, when evaluated in a context that + # Generates a list of guards strings which, when evaluated in a context that # defines tensors for all the sources, returns True or False depending - # on if the guards evaluated to True or not. Primarily used by Dynamo, + # on if the guards in the list evaluated to True or not. Primarily used by Dynamo, # but this is also helpful for manual testing of guards (see # evaluate_guards_for_args) def produce_guards(self, placeholders, sources, - source_ref=lambda n: n.name()): + source_ref=lambda n: n.name()) -> List[str]: # It took a lot of sweat to figure out the algorithm here. Let's # explain how it works. # From 5b3a817e74336366f2e06a1ac39a445ea83a93a7 Mon Sep 17 00:00:00 2001 From: Michael Voznesensky Date: Thu, 2 Feb 2023 02:00:37 +0000 Subject: [PATCH 3/4] Test missed --- test/test_proxy_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 64ba5a6b00e0e..7c78217abbecb 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1081,8 +1081,8 @@ def f(a, b): fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(16), torch.randn(8)) from torch._dynamo.source import LocalSource self.assertExpectedInline( - fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")]), - """a.size()[0] == 2*b.size()[0] and a.stride()[0] == 1 and a.storage_offset() == 0 and b.stride()[0] == 1 and b.storage_offset() == 0 and b.size()[0] != 0 and b.size()[0] != 1""" # noqa: B950 + str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")])), + """['a.size()[0] == 2*b.size()[0]', 'a.stride()[0] == 1', 'a.storage_offset() == 0', 'b.stride()[0] == 1', 'b.storage_offset() == 0', 'b.size()[0] != 0 and b.size()[0] != 1']""" # noqa: B950 ) def test_sym_storage_offset(self): From bc5c9bf1176c9ff16c2ee8592c029615089e45db Mon Sep 17 00:00:00 2001 From: Michael Voznesensky Date: Thu, 2 Feb 2023 08:21:06 +0000 Subject: [PATCH 4/4] test fix 2 --- test/dynamo/test_export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 59bf9b8145393..1e24284d3e810 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -98,7 +98,7 @@ def func(x): for guard in out_guards: if guard.source == GuardSource.SHAPE_ENV: hit = True - self.assertTrue("x.size()[0] <= 10" in guard.code_list[0]) + self.assertTrue("x.size()[0] <= 10" in guard.code_list) self.assertTrue(hit)