Skip to content

Commit

Permalink
[dynamo] Allow inlining of hooks for the top module
Browse files Browse the repository at this point in the history
ghstack-source-id: 567f8dbd8d9b790726b2fb9a246c033d22035a44
Pull Request resolved: #124501
  • Loading branch information
anijain2305 committed Apr 26, 2024
1 parent 69eb49f commit c08138e
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 54 deletions.
20 changes: 10 additions & 10 deletions test/dynamo/test_aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def guard_fail_fn(failure):
compare_equal_outs_and_grads(self, F(), fxy, (x, y))
compare_equal_outs_and_grads(self, F(), fxy, (x, z))
self.assertIn(
"""tensor 'L['y']' requires_grad mismatch. expected requires_grad=1""",
"""tensor 'L['args'][1]' requires_grad mismatch. expected requires_grad=1""",
failure_reason,
)

Expand Down Expand Up @@ -425,7 +425,7 @@ def guard_fail_fn(failure):
fxx(x3, x3)
fxx(x4, y4)
self.assertEqual(cc.frame_count, 2)
self.assertIn("""L['x'] is L['y']""", failure_reason)
self.assertIn("""L['args'][0] is L['args'][1]""", failure_reason)

@patch("torch._functorch.config.debug_assert", True)
def test_arg_dupe_via_dynamo_recompiles_many_args_param_non_tensor_arg(self):
Expand Down Expand Up @@ -459,7 +459,7 @@ def guard_fail_fn(failure):
f(a2, b2, 2, 2)
self.assertEqual(cc.frame_count, 2)
self.assertIn(
"""L['a'] is L['b']""",
"""L['args'][0] is L['args'][1]""",
failure_reason,
)

Expand All @@ -476,7 +476,7 @@ def guard_fail_fn(failure):
f(c3, c3, 3, 3)
f(c4, d4, 3, 3)
self.assertEqual(cc.frame_count, 2)
self.assertIn("""L['a'] is L['b']""", failure_reason)
self.assertIn("""L['args'][0] is L['args'][1]""", failure_reason)

@patch("torch._functorch.config.debug_assert", True)
def test_arg_dupe_via_dynamo_recompiles_many_with_global(self):
Expand Down Expand Up @@ -513,7 +513,7 @@ def guard_fail_fn(failure):
f(a2, b2, 2, 2)
self.assertEqual(cc.frame_count, 2)
self.assertIn(
"""L['a'] is L['b']""",
"""L['args'][0] is L['args'][1]""",
failure_reason,
)

Expand Down Expand Up @@ -549,7 +549,7 @@ def guard_fail_fn(failure):
f([3, 2, 1], [4, 5, 6], a2, b2)
self.assertEqual(cc.frame_count, 2)
self.assertIn(
"""L['a'] is L['b']""",
"""L['args'][2] is L['args'][3]""",
failure_reason,
)

Expand Down Expand Up @@ -599,7 +599,7 @@ def guard_fail_fn(failure):
f(a2, b2)
self.assertEqual(cc.frame_count, 2)
self.assertIn(
"""L['a'] is L['b']""",
"""L['args'][0] is L['args'][1]""",
failure_reason,
)

Expand All @@ -616,7 +616,7 @@ def guard_fail_fn(failure):
f(c3, c3)
f(c4, d4)
self.assertEqual(cc.frame_count, 2)
self.assertIn("""L['a'] is L['b']""", failure_reason)
self.assertIn("""L['args'][0] is L['args'][1]""", failure_reason)

@patch("torch._functorch.config.debug_assert", True)
def test_arg_dupe_via_dynamo_recompiles_many_args(self):
Expand Down Expand Up @@ -648,7 +648,7 @@ def guard_fail_fn(failure):
f(a2, b2, b2, b2)
self.assertEqual(cc.frame_count, 2)
self.assertIn(
"""L['a'] is L['b']""",
"""L['args'][0] is L['args'][1]""",
failure_reason,
)

