Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx] GraphModule copy top level attributes from root #45182

Closed
wants to merge 6 commits into from
11 changes: 11 additions & 0 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,17 @@ def forward(self, x):
traced = symbolic_trace(baz)
copied = copy.deepcopy(traced)

def test_deepcopy_preserve_attributes(self):
m = symbolic_trace(torch.nn.Conv2d(1, 1, 1))
m.attr = 3
self.assertTrue(hasattr(m, 'attr'))
m = copy.deepcopy(m)
self.assertTrue(hasattr(m, 'attr'))
self.assertTrue(hasattr(m, 'training'))
m = copy.copy(m)
self.assertTrue(hasattr(m, 'attr'))
self.assertTrue(hasattr(m, 'training'))

def test_unpack_list_better_error(self):
class SomeArgs(torch.nn.Module):
def forward(self, a, b):
Expand Down
13 changes: 11 additions & 2 deletions torch/fx/graph_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,19 @@ def __reduce__(self):
def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return GraphModule(fake_mod, self.graph)
# delete generated attributes before creating a new
# GraphModule
generated_attrs = ['code', '_graph', '_modules']
for attr in generated_attrs:
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
fake_mod.__dict__.pop(attr, None)
graph_module = GraphModule(fake_mod, self.graph)
graph_module.__dict__.update(fake_mod.__dict__)
return graph_module

def __copy__(self):
return GraphModule(self, self.graph)
graph_module = GraphModule(self, self.graph)
graph_module.__dict__ = self.__dict__
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
return graph_module

def __str__(self) -> str:
orig_str = super().__str__()
Expand Down