Skip to content

Commit

Permalink
[fx] GraphModule copy top level attributes from root
Browse files Browse the repository at this point in the history
Summary:
Previously only the attributes that's used in the graph is copied in the constructor
of GraphModule, this PR adds the support for copying all attributes by overriding `__dict__`

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: c9a0c7f570b44e99a83e46ccdc5e90139966ad70
Pull Request resolved: #45182
  • Loading branch information
jerryzh168 committed Sep 23, 2020
1 parent 2a37f3f commit 48291de
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
8 changes: 8 additions & 0 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,14 @@ def forward(self, x):
traced = symbolic_trace(baz)
copied = copy.deepcopy(traced)

def test_deepcopy_preserve_attributes(self):
m = symbolic_trace(torch.nn.Conv2d(1, 1, 1))
m.attr = 3
self.assertTrue(hasattr(m, 'attr'))
m = copy.deepcopy(m)
self.assertTrue(hasattr(m, 'attr'))
self.assertTrue(hasattr(m, 'training'))

def test_unpack_list_better_error(self):
class SomeArgs(torch.nn.Module):
def forward(self, a, b):
Expand Down
3 changes: 1 addition & 2 deletions torch/fx/graph_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@ def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph):
"""
super().__init__()
if isinstance(root, torch.nn.Module):
if hasattr(root, 'training'):
self.training = root.training
self.__dict__ = root.__dict__
for node in graph.nodes:
if node.op in ['get_attr', 'call_module']:
assert isinstance(node.target, str)
Expand Down

0 comments on commit 48291de

Please sign in to comment.