Skip to content

Commit

Permalink
Revert "[dynamo][disable] Move disable impl to its own __call__ method (
Browse files Browse the repository at this point in the history
#125486)"

This reverts commit d474d79.

Reverted #125486 on behalf of https://github.com/izaitsevfb due to Fails internal tests, see D57216402 ([comment](#125486 (comment)))
  • Loading branch information
pytorchmergebot committed May 11, 2024
1 parent a174c53 commit d547074
Showing 1 changed file with 17 additions and 55 deletions.
72 changes: 17 additions & 55 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,7 @@ def __init__(self, mod: torch.nn.Module, dynamo_ctx):

def _initialize(self):
# Do this stuff in constructor to lower overhead slightly
if isinstance(self.dynamo_ctx, DisableContext):
# No need to check trace rules
self.forward = self.dynamo_ctx(self._orig_mod.__call__)
elif isinstance(self._orig_mod.forward, types.MethodType) and trace_rules.check(
if isinstance(self._orig_mod.forward, types.MethodType) and trace_rules.check(
self._orig_mod.forward
):
# This may be a torch.nn.* instance in trace_rules.py which
Expand Down Expand Up @@ -356,9 +353,14 @@ def get_compiler_config():
# User has wrapped the class with compile/disable decorator. Apply
# disable to init/call method.
cls_obj = fn
if isinstance(self, DisableContext):
# Disable on init is useful for reconstruction of bytecodes where we
# want to prevent Dynamo from tracing into the init function. Check
# test_reconstruction in test_model_output.py.
cls_obj.__init__ = self(cls_obj.__init__)
cls_obj.__call__ = self(cls_obj.__call__)
if issubclass(cls_obj, torch.nn.Module):
# NN module variable tracker directly inlines the _call_impl.
# NN module variable tracker directly inlines the _call_impl. Disable it.
cls_obj._call_impl = self(cls_obj._call_impl)
return cls_obj

Expand All @@ -381,8 +383,12 @@ def get_compiler_config():

callback = self.callback

is_jit_tracing = torch._C._is_tracing
is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing
if isinstance(self, DisableContext):
is_jit_tracing = always_false
is_fx_tracing = always_false
else:
is_jit_tracing = torch._C._is_tracing
is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing

@functools.wraps(fn)
def _fn(*args, **kwargs):
Expand Down Expand Up @@ -418,7 +424,10 @@ def _fn(*args, **kwargs):
cleanup()

# hooks to properly handle inlining
_fn._torchdynamo_inline = fn # type: ignore[attr-defined]
if isinstance(self, DisableContext):
_fn._torchdynamo_disable = True # type: ignore[attr-defined]
else:
_fn._torchdynamo_inline = fn # type: ignore[attr-defined]

# Save the function pointer to find the original callable while nesting
# of decorators.
Expand Down Expand Up @@ -510,53 +519,6 @@ class DisableContext(_TorchDynamoContext):
def __init__(self):
super().__init__(callback=None)

def __call__(self, fn):
# Earlier this code was in the base class _TorchDynamoContext. But we
# moved it here to have better code organization. For disable, we just
# want the callback to be None. We don't have to check trace_rules or
# create any wrapper.
fn = innermost_fn(fn)

if isinstance(fn, torch.nn.Module):
mod = fn
new_mod = OptimizedModule(mod, self)
new_mod._torchdynamo_orig_callable = mod.forward
return new_mod

if inspect.isclass(fn):
# User has wrapped the class with compile/disable decorator. Apply
# disable to init/call method.
cls_obj = fn
# Disable on init is useful for reconstruction of bytecodes where we
# want to prevent Dynamo from tracing into the init function. Check
# test_reconstruction in test_model_output.py.
cls_obj.__init__ = self(cls_obj.__init__)
cls_obj.__call__ = self(cls_obj.__call__)
if issubclass(cls_obj, torch.nn.Module):
# NN module variable tracker directly inlines the _call_impl. Disable it.
cls_obj._call_impl = self(cls_obj._call_impl)
return cls_obj

assert callable(fn)

callback = self.callback

@functools.wraps(fn)
def _fn(*args, **kwargs):
prior = set_eval_frame(callback)
try:
return fn(*args, **kwargs)
finally:
set_eval_frame(prior)

_fn._torchdynamo_disable = True # type: ignore[attr-defined]

# Save the function pointer to find the original callable while nesting
# of decorators.
_fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]

return _fn


def _optimize_catch_errors(
compile_fn,
Expand Down

0 comments on commit d547074

Please sign in to comment.