From c08138e65ab98ee8c2c265acb3ba50ae2b8df2da Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 26 Apr 2024 14:02:58 -0700 Subject: [PATCH] [dynamo] Allow inlining of hooks for the top module ghstack-source-id: 567f8dbd8d9b790726b2fb9a246c033d22035a44 Pull Request resolved: https://github.com/pytorch/pytorch/pull/124501 --- test/dynamo/test_aot_autograd.py | 20 +++++++------- test/dynamo/test_decorators.py | 2 +- test/dynamo/test_hooks.py | 5 +++- test/dynamo/test_misc.py | 11 +++++--- test/dynamo/test_modules.py | 43 ++++++++++++++++------------- test/functorch/test_control_flow.py | 24 ++++++++-------- torch/_dynamo/compiled_autograd.py | 5 +++- torch/_dynamo/eval_frame.py | 21 ++++++++++---- torch/_dynamo/trace_rules.py | 19 +++++++++++++ 9 files changed, 96 insertions(+), 54 deletions(-) diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py index 84b1e0a28c0ec..6bd0f0ddddd02 100644 --- a/test/dynamo/test_aot_autograd.py +++ b/test/dynamo/test_aot_autograd.py @@ -307,7 +307,7 @@ def guard_fail_fn(failure): compare_equal_outs_and_grads(self, F(), fxy, (x, y)) compare_equal_outs_and_grads(self, F(), fxy, (x, z)) self.assertIn( - """tensor 'L['y']' requires_grad mismatch. expected requires_grad=1""", + """tensor 'L['args'][1]' requires_grad mismatch. expected requires_grad=1""", failure_reason, ) @@ -425,7 +425,7 @@ def guard_fail_fn(failure): fxx(x3, x3) fxx(x4, y4) self.assertEqual(cc.frame_count, 2) - self.assertIn("""L['x'] is L['y']""", failure_reason) + self.assertIn("""L['args'][0] is L['args'][1]""", failure_reason) @patch("torch._functorch.config.debug_assert", True) def test_arg_dupe_via_dynamo_recompiles_many_args_param_non_tensor_arg(self): @@ -459,7 +459,7 @@ def guard_fail_fn(failure): f(a2, b2, 2, 2) self.assertEqual(cc.frame_count, 2) self.assertIn( - """L['a'] is L['b']""", + """L['args'][0] is L['args'][1]""", failure_reason, ) @@ -476,7 +476,7 @@ def guard_fail_fn(failure): f(c3, c3, 3, 3) f(c4, d4, 3, 3) self.assertEqual(cc.frame_count, 2) - self.assertIn("""L['a'] is L['b']""", failure_reason) + self.assertIn("""L['args'][0] is L['args'][1]""", failure_reason) @patch("torch._functorch.config.debug_assert", True) def test_arg_dupe_via_dynamo_recompiles_many_with_global(self): @@ -513,7 +513,7 @@ def guard_fail_fn(failure): f(a2, b2, 2, 2) self.assertEqual(cc.frame_count, 2) self.assertIn( - """L['a'] is L['b']""", + """L['args'][0] is L['args'][1]""", failure_reason, ) @@ -549,7 +549,7 @@ def guard_fail_fn(failure): f([3, 2, 1], [4, 5, 6], a2, b2) self.assertEqual(cc.frame_count, 2) self.assertIn( - """L['a'] is L['b']""", + """L['args'][2] is L['args'][3]""", failure_reason, ) @@ -599,7 +599,7 @@ def guard_fail_fn(failure): f(a2, b2) self.assertEqual(cc.frame_count, 2) self.assertIn( - """L['a'] is L['b']""", + """L['args'][0] is L['args'][1]""", failure_reason, ) @@ -616,7 +616,7 @@ def guard_fail_fn(failure): f(c3, c3) f(c4, d4) self.assertEqual(cc.frame_count, 2) - self.assertIn("""L['a'] is L['b']""", failure_reason) + self.assertIn("""L['args'][0] is L['args'][1]""", failure_reason) @patch("torch._functorch.config.debug_assert", True) def test_arg_dupe_via_dynamo_recompiles_many_args(self): @@ -648,7 +648,7 @@ def guard_fail_fn(failure): f(a2, b2, b2, b2) self.assertEqual(cc.frame_count, 2) self.assertIn( - """L['a'] is L['b']""", + """L['args'][0] is L['args'][1]""", failure_reason, ) @@ -665,7 +665,7 @@ def guard_fail_fn(failure): f(a3, b3, c3, c3) f(a4, b4, c4, d4) self.assertEqual(cc.frame_count, 2) - self.assertIn("""L['c'] is L['d']""", failure_reason) + self.assertIn("""L['args'][2] is L['args'][3]""", failure_reason) def test_alias_inputs(self): def fn(): diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 890edca40ccc9..f7269c401addd 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -337,7 +337,7 @@ def global_context_capture_fn(frame_summary): ): torch._dynamo.optimize("eager")(e)(x) - self.assertEqual(len(seen_frames), 0) + self.assertEqual(len(seen_frames), 2) def test_torch_guards_stack_frame_register_inlining_partially_disable(self): y = torch.nn.Parameter(torch.tensor([0.25, 0.25])) diff --git a/test/dynamo/test_hooks.py b/test/dynamo/test_hooks.py index 17966cfb85a65..8e4e4dc2b4210 100644 --- a/test/dynamo/test_hooks.py +++ b/test/dynamo/test_hooks.py @@ -674,7 +674,10 @@ def forward(self, x): comp_out = comp_mod(x1) - self.assertEqual(cnts.frame_count, 1) + # Now the forward graph is recompiled because of this guard failure + # ___check_obj_id(L['fn'].forward.__closure__[0].cell_contents.__code__, 139779879079008) + # which is basically id(my_hook) + self.assertEqual(cnts.frame_count, 2) comp_out[0].backward(torch.ones(4)) self.assertEqual(x0.grad, x1.grad) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 2f27f168898fc..0b2bc2339e092 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -7803,7 +7803,7 @@ def forward(self, input): # Not an exhaustive test of dynamic shapes behavior, but some sanity if torch._dynamo.config.assume_static_by_default: - base_checker().check("Recompile Reasons").check("'forward'").check( + base_checker().check("Recompile Reasons").check("'inner'").check( "cache_size_limit to 1" ).run(prof.report()) else: @@ -7812,10 +7812,10 @@ def forward(self, input): new_shape_input = torch.rand((4, 3, 4)) _ = compiled(new_shape_input) - base_checker().check("Recompile Reasons").check("'forward'").check( - "tensor 'L['input']' size mismatch at index 0. expected 2, actual 3" + base_checker().check("Recompile Reasons").check("'inner'").check( + "tensor 'L['args'][0]' size mismatch at index 0. expected 2, actual 3" ).check( - "tensor 'L['input']' size mismatch at index 0. expected 3, actual 4" + "tensor 'L['args'][0]' size mismatch at index 0. expected 3, actual 4" ).run( prof.report() ) @@ -10144,6 +10144,9 @@ def test_linear_module_free(self): def test_outside_linear_module_free(self): # Compared to test_linear_module_free, the linear # layer is not the code object that is directly compiled. + + # functools.lru_cache causes the static test to fail. Removing it passes. + # Dynamic still fails. def model_inp_ctr(): fc = torch.nn.Linear(100, 100) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index de0e66e59fb72..cff738b774a00 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -1304,7 +1304,6 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): run() - @patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False) def test_nn_moduledict_contains(self): class M(torch.nn.Module): def __init__(self, module_dict): @@ -1327,33 +1326,37 @@ def forward(self, x): self.assertEqual(cnt.op_count, 2) self.assertTrue(torch._dynamo.testing.same(out1, out2)) + torch._dynamo.reset() module_dict = torch.nn.ModuleDict({"bar": torch.nn.Conv2d(1, 1, 1)}) m = M(module_dict) data = torch.randn(1) out1 = m(data) cnt = torch._dynamo.testing.CompileCounter() - torch._dynamo.reset() opt_m = torch._dynamo.optimize(cnt, nopython=True)(m) out2 = opt_m(data) - self.assertEqual(cnt.op_count, 1) self.assertTrue(torch._dynamo.testing.same(out1, out2)) - module_dict = torch.nn.ModuleDict({"cat": torch.nn.Conv2d(1, 1, 1)}) - pre = m(data) - cnt.clear() + torch._dynamo.reset() + cnt = torch._dynamo.testing.CompileCounter() + data = torch.randn(1) + module_dict1 = torch.nn.ModuleDict({"cat": torch.nn.Conv2d(1, 1, 1)}) + module_dict2 = torch.nn.ModuleDict({"foo": torch.nn.Conv2d(1, 1, 1)}) + + m1 = M(module_dict1) + m2 = M(module_dict2) + + def fn(): + out1 = m1(data) + out2 = m2(data) + return out1 - with torch._dynamo.optimize(cnt, nopython=False): - opt_pre = m(data) - m = M(module_dict) - data = torch.randn(1) - out1 = m(data) + opt_fn = torch.compile(fn, backend=cnt) + opt_fn() - out_post = m(data) self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, 1) - self.assertTrue(torch._dynamo.testing.same(pre, opt_pre)) - self.assertTrue(torch._dynamo.testing.same(out1, out_post)) + self.assertEqual(cnt.op_count, 3) + self.assertTrue(torch._dynamo.testing.same(fn(), opt_fn())) # RuntimeError: SymIntArrayRef expected to contain only concrete integers @expectedFailureDynamic @@ -1763,7 +1766,9 @@ def forward(self, x): ]: x = torch.randn(size) mod(x) - self.assertEqual(cnts.frame_count, 2 * num_submodules) + # The extra recompilations happen because _wrapped_call_impl is now + # falling back to eager, and Dynamo is triggering on forward method. + self.assertEqual(cnts.frame_count, 3 * num_submodules) def test_recursion(self): mod = MockModule() @@ -2136,15 +2141,15 @@ def output_modifying_hook(mod, inp, out): loss_bwd = loss.backward() self.assertEqual(eager_loss_bwd, loss_bwd) - self.assertEqual(cnt.frame_count, 2) + self.assertEqual(cnt.frame_count, 1) # Ndim change, recompile pred = model(torch.randn([10, 10, 10])) - self.assertEqual(cnt.frame_count, 4) + self.assertEqual(cnt.frame_count, 2) # Stable pred = model(torch.randn([10, 10, 10])) - self.assertEqual(cnt.frame_count, 4) + self.assertEqual(cnt.frame_count, 2) def test_dunder_call_explicitly(self): # hooks should be triggered if explicit calling `__call__` diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index a1aeb8c1de7d3..53911da868492 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -682,15 +682,15 @@ def test_while_loop_simple_with_linear_compile_check_graph(self): self.assertExpectedInline( gm.code.strip(), """\ -def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor): - l_iter_ = L_iter_ - l_x_ = L_x_ - l__self___dec = self.L__self___dec - l__self___linear_weight = self.L__self___linear_weight - l__self___linear_bias = self.L__self___linear_bias +def forward(self, L_args_0_ : torch.Tensor, L_args_1_ : torch.Tensor): + l_args_0_ = L_args_0_ + l_args_1_ = L_args_1_ + l__fn___dec = self.L__fn___dec + l__fn___linear_weight = self.L__fn___linear_weight + l__fn___linear_bias = self.L__fn___linear_bias cond_fn_0 = self.cond_fn_0 body_fn_0 = self.body_fn_0 - while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_iter_, l_x_), (l__self___dec, l__self___linear_bias, l__self___linear_weight)); cond_fn_0 = body_fn_0 = l_iter_ = l_x_ = l__self___dec = l__self___linear_bias = l__self___linear_weight = None + while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_args_0_, l_args_1_), (l__fn___dec, l__fn___linear_bias, l__fn___linear_weight)); cond_fn_0 = body_fn_0 = l_args_0_ = l_args_1_ = l__fn___dec = l__fn___linear_bias = l__fn___linear_weight = None getitem = while_loop[0] getitem_1 = while_loop[1]; while_loop = None return (getitem, getitem_1)""", # noqa: B950 @@ -698,17 +698,17 @@ def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor): self.assertExpectedInline( gm.cond_fn_0.code.strip(), """\ -def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn): - sub = l_iter_ - l__self___dec_cond_fn; l_iter_ = l__self___dec_cond_fn = None +def forward(self, l_args_0_, l_args_1_, l__fn___dec_cond_fn, l__fn___linear_bias_body_fn, l__fn___linear_weight_body_fn): + sub = l_args_0_ - l__fn___dec_cond_fn; l_args_0_ = l__fn___dec_cond_fn = None gt = sub > 0; sub = None return gt""", # noqa: B950 ) self.assertExpectedInline( gm.body_fn_0.code.strip(), """\ -def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn): - sub = l_iter_ - 1; l_iter_ = None - linear = torch._C._nn.linear(l_x_, l__self___linear_weight_body_fn, l__self___linear_bias_body_fn); l_x_ = l__self___linear_weight_body_fn = l__self___linear_bias_body_fn = None +def forward(self, l_args_0_, l_args_1_, l__fn___dec_cond_fn, l__fn___linear_bias_body_fn, l__fn___linear_weight_body_fn): + sub = l_args_0_ - 1; l_args_0_ = None + linear = torch._C._nn.linear(l_args_1_, l__fn___linear_weight_body_fn, l__fn___linear_bias_body_fn); l_args_1_ = l__fn___linear_weight_body_fn = l__fn___linear_bias_body_fn = None return (sub, linear)""", # noqa: B950 ) diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index 18385171fac8c..e006ea9bf6b5b 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -207,7 +207,10 @@ def end_capture(self, outputs): "compiled_autograd_graph", payload_fn=lambda: graph.print_readable(print_output=False), ) - return self.compiler_fn(graph) + + # Fix for test_module_backward_hooks_eager + with torch._dynamo.trace_rules.dont_wrap_top_module(): + return self.compiler_fn(graph) def reorder_accumulate_grad_nodes(self): """ diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 99a466523aadd..43460ec3843f6 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -137,15 +137,18 @@ def __init__(self, mod: torch.nn.Module, dynamo_ctx): def _initialize(self): # Do this stuff in constructor to lower overhead slightly - if isinstance(self._orig_mod.forward, types.MethodType) and trace_rules.check( - self._orig_mod.forward + + if trace_rules.should_wrap_top_module() or ( + isinstance(self._orig_mod.forward, types.MethodType) + and trace_rules.check(self._orig_mod.forward) ): - # This may be a torch.nn.* instance in trace_rules.py which - # won't trigger a frame evaluation workaround to add an extra - # frame we can capture + # TODO(export-team) - the second part of the or condition is + # required for export tests. We should fix them and remove it. self.forward = self.dynamo_ctx(external_utils.wrap_inline(self._orig_mod)) else: # Invoke hooks outside of dynamo then pickup the inner frame + # TODO(export-team/compiled-autograd) - This is because of test + # failures for export and compiled-autograd. self.forward = self.dynamo_ctx(self._orig_mod.__call__) if hasattr(self._orig_mod, "_initialize_hook"): @@ -1215,7 +1218,13 @@ def result_capturing_wrapper(*graph_inputs): automatic_dynamic_shapes=False, capture_dynamic_output_shape_ops=True, capture_scalar_outputs=True, - ): + ), trace_rules.dont_wrap_top_module(): + # TODO(export-team) - discrepancy between torch.compile and + # torch.export because torch.compile is planning to inline the + # _call_impl (one level above forward) to inline hooks. But doing + # that for export breaks many tests because (1) tests are hardcoded + # to assume that tracing starts from forward, and (2) some + # discrepancies between strict and non strict mode. opt_f = optimize_assert( dynamo_normalization_capturing_compiler, hooks=Hooks( diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 763f6482cb92e..3220ba8ddad0e 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -32,6 +32,7 @@ import unittest import weakref from collections import defaultdict +from contextlib import contextmanager from typing import Any, Callable, cast, Dict, List, Optional, Set, Union np: Optional[types.ModuleType] = None @@ -127,6 +128,24 @@ """ + +_TLS = threading.local() + + +@contextmanager +def dont_wrap_top_module(): + old = getattr(_TLS, "wrap_top_module", True) + _TLS.wrap_top_module = False + try: + yield False + finally: + _TLS.wrap_top_module = old + + +def should_wrap_top_module(): + return getattr(_TLS, "wrap_top_module", True) + + manual_torch_name_rule_map = { "torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable, "torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable,