Skip to content

Commit

Permalink
[dynamo][compile-time][inlining-inbuilt-nn-modules] Manually implemen…
Browse files Browse the repository at this point in the history
…t nn.Module._call_impl (#129285)

# Compile time for eager backend
## AlbertForMaskedLM
No inlining - 3.65 seconds
Inlining on main - 7.48 seconds
Inlining + this PR - 2.86 seconds

## MobileBertForMaskedLM
No inlining - 26.90 seconds
Inlining on main - 48.21 seconds
Inlining + this PR - 24.25 seconds

Pull Request resolved: #129285
Approved by: https://github.com/jansel
ghstack dependencies: #129316, #129315
  • Loading branch information
anijain2305 authored and pytorchmergebot committed Jun 25, 2024
1 parent 514f927 commit c4dd752
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 0 deletions.
7 changes: 7 additions & 0 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,13 @@ def _load_global(self, inst):
source = GlobalSource(name)
self.push(VariableBuilder(self, source)(value))

@functools.cached_property
def nn_modules_globals_vt(self):
module_name = "torch.nn.modules.module"
module_source = self.import_source(module_name)
fglobals_value = importlib.import_module(module_name) # type: ignore[assignment]
return VariableBuilder(self, module_source)(fglobals_value)

def LOAD_GLOBAL(self, inst):
if sys.version_info >= (3, 11) and sys.version_info < (3, 13) and inst.arg % 2:
self.PUSH_NULL(inst)
Expand Down
9 changes: 9 additions & 0 deletions torch/_dynamo/variables/dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,15 @@ def __contains__(self, vt):
and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable)
)

def len(self):
return len(
[
x
for x in self.items.values()
if not isinstance(x, variables.DeletedVariable)
]
)

def reconstruct(self, codegen):
# instructions to load collections.OrderedDict if necessary
if self.user_cls is collections.OrderedDict:
Expand Down
14 changes: 14 additions & 0 deletions torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,20 @@ def call_hasattr(self, tx, name):
return variables.ConstantVariable.create(result)
return super().call_hasattr(tx, name)

def var_getattr(self, tx, name):
if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
return tx.output.side_effects.load_attr(self, name)

from .builder import SourcelessBuilder, VariableBuilder

attr_value = getattr(self.value, name)

if self.source:
new_source = AttrSource(self.source, name)
return VariableBuilder(tx, new_source)(attr_value)
else:
return SourcelessBuilder.create(tx, attr_value)


class TypingVariable(VariableTracker):
def __init__(self, value, **kwargs):
Expand Down
20 changes: 20 additions & 0 deletions torch/_dynamo/variables/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,26 @@ def call_function(
initialize_lazy_module(tx, mod, args, kwargs)
name = "_call_impl"
fn = getattr(self.value_type, name)

# Check if we can short circuit nn.Module._call_impl to the forward
# method. NB - This is done to reduce the compile time of Dynamo.
if fn is torch.nn.Module._call_impl and "forward" not in mod.__dict__:
forward_method = inspect.getattr_static(mod, "forward")
if isinstance(forward_method, types.FunctionType):
globals_vt = tx.nn_modules_globals_vt
if not (
self.var_getattr(tx, "_backward_hooks").realize().len()
or self.var_getattr(tx, "_backward_pre_hooks").realize().len()
or self.var_getattr(tx, "_forward_hooks").realize().len()
or self.var_getattr(tx, "_forward_pre_hooks").realize().len()
or globals_vt.var_getattr(tx, "_global_backward_pre_hooks").len()
or globals_vt.var_getattr(tx, "_global_backward_hooks").len()
or globals_vt.var_getattr(tx, "_global_forward_hooks").len()
or globals_vt.var_getattr(tx, "_global_forward_pre_hooks").len()
):
name = "forward"
fn = self.value_type.forward

if self.source:
source = AttrSource(AttrSource(self.source, "__class__"), name)
else:
Expand Down

0 comments on commit c4dd752

Please sign in to comment.