Skip to content

Commit

Permalink
[dynamo] Support __getitem__ on NNModuleVariable __dict__
Browse files Browse the repository at this point in the history
ghstack-source-id: 708f0d79f6e60fd2f0754267b9c6fa2d5358d106
Pull Request resolved: #126956
  • Loading branch information
anijain2305 committed May 30, 2024
1 parent 7a994c9 commit 9b0e8a7
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 4 deletions.
12 changes: 8 additions & 4 deletions torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,11 +679,13 @@ def call_method(
and self.name == "__dict__"
and not kwargs
and args[0].is_python_constant()
and isinstance(self.obj, variables.UserDefinedObjectVariable)
and isinstance(
self.obj,
(variables.UserDefinedObjectVariable, variables.NNModuleVariable),
)
):
obj = self.obj
key = args[0].as_python_constant()
obj._check_for_getattribute()
if obj.has_key_in_generic_dict(tx, key):
# redirect to var_getattr on the original obj
return obj.var_getattr(tx, key)
Expand All @@ -701,11 +703,13 @@ def call_method(
and len(args) == 1
and args[0].is_python_constant()
and not kwargs
and isinstance(self.obj, variables.UserDefinedObjectVariable)
and isinstance(
self.obj,
(variables.UserDefinedObjectVariable, variables.NNModuleVariable),
)
):
obj = self.obj
key = args[0].as_python_constant()
obj._check_for_getattribute()
if obj.has_key_in_generic_dict(tx, key):
return variables.ConstantVariable(True)
else:
Expand Down
16 changes: 16 additions & 0 deletions torch/_dynamo/variables/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,19 @@ def convert_to_unspecialized(self, tx):
GenerationTracker.mark_class_dynamic(type(mod))
raise UnspecializeRestartAnalysis

def has_key_in_generic_dict(self, tx, key):
base = tx.output.get_submodule(self.module_key)

if object_has_getattribute(base):
unimplemented("NNModuleVariable with custom __getattribute__")

if tx.output.side_effects.has_pending_mutation_of_attr(self, key):
mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True)
return not isinstance(mutated_attr, variables.DeletedVariable)

base_dict = object.__getattribute__(base, "__dict__")
return key in base_dict

def _custom_getattr_fallback(self, base, tx, name, options):
"""Check for a __getattr__ and handle it specially if it is implemented"""
if object_has_getattribute(base):
Expand Down Expand Up @@ -223,6 +236,9 @@ def var_getattr(self, tx, name):
if not self.source:
unimplemented("GETATTR with no source")

if name == "__dict__":
return variables.GetAttrVariable(self, name, source=source)

if name in base_dict:
subobj = base_dict[name]
elif (
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/variables/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,7 @@ def _getattr_static(self, name):
return subobj

def has_key_in_generic_dict(self, tx, key):
self._check_for_getattribute()
if tx.output.side_effects.has_pending_mutation_of_attr(self, key):
mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True)
return not isinstance(mutated_attr, variables.DeletedVariable)
Expand Down

0 comments on commit 9b0e8a7

Please sign in to comment.