Skip to content

Commit

Permalink
Remote the list for the attributes that will be ignored for pickling (#…
Browse files Browse the repository at this point in the history
…58345)

Summary:
Pull Request resolved: #58345

1. Add a sanity check to make sure any new attribute added to the constructor should be added to either `_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING` pr `_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING`.
2. Update some comments and warning -- now if a new attribute is added after the construction, it will not be pickled. Previously it will trigger a runtime error, which is hard for unit test (one worker hit the runtime error, but the other worker will cause timeout).
Context: #58019 (comment)
ghstack-source-id: 129070358

Test Plan: unit test

Reviewed By: rohan-varma

Differential Revision: D28460744

fbshipit-source-id: 8028186fc447c88fbf2bf57f5c5d321f42ba54ed
  • Loading branch information
Yi Wang authored and facebook-github-bot committed May 15, 2021
1 parent 9def776 commit 2436377
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 9 deletions.
25 changes: 20 additions & 5 deletions torch/distributed/nn/api/remote_module.py
Expand Up @@ -47,9 +47,10 @@

_SerializedRemoteModule = collections.namedtuple("_SerializedRemoteModule", _REMOTE_MODULE_PICKLED_ATTRIBUTES) # type: ignore[misc]

# These attributes are mostly from RemoteModule's parent class and are not pickled.
# A new attribute of RemoteModule must be either in _REMOTE_MODULE_PICKLED_ATTRIBUTES
# These attributes are mostly from RemoteModule's parent class and are intentionally not pickled.
# A new attribute of RemoteModule should be either in _REMOTE_MODULE_PICKLED_ATTRIBUTES
# or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING.
# Otherwise, it will not be pickled.
_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING = (
"training",
"_parameters",
Expand Down Expand Up @@ -84,7 +85,10 @@ def _create_module(module_cls, args, kwargs, device):
module.to(device)
return module

def _create_module_with_interface(module_cls, args, kwargs, device, module_interface_cls):

def _create_module_with_interface(
module_cls, args, kwargs, device, module_interface_cls
):
module = _create_module(module_cls, args, kwargs, device)
if module_interface_cls is not None:
module = torch.jit.script(module)
Expand Down Expand Up @@ -270,6 +274,17 @@ def __init__(
method = torch.jit.export(method)
setattr(self, method_name, types.MethodType(method, self))

# Sanity check: whether to be pickled must be explicitly defined for every attribute.
for k in self.__dict__.keys():
if (
k not in _REMOTE_MODULE_PICKLED_ATTRIBUTES
and k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING
):
raise AttributeError(
"Attribute {} must be either in ``_REMOTE_MODULE_PICKLED_ATTRIBUTES`` or "
"``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``.".format(k)
)

def remote_parameters(self, recurse: bool = True) -> List[rpc.RRef[Parameter]]:
"""
Returns a list of :class:`~torch.distributed.rpc.RRef` pointing to the
Expand Down Expand Up @@ -544,8 +559,8 @@ def _remote_module_reducer(remote_module):
elif k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING: # type: ignore[attr-defined]
print(
"The new attribute ``{}`` of RemoteModule is ignored during RPC pickling. "
"To pickle this attribute, it must be either in ``_REMOTE_MODULE_PICKLED_ATTRIBUTES`` or "
"``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``.".format(k),
"To pickle this attribute, please add it to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES``. "
"Otherwise, please explicitly add it to ``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``.".format(k),
file=sys.stderr,
)

Expand Down
12 changes: 8 additions & 4 deletions torch/testing/_internal/distributed/nn/api/remote_module_test.py
Expand Up @@ -406,14 +406,18 @@ def hook(module, grad_input, grad_output):
remote_module.extra_repr()

@dist_utils.dist_init
def test_send_remote_module_with_a_new_attribute_ignored_over_the_wire(self):
def test_send_remote_module_with_a_new_attribute_not_pickled_over_the_wire(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)

# If add a new attribute is added to this RemoteModule, which will be sent over the wire by RPC,
# this new field must be added to either _REMOTE_MODULE_PICKLED_ATTRIBUTES or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING
# to avoid runtime error.
# If a new attribute is added to this RemoteModule after the initialization,
# and it will be sent over the wire by RPC,
# this new field will not be pickled, because it's not specified in _REMOTE_MODULE_PICKLED_ATTRIBUTES.
# Note that adding a new attribute out of constructor should rarely happen.
# If a new attribute is added to RemoteModule constructor,
# there is a sanity check to enforce developers to add this attribute to either
# _REMOTE_MODULE_PICKLED_ATTRIBUTES or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING.
for remote_module in self._create_remote_module_iter(
dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
):
Expand Down

0 comments on commit 2436377

Please sign in to comment.