Skip to content

Commit

Permalink
Update base for Update on "[wip] Improve exception support"
Browse files Browse the repository at this point in the history
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
  • Loading branch information
anijain2305 committed May 23, 2024
1 parent dc34438 commit e2bc380
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def var_getattr(self, tx, name):
# For attributes (not methods) that were not caught in the special handling above,
# (e.g. tensor.real), we handle these generically, assuming that the output type is
# a tensor.
if result is None and name != "grad":
if result is None and name not in ("grad", "grad_fn"):

def try_generic_attr_handling():
from .builder import wrap_fx_proxy
Expand Down
4 changes: 4 additions & 0 deletions torch/_dynamo/variables/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,10 @@ def var_getattr(self, tx, name):
if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
return tx.output.side_effects.load_attr(self, name)

if name == "__dict__":
options = {"source": source}
return variables.GetAttrVariable(self, name, **options)

try:
subobj = self._getattr_static(name)
except AttributeError:
Expand Down

0 comments on commit e2bc380

Please sign in to comment.