Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx] GraphModule copy top level attributes from root #45182

Closed
wants to merge 6 commits into from
11 changes: 11 additions & 0 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,17 @@ def forward(self, x):
traced = symbolic_trace(baz)
copied = copy.deepcopy(traced)

def test_deepcopy_preserve_attributes(self):
m = symbolic_trace(torch.nn.Conv2d(1, 1, 1))
m.attr = 3
self.assertTrue(hasattr(m, 'attr'))
m = copy.deepcopy(m)
self.assertTrue(hasattr(m, 'attr'))
self.assertTrue(hasattr(m, 'training'))
m = copy.copy(m)
self.assertTrue(hasattr(m, 'attr'))
self.assertTrue(hasattr(m, 'training'))

def test_unpack_list_better_error(self):
class SomeArgs(torch.nn.Module):
def forward(self, a, b):
Expand Down
14 changes: 12 additions & 2 deletions torch/fx/graph_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,20 @@ def __reduce__(self):
def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return GraphModule(fake_mod, self.graph)
graph_module = GraphModule(fake_mod, self.graph)
# skip overwriting generated attributes
for attr in ['code', '_graph', '_modules']:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These fields are not complete. For instance _parameters and _buffers are not included and copying them might break the nn.Module similar to _modules. Any attempt to list parameters here would be very fragile because changes to nn.Module would affect what this list needs to be with no reliable way to inform an editor of nn.Module that this list also needs to be modified.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually I'm not sure why I included _modules we probably don't need that. This is just for skipping attributes generated when creating a GraphModule.

fake_mod.__dict__.pop(attr, None)
graph_module.__dict__.update(fake_mod.__dict__)
return graph_module

def __copy__(self):
return GraphModule(self, self.graph)
graph_module = GraphModule(self, self.graph)
# skip overwriting generated attributes
for attr in self.__dict__:
if attr not in ['code', '_graph', '_modules']:
graph_module.__dict__[attr] = self.__dict__[attr]
return graph_module

def __str__(self) -> str:
orig_str = super().__str__()
Expand Down