Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx] GraphModule copy top level attributes from root #45182

Closed
wants to merge 6 commits into from
8 changes: 8 additions & 0 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,14 @@ def forward(self, x):
traced = symbolic_trace(baz)
copied = copy.deepcopy(traced)

def test_deepcopy_preserve_attributes(self):
m = symbolic_trace(torch.nn.Conv2d(1, 1, 1))
m.attr = 3
self.assertTrue(hasattr(m, 'attr'))
m = copy.deepcopy(m)
self.assertTrue(hasattr(m, 'attr'))
self.assertTrue(hasattr(m, 'training'))

def test_unpack_list_better_error(self):
class SomeArgs(torch.nn.Module):
def forward(self, a, b):
Expand Down
3 changes: 1 addition & 2 deletions torch/fx/graph_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@ def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph):
"""
super().__init__()
if isinstance(root, torch.nn.Module):
if hasattr(root, 'training'):
self.training = root.training
self.__dict__ = root.__dict__
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
for node in graph.nodes:
if node.op in ['get_attr', 'call_module']:
assert isinstance(node.target, str)
Expand Down