Skip to content

Commit

Permalink
[fx] GraphModule copy top level attributes from root
Browse files Browse the repository at this point in the history
Summary:
Previously only the attributes that's used in the graph is copied in the constructor
of GraphModule, this PR adds the support for copying all attributes by overriding `__dict__`

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 3466ebdd4cafbca5ed305618307d6fb09b36fd89
Pull Request resolved: #45182
  • Loading branch information
jerryzh168 committed Sep 23, 2020
1 parent 2a37f3f commit 09a7758
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
8 changes: 8 additions & 0 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,14 @@ def forward(self, x):
traced = symbolic_trace(baz)
copied = copy.deepcopy(traced)

def test_deepcopy_preserve_attributes(self):
m = symbolic_trace(torch.nn.Conv2d(1, 1, 1))
m.attr = 3
self.assertTrue(hasattr(m, 'attr'))
m = copy.deepcopy(m)
self.assertTrue(hasattr(m, 'attr'))
self.assertTrue(hasattr(m, 'training'))

def test_unpack_list_better_error(self):
class SomeArgs(torch.nn.Module):
def forward(self, a, b):
Expand Down
4 changes: 3 additions & 1 deletion torch/fx/graph_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@ def __reduce__(self):
def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return GraphModule(fake_mod, self.graph)
graph_module = GraphModule(fake_mod, self.graph)
graph_module.__dict__ = fake_mod.__dict__
return graph_module

def __copy__(self):
return GraphModule(self, self.graph)
Expand Down

0 comments on commit 09a7758

Please sign in to comment.