Skip to content

Commit

Permalink
[dynamo] Allow inlining oh hooks for the top module
Browse files Browse the repository at this point in the history
ghstack-source-id: d0e3ef1e8145ffd5479328c2db1989791e3d47b7
Pull Request resolved: #124501
  • Loading branch information
anijain2305 committed Apr 19, 2024
1 parent bad8d25 commit 3483d56
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 7 deletions.
19 changes: 12 additions & 7 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,14 @@ def __init__(self, mod: torch.nn.Module, dynamo_ctx):

def _initialize(self):
# Do this stuff in constructor to lower overhead slightly
if isinstance(self._orig_mod.forward, types.MethodType) and trace_rules.check(
self._orig_mod.forward
):
# This may be a torch.nn.* instance in trace_rules.py which
# won't trigger a frame evaluation workaround to add an extra
# frame we can capture

if trace_rules.should_wrap_top_module():
self.forward = self.dynamo_ctx(external_utils.wrap_inline(self._orig_mod))
else:
# Invoke hooks outside of dynamo then pickup the inner frame
# TODO(export-team) - This part is only run for export. Remove the
# if condition when export has made adjustments to account for
# wrapping top module.
self.forward = self.dynamo_ctx(self._orig_mod.__call__)

if hasattr(self._orig_mod, "_initialize_hook"):
Expand Down Expand Up @@ -1215,7 +1214,13 @@ def result_capturing_wrapper(*graph_inputs):
automatic_dynamic_shapes=False,
capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True,
):
), trace_rules.dont_wrap_top_module():
# TODO(export-team) - discrepancy between torch.compile and
# torch.export because torch.compile is planning to inline the
# _call_impl (one level above forward) to inline hooks. But doing
# that for export breaks many tests because (1) tests are hardcoded
# to assume that tracing starts from forward, and (2) some
# discrepancies between strict and non strict mode.
opt_f = optimize_assert(
dynamo_normalization_capturing_compiler,
hooks=Hooks(
Expand Down
18 changes: 18 additions & 0 deletions torch/_dynamo/trace_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import unittest
import weakref
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, Callable, cast, Dict, List, Optional, Set, Union

np: Optional[types.ModuleType] = None
Expand Down Expand Up @@ -127,6 +128,23 @@
"""

_TLS = threading.local()

@contextmanager
def dont_wrap_top_module():
old = getattr(_TLS, "wrap_top_module", True)
_TLS.wrap_top_module = False
try:
yield False
finally:
_TLS.wrap_top_module = old


def should_wrap_top_module():
return getattr(_TLS, "wrap_top_module", True)


manual_torch_name_rule_map = {
"torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable,
"torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable,
Expand Down

0 comments on commit 3483d56

Please sign in to comment.