Expand All @@ -665,7 +665,7 @@ def guard_fail_fn(failure):
f(a3, b3, c3, c3)
f(a4, b4, c4, d4)
self.assertEqual(cc.frame_count, 2)
self.assertIn("""L['c'] is L['d']""", failure_reason)
self.assertIn("""L['args'][2] is L['args'][3]""", failure_reason)

def test_alias_inputs(self):
def fn():
Expand Down
2 changes: 1 addition & 1 deletion test/dynamo/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def global_context_capture_fn(frame_summary):
):
torch._dynamo.optimize("eager")(e)(x)

self.assertEqual(len(seen_frames), 0)
self.assertEqual(len(seen_frames), 2)

def test_torch_guards_stack_frame_register_inlining_partially_disable(self):
y = torch.nn.Parameter(torch.tensor([0.25, 0.25]))
Expand Down
5 changes: 4 additions & 1 deletion test/dynamo/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,10 @@ def forward(self, x):

comp_out = comp_mod(x1)

self.assertEqual(cnts.frame_count, 1)
# Now the forward graph is recompiled because of this guard failure
# ___check_obj_id(L['fn'].forward.__closure__[0].cell_contents.__code__, 139779879079008)
# which is basically id(my_hook)
self.assertEqual(cnts.frame_count, 2)
comp_out[0].backward(torch.ones(4))
self.assertEqual(x0.grad, x1.grad)

Expand Down
11 changes: 7 additions & 4 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7803,7 +7803,7 @@ def forward(self, input):

# Not an exhaustive test of dynamic shapes behavior, but some sanity
if torch._dynamo.config.assume_static_by_default:
base_checker().check("Recompile Reasons").check("'forward'").check(
base_checker().check("Recompile Reasons").check("'inner'").check(
"cache_size_limit to 1"
).run(prof.report())
else:
Expand All @@ -7812,10 +7812,10 @@ def forward(self, input):
new_shape_input = torch.rand((4, 3, 4))
_ = compiled(new_shape_input)

base_checker().check("Recompile Reasons").check("'forward'").check(
"tensor 'L['input']' size mismatch at index 0. expected 2, actual 3"
base_checker().check("Recompile Reasons").check("'inner'").check(
"tensor 'L['args'][0]' size mismatch at index 0. expected 2, actual 3"
).check(
"tensor 'L['input']' size mismatch at index 0. expected 3, actual 4"
"tensor 'L['args'][0]' size mismatch at index 0. expected 3, actual 4"
).run(
prof.report()
)
Expand Down Expand Up @@ -10144,6 +10144,9 @@ def test_linear_module_free(self):
def test_outside_linear_module_free(self):
# Compared to test_linear_module_free, the linear
# layer is not the code object that is directly compiled.

# functools.lru_cache causes the static test to fail. Removing it passes.
# Dynamic still fails.
def model_inp_ctr():
fc = torch.nn.Linear(100, 100)

Expand Down
43 changes: 24 additions & 19 deletions test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,6 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):

run()

@patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
def test_nn_moduledict_contains(self):
class M(torch.nn.Module):
def __init__(self, module_dict):
Expand All @@ -1327,33 +1326,37 @@ def forward(self, x):
self.assertEqual(cnt.op_count, 2)
self.assertTrue(torch._dynamo.testing.same(out1, out2))

torch._dynamo.reset()
module_dict = torch.nn.ModuleDict({"bar": torch.nn.Conv2d(1, 1, 1)})
m = M(module_dict)
data = torch.randn(1)
out1 = m(data)
cnt = torch._dynamo.testing.CompileCounter()
torch._dynamo.reset()
opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
out2 = opt_m(data)

self.assertEqual(cnt.op_count, 1)
self.assertTrue(torch._dynamo.testing.same(out1, out2))

module_dict = torch.nn.ModuleDict({"cat": torch.nn.Conv2d(1, 1, 1)})
pre = m(data)
cnt.clear()
torch._dynamo.reset()
cnt = torch._dynamo.testing.CompileCounter()
data = torch.randn(1)
module_dict1 = torch.nn.ModuleDict({"cat": torch.nn.Conv2d(1, 1, 1)})
module_dict2 = torch.nn.ModuleDict({"foo": torch.nn.Conv2d(1, 1, 1)})

m1 = M(module_dict1)
m2 = M(module_dict2)

def fn():
out1 = m1(data)
out2 = m2(data)
return out1

with torch._dynamo.optimize(cnt, nopython=False):
opt_pre = m(data)
m = M(module_dict)
data = torch.randn(1)
out1 = m(data)
opt_fn = torch.compile(fn, backend=cnt)
opt_fn()

out_post = m(data)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 1)
self.assertTrue(torch._dynamo.testing.same(pre, opt_pre))
self.assertTrue(torch._dynamo.testing.same(out1, out_post))
self.assertEqual(cnt.op_count, 3)
self.assertTrue(torch._dynamo.testing.same(fn(), opt_fn()))

