Skip to content

Commit

Permalink
Check for none for NNModuleVariable.__module__ (#93326)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #93326

Test Plan: CI

Differential Revision: D42869182

fbshipit-source-id: 153788f46f6b504ee9c5d0623510d017f0789ed1
  • Loading branch information
tugsbayasgalan authored and facebook-github-bot committed Jan 31, 2023
1 parent 1a45431 commit ec65a13
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions torch/_dynamo/variables/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,14 +278,12 @@ def python_type(self):
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
if (
isinstance(self.obj, variables.NNModuleVariable)
and getattr(self.fn, "__module__", "").startswith("torch.nn.")
or self.is_constant
):
return self.obj.call_method(
tx, self.fn.__name__, args, kwargs, constant=self.is_constant
).add_options(self)
if isinstance(self.obj, variables.NNModuleVariable):
module_attr = getattr(self.fn, "__module__", "")
if module_attr is not None and module_attr.startswith("torch.nn.") or self.is_constant:
return self.obj.call_method(
tx, self.fn.__name__, args, kwargs, constant=self.is_constant
).add_options(self)
return super().call_function(tx, args, kwargs)

def num_parameters(self):
Expand Down

0 comments on commit ec65a13

Please sign in to comment.