Skip to content

Commit

Permalink
Update on "[dynamo][compile-time][inlining-inbuilt-nn-modules] Manual…
Browse files Browse the repository at this point in the history
…ly implement nn.Module._call_impl"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
  • Loading branch information
anijain2305 committed Jun 22, 2024
2 parents 8ff70e6 + 24f3efe commit e435386
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 34 deletions.
7 changes: 7 additions & 0 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,13 @@ def _load_global(self, inst):
source = GlobalSource(name)
self.push(VariableBuilder(self, source)(value))

@functools.cached_property
def nn_modules_globals_vt(self):
module_name = "torch.nn.modules.module"
module_source = self.import_source(module_name)
fglobals_value = importlib.import_module(module_name) # type: ignore[assignment]
return VariableBuilder(self, module_source)(fglobals_value)

def LOAD_GLOBAL(self, inst):
if sys.version_info >= (3, 11) and sys.version_info < (3, 13) and inst.arg % 2:
self.PUSH_NULL(inst)
Expand Down
50 changes: 16 additions & 34 deletions torch/_dynamo/variables/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,42 +838,24 @@ def call_function(
name = "_call_impl"
fn = getattr(self.value_type, name)

from torch.nn.utils.parametrize import is_parametrized

if (
not tx.output.side_effects.has_pending_mutation(self)
and "forward" not in mod.__dict__
and not is_parametrized(mod)
and fn is torch.nn.Module._call_impl
):
# Check if we can short circuit nn.Module._call_impl to the forward
# method. NB - This is done to reduce the compile time of Dynamo.
if fn is torch.nn.Module._call_impl and "forward" not in mod.__dict__:
forward_method = inspect.getattr_static(mod, "forward")
if isinstance(forward_method, types.FunctionType):
# TODO(anijain2305) - Load this once - should give ~3 seconds on MobileBert
import importlib

module_name = "torch.nn.modules.module"
module_source = tx.import_source(module_name)
fglobals_value = importlib.import_module(module_name) # type: ignore[assignment]
from .builder import VariableBuilder

globals_vt = VariableBuilder(tx, module_source)(fglobals_value)

if not tx.output.side_effects.has_pending_mutation(globals_vt):
# Check if we can take the fast path of directly tracing forward.
if not (
self.var_getattr(tx, "_backward_hooks").realize().len()
or self.var_getattr(tx, "_backward_pre_hooks").realize().len()
or self.var_getattr(tx, "_forward_hooks").realize().len()
or self.var_getattr(tx, "_forward_pre_hooks").realize().len()
or globals_vt.var_getattr(
tx, "_global_backward_pre_hooks"
).len()
or globals_vt.var_getattr(tx, "_global_backward_hooks").len()
or globals_vt.var_getattr(tx, "_global_forward_hooks").len()
or globals_vt.var_getattr(tx, "_global_forward_pre_hooks").len()
):
name = "forward"
fn = self.value_type.forward
globals_vt = tx.nn_modules_globals_vt
if not (
self.var_getattr(tx, "_backward_hooks").realize().len()
or self.var_getattr(tx, "_backward_pre_hooks").realize().len()
or self.var_getattr(tx, "_forward_hooks").realize().len()
or self.var_getattr(tx, "_forward_pre_hooks").realize().len()
or globals_vt.var_getattr(tx, "_global_backward_pre_hooks").len()
or globals_vt.var_getattr(tx, "_global_backward_hooks").len()
or globals_vt.var_getattr(tx, "_global_forward_hooks").len()
or globals_vt.var_getattr(tx, "_global_forward_pre_hooks").len()
):
name = "forward"
fn = self.value_type.forward

if self.source:
source = AttrSource(AttrSource(self.source, "__class__"), name)
Expand Down

0 comments on commit e435386

Please sign in to comment.