Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RPC Framework] Support remote device format "<workername>/<device>" #46773

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 21 additions & 17 deletions torch/distributed/nn/api/remote_module.py
Expand Up @@ -64,8 +64,7 @@ def _raise_not_supported(name):
class _RemoteModule(nn.Module):
def __init__(
self,
on: str,
device: torch.device,
remote_device: str,
module_cls: nn.Module,
args: Tuple = None,
kwargs: Dict[str, Any] = None,
Expand Down Expand Up @@ -100,8 +99,9 @@ def __init__(
``def forward_async(input: Tensor) -> Future[Tensor]:``.

Arguments:
on (str or WorkerInfo): id or name of the destination worker.
device (torch.device): Device on the destination worker where we‘d like to place this module.
remote_device (str): Device on the destination worker where we‘d like to place this module.
The format should be "<workername>/<device>", where the device field can be parsed as torch.device type.
E.g., "trainer0/cpu", "ps0/cuda:0".
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
module_cls (nn.Module): For example,
>>> class MyModule(nn.Module):
>>> def forward(input):
Expand Down Expand Up @@ -132,7 +132,7 @@ def __init__(
>>>
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> remote_linear_module = RemoteModule(
>>> "worker1", "cpu", nn.Linear, args=(20, 30),
>>> "worker1/cpu", nn.Linear, args=(20, 30),
>>> )
>>> input = torch.randn(128, 20)
>>> ret_fut = remote_linear_module.forward_async(input)
Expand All @@ -155,18 +155,22 @@ def __init__(
args = args if args is not None else ()
kwargs = kwargs if kwargs is not None else {}

self.on = on
[self.on, self.device] = remote_device.split("/")
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
wayi1 marked this conversation as resolved.
Show resolved Hide resolved

if _module_interface_cls is not None:
# Users reply on this field to know if this generated RemoteModule is TorchScript-able.
self.is_scriptable = True

# Instantiate template on remote side.
fut = rpc.rpc_async(on, _instantiate_template, (_module_interface_cls,))
fut = rpc.rpc_async(
self.on, _instantiate_template, (_module_interface_cls,)
)

# Instantiate template on local side.
generated_module = instantiator.instantiate_scriptable_remote_module_template(
_module_interface_cls
generated_module = (
instantiator.instantiate_scriptable_remote_module_template(
_module_interface_cls
)
)
generated_methods = generated_module._generated_methods

Expand All @@ -178,9 +182,9 @@ def __init__(

# Create the module on the remote side.
self.module_rref = rpc.rpc_sync(
on,
self.on,
_create_module,
(module_cls, args, kwargs, device, _module_interface_cls),
(module_cls, args, kwargs, self.device, _module_interface_cls),
)

# Install generated methods.
Expand Down Expand Up @@ -329,8 +333,9 @@ class RemoteModule(_RemoteModule):
``def forward_async(input: Tensor) -> Future[Tensor]:``.

Arguments:
to (str or WorkerInfo): id or name of the destination worker.
device (torch.device): Device on the destination worker where we‘d like to place this module.
remote_device (str): Device on the destination worker where we‘d like to place this module.
The format should be "<workername>/<device>", where the device field can be parsed as torch.device type.
E.g., "trainer0/cpu", "ps0/cuda:0".
module_cls (nn.Module): For example,
>>> class MyModule(nn.Module):
>>> def forward(input):
Expand All @@ -357,7 +362,7 @@ class RemoteModule(_RemoteModule):
>>>
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> remote_linear_module = RemoteModule(
>>> "worker1", nn.Linear, args=(20, 30),
>>> "worker1/cpu", nn.Linear, args=(20, 30),
>>> )
>>> input = torch.randn(128, 20)
>>> ret_fut = remote_linear_module.forward_async(input)
Expand All @@ -374,10 +379,9 @@ class RemoteModule(_RemoteModule):

def __init__(
self,
on: str,
device: torch.device,
remote_device: str,
module_cls: nn.Module,
args: Tuple = None,
kwargs: Dict[str, Any] = None,
):
super().__init__(on, device, module_cls, args, kwargs)
super().__init__(remote_device, module_cls, args, kwargs)
Expand Up @@ -20,7 +20,7 @@
skip_if_lt_x_gpu,
skip_if_rocm,
)
from torch.testing._internal.dist_utils import dist_init, INIT_METHOD_TEMPLATE
from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
Expand Down Expand Up @@ -620,7 +620,7 @@ def test_ddp_dist_autograd_local_vs_remote(self):
)

remote_layer1 = RemoteModule(
"worker0", device="cpu", module_cls=nn.Linear, args=(10, 5, False)
remote_device="worker0/cpu", module_cls=nn.Linear, args=(10, 5, False)
)
layer1 = nn.Linear(10, 5, False)
# Start with the same parameters for remote and local
Expand Down Expand Up @@ -667,7 +667,7 @@ def test_ddp_dist_autograd_local_vs_remote_gpu(self):
)

remote_layer1 = RemoteModule(
"worker0", device="cpu", module_cls=nn.Linear, args=(10, 7, False)
remote_device="worker0/cpu", module_cls=nn.Linear, args=(10, 7, False)
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
)
layer1 = nn.Linear(10, 7, False)
# Start with the same parameters for remote and local
Expand All @@ -677,7 +677,7 @@ def test_ddp_dist_autograd_local_vs_remote_gpu(self):
ddp_layer2 = DistributedDataParallel(layer2, device_ids=[self.rank])

remote_layer3 = RemoteModule(
"worker0", device="cpu", module_cls=nn.Linear, args=(5, 3, False)
remote_device="worker0/cpu", module_cls=nn.Linear, args=(5, 3, False)
)
layer3 = nn.Linear(5, 3, False)
# Start with the same parameters for remote and local
Expand Down
11 changes: 6 additions & 5 deletions torch/testing/_internal/distributed/nn/api/remote_module_test.py
Expand Up @@ -84,17 +84,17 @@ def _create_remote_module_iter(dst_worker_name, device="cpu", modes=None):

args = (1,)
kwargs = dict(first_kwarg=2)
remote_device = "{}/{}".format(dst_worker_name, device)

if ModuleCreationMode.MODULE_CTOR in modes:
remote_module = RemoteModule(
dst_worker_name, device, MyModule, args, kwargs
remote_device, MyModule, args, kwargs
)
yield remote_module

if ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE in modes:
remote_module = _RemoteModule(
dst_worker_name,
device,
remote_device,
create_scripted_module,
args,
kwargs,
Expand All @@ -108,20 +108,21 @@ def test_bad_module(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
remote_device = "{}/cpu".format(dst_worker_name)
args = (1,)
kwargs = dict(first_kwarg=2)

with self.assertRaisesRegex(
ValueError,
r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of <class nn.Module>,",
):
RemoteModule(dst_worker_name, "cpu", BadModule, args, kwargs)
RemoteModule(remote_device, BadModule, args, kwargs)
wayi1 marked this conversation as resolved.
Show resolved Hide resolved

with self.assertRaisesRegex(
ValueError,
r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of <class nn.Module>,",
):
RemoteModule(dst_worker_name, "cpu", BadModule, args, kwargs)
RemoteModule(remote_device, BadModule, args, kwargs)

@dist_utils.dist_init
def test_forward_async(self):
Expand Down