Skip to content

Commit

Permalink
[dynamo][inline inbuilt nn modules] Do not inline for export (#124814)
Browse files Browse the repository at this point in the history
Pull Request resolved: #124814
Approved by: https://github.com/jansel
  • Loading branch information
anijain2305 authored and pytorchmergebot committed Apr 25, 2024
1 parent 94af62b commit 59a1f1f
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
8 changes: 6 additions & 2 deletions torch/_dynamo/mutation_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def check(cls, obj):
)


def is_dynamic_nn_module(obj):
def is_dynamic_nn_module(obj, is_export):
"""Check for nn.Modules() created dynamically or mutated"""
if isinstance(obj, torch.nn.Module) and "forward" in obj.__dict__:
# A monkey patched `.forward` indicates something wacky is going on
Expand All @@ -93,7 +93,11 @@ def is_dynamic_nn_module(obj):
return obj.torchdynamo_force_dynamic
if is_lazy_module(obj):
return False
if config.inline_inbuilt_nn_modules:
# For export, we will have to fix
# 1) Input signature problem because params are lifted as inputs
# 2) nn module stack info changes
# 3) adjust failing tests
if config.inline_inbuilt_nn_modules and not is_export:
return True
dyn = GenerationTracker.dynamic_classes.get(type(obj)) or GenerationTracker.check(
obj
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def register_attr_or_module(
*names,
**options,
):
if is_dynamic_nn_module(target):
if is_dynamic_nn_module(target, self.root_tx.export):
return variables.UnspecializedNNModuleVariable(target, **options)

options = dict(options)
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,7 @@ def wrap_module(self, value: torch.nn.Module):
and not config.allow_rnn
):
unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs")
if mutation_guard.is_dynamic_nn_module(value):
if mutation_guard.is_dynamic_nn_module(value, self.tx.export):
# created dynamically, don't specialize on it
self.install_guards(GuardBuilder.TYPE_MATCH)
result = UnspecializedNNModuleVariable(value, source=self.source)
Expand Down

0 comments on commit 59a1f1f

Please sign in to comment.