# RuntimeError: SymIntArrayRef expected to contain only concrete integers
@expectedFailureDynamic
Expand Down Expand Up @@ -1763,7 +1766,9 @@ def forward(self, x):
]:
x = torch.randn(size)
mod(x)
self.assertEqual(cnts.frame_count, 2 * num_submodules)
# The extra recompilations happen because _wrapped_call_impl is now
# falling back to eager, and Dynamo is triggering on forward method.
self.assertEqual(cnts.frame_count, 3 * num_submodules)

def test_recursion(self):
mod = MockModule()
Expand Down Expand Up @@ -2136,15 +2141,15 @@ def output_modifying_hook(mod, inp, out):
loss_bwd = loss.backward()

self.assertEqual(eager_loss_bwd, loss_bwd)
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.frame_count, 1)

# Ndim change, recompile
pred = model(torch.randn([10, 10, 10]))
self.assertEqual(cnt.frame_count, 4)
self.assertEqual(cnt.frame_count, 2)

# Stable
pred = model(torch.randn([10, 10, 10]))
self.assertEqual(cnt.frame_count, 4)
self.assertEqual(cnt.frame_count, 2)

def test_dunder_call_explicitly(self):
# hooks should be triggered if explicit calling `__call__`
Expand Down
24 changes: 12 additions & 12 deletions test/functorch/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,33 +682,33 @@ def test_while_loop_simple_with_linear_compile_check_graph(self):
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor):
l_iter_ = L_iter_
l_x_ = L_x_
l__self___dec = self.L__self___dec
l__self___linear_weight = self.L__self___linear_weight
l__self___linear_bias = self.L__self___linear_bias
def forward(self, L_args_0_ : torch.Tensor, L_args_1_ : torch.Tensor):
l_args_0_ = L_args_0_
l_args_1_ = L_args_1_
l__fn___dec = self.L__fn___dec
l__fn___linear_weight = self.L__fn___linear_weight
l__fn___linear_bias = self.L__fn___linear_bias
cond_fn_0 = self.cond_fn_0
body_fn_0 = self.body_fn_0
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_iter_, l_x_), (l__self___dec, l__self___linear_bias, l__self___linear_weight)); cond_fn_0 = body_fn_0 = l_iter_ = l_x_ = l__self___dec = l__self___linear_bias = l__self___linear_weight = None
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_args_0_, l_args_1_), (l__fn___dec, l__fn___linear_bias, l__fn___linear_weight)); cond_fn_0 = body_fn_0 = l_args_0_ = l_args_1_ = l__fn___dec = l__fn___linear_bias = l__fn___linear_weight = None
getitem = while_loop[0]
getitem_1 = while_loop[1]; while_loop = None
return (getitem, getitem_1)""", # noqa: B950
)
self.assertExpectedInline(
gm.cond_fn_0.code.strip(),
"""\
def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn):
sub = l_iter_ - l__self___dec_cond_fn; l_iter_ = l__self___dec_cond_fn = None
def forward(self, l_args_0_, l_args_1_, l__fn___dec_cond_fn, l__fn___linear_bias_body_fn, l__fn___linear_weight_body_fn):
sub = l_args_0_ - l__fn___dec_cond_fn; l_args_0_ = l__fn___dec_cond_fn = None
gt = sub > 0; sub = None
return gt""", # noqa: B950
)
self.assertExpectedInline(
gm.body_fn_0.code.strip(),
"""\
def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn):
sub = l_iter_ - 1; l_iter_ = None
linear = torch._C._nn.linear(l_x_, l__self___linear_weight_body_fn, l__self___linear_bias_body_fn); l_x_ = l__self___linear_weight_body_fn = l__self___linear_bias_body_fn = None
def forward(self, l_args_0_, l_args_1_, l__fn___dec_cond_fn, l__fn___linear_bias_body_fn, l__fn___linear_weight_body_fn):
sub = l_args_0_ - 1; l_args_0_ = None
linear = torch._C._nn.linear(l_args_1_, l__fn___linear_weight_body_fn, l__fn___linear_bias_body_fn); l_args_1_ = l__fn___linear_weight_body_fn = l__fn___linear_bias_body_fn = None
return (sub, linear)""", # noqa: B950
)

Expand Down
5 changes: 4 additions & 1 deletion torch/_dynamo/compiled_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,10 @@ def end_capture(self, outputs):
"compiled_autograd_graph",
payload_fn=lambda: graph.print_readable(print_output=False),
)
return self.compiler_fn(graph)

# Fix for test_module_backward_hooks_eager
with torch._dynamo.trace_rules.dont_wrap_top_module():
return self.compiler_fn(graph)

def reorder_accumulate_grad_nodes(self):
"""
Expand Down
21 changes: 15 additions & 6 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,18 @@ def __init__(self, mod: torch.nn.Module, dynamo_ctx):

def _initialize(self):
# Do this stuff in constructor to lower overhead slightly
if isinstance(self._orig_mod.forward, types.MethodType) and trace_rules.check(
self._orig_mod.forward

if trace_rules.should_wrap_top_module() or (
isinstance(self._orig_mod.forward, types.MethodType)
and trace_rules.check(self._orig_mod.forward)
):
# This may be a torch.nn.* instance in trace_rules.py which
# won't trigger a frame evaluation workaround to add an extra
# frame we can capture
# TODO(export-team) - the second part of the or condition is
# required for export tests. We should fix them and remove it.
self.forward = self.dynamo_ctx(external_utils.wrap_inline(self._orig_mod))
else:
# Invoke hooks outside of dynamo then pickup the inner frame
# TODO(export-team/compiled-autograd) - This is because of test
# failures for export and compiled-autograd.
self.forward = self.dynamo_ctx(self._orig_mod.__call__)

if hasattr(self._orig_mod, "_initialize_hook"):
Expand Down Expand Up @@ -1215,7 +1218,13 @@ def result_capturing_wrapper(*graph_inputs):
automatic_dynamic_shapes=False,
capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True,
):
), trace_rules.dont_wrap_top_module():
# TODO(export-team) - discrepancy between torch.compile and
# torch.export because torch.compile is planning to inline the
# _call_impl (one level above forward) to inline hooks. But doing
# that for export breaks many tests because (1) tests are hardcoded
# to assume that tracing starts from forward, and (2) some
# discrepancies between strict and non strict mode.
opt_f = optimize_assert(
dynamo_normalization_capturing_compiler,
hooks=Hooks(
Expand Down
19 changes: 19 additions & 0 deletions torch/_dynamo/trace_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import unittest
import weakref
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, Callable, cast, Dict, List, Optional, Set, Union

np: Optional[types.ModuleType] = None
Expand Down Expand Up @@ -127,6 +128,24 @@
"""

_TLS = threading.local()


@contextmanager
def dont_wrap_top_module():
old = getattr(_TLS, "wrap_top_module", True)
_TLS.wrap_top_module = False
try:
yield False
finally:
_TLS.wrap_top_module = old


def should_wrap_top_module():
return getattr(_TLS, "wrap_top_module", True)


manual_torch_name_rule_map = {
"torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable,
"torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable,
Expand Down

0 comments on commit c08138e

Please sign in to comment.