Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RPC Framework] Support remote device format "<workername>/<device>" #46773

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
41 changes: 24 additions & 17 deletions torch/distributed/nn/api/remote_module.py
Expand Up @@ -17,6 +17,7 @@
import torch.distributed.rpc as rpc
from torch import Tensor, device, dtype, nn
from torch.distributed.nn.jit import instantiator
from torch.distributed.rpc.utils import _parse_remote_device
from torch.nn.parameter import Parameter
from torch.utils.hooks import RemovableHandle

Expand Down Expand Up @@ -64,8 +65,7 @@ def _raise_not_supported(name):
class _RemoteModule(nn.Module):
def __init__(
self,
on: str,
device: torch.device,
remote_device: str,
module_cls: nn.Module,
args: Tuple = None,
kwargs: Dict[str, Any] = None,
Expand Down Expand Up @@ -100,8 +100,10 @@ def __init__(
``def forward_async(input: Tensor) -> Future[Tensor]:``.

Arguments:
on (str or WorkerInfo): id or name of the destination worker.
device (torch.device): Device on the destination worker where we‘d like to place this module.
remote_device (str): Device on the destination worker where we‘d like to place this module.
The format should be "<workername>/<device>", where the device field can be parsed as torch.device type.
E.g., "trainer0/cpu", "ps0/cuda:0".
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
In addition, the device field can be optional, and the default value is "cpu".
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
module_cls (nn.Module): For example,
>>> class MyModule(nn.Module):
>>> def forward(input):
Expand Down Expand Up @@ -132,7 +134,7 @@ def __init__(
>>>
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> remote_linear_module = RemoteModule(
>>> "worker1", "cpu", nn.Linear, args=(20, 30),
>>> "worker1/cpu", nn.Linear, args=(20, 30),
>>> )
>>> input = torch.randn(128, 20)
>>> ret_fut = remote_linear_module.forward_async(input)
Expand All @@ -155,18 +157,22 @@ def __init__(
args = args if args is not None else ()
kwargs = kwargs if kwargs is not None else {}

self.on = on
self.on, self.device = _parse_remote_device(remote_device)

if _module_interface_cls is not None:
# Users reply on this field to know if this generated RemoteModule is TorchScript-able.
self.is_scriptable = True

# Instantiate template on remote side.
fut = rpc.rpc_async(on, _instantiate_template, (_module_interface_cls,))
fut = rpc.rpc_async(
self.on, _instantiate_template, (_module_interface_cls,)
)

# Instantiate template on local side.
generated_module = instantiator.instantiate_scriptable_remote_module_template(
_module_interface_cls
generated_module = (
instantiator.instantiate_scriptable_remote_module_template(
_module_interface_cls
)
)
generated_methods = generated_module._generated_methods

Expand All @@ -178,9 +184,9 @@ def __init__(

# Create the module on the remote side.
self.module_rref = rpc.rpc_sync(
on,
self.on,
_create_module,
(module_cls, args, kwargs, device, _module_interface_cls),
(module_cls, args, kwargs, self.device, _module_interface_cls),
)

# Install generated methods.
Expand Down Expand Up @@ -329,8 +335,10 @@ class RemoteModule(_RemoteModule):
``def forward_async(input: Tensor) -> Future[Tensor]:``.

Arguments:
to (str or WorkerInfo): id or name of the destination worker.
device (torch.device): Device on the destination worker where we‘d like to place this module.
remote_device (str): Device on the destination worker where we‘d like to place this module.
The format should be "<workername>/<device>", where the device field can be parsed as torch.device type.
E.g., "trainer0/cpu", "ps0/cuda:0".
In addition, the device field can be optional, and the default value is "cpu".
module_cls (nn.Module): For example,
>>> class MyModule(nn.Module):
>>> def forward(input):
Expand All @@ -357,7 +365,7 @@ class RemoteModule(_RemoteModule):
>>>
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> remote_linear_module = RemoteModule(
>>> "worker1", nn.Linear, args=(20, 30),
>>> "worker1/cpu", nn.Linear, args=(20, 30),
>>> )
>>> input = torch.randn(128, 20)
>>> ret_fut = remote_linear_module.forward_async(input)
Expand All @@ -374,10 +382,9 @@ class RemoteModule(_RemoteModule):

def __init__(
self,
on: str,
device: torch.device,
remote_device: str,
module_cls: nn.Module,
args: Tuple = None,
kwargs: Dict[str, Any] = None,
):
super().__init__(on, device, module_cls, args, kwargs)
super().__init__(remote_device, module_cls, args, kwargs)
37 changes: 37 additions & 0 deletions torch/distributed/rpc/utils.py
@@ -0,0 +1,37 @@
def _parse_remote_device(remote_device: str):
r"""
Parses the remote device.

Arguments:
remote_device (str): Device on the destination worker where we‘d like to place this module.
The format should be "<workername>/<device>", where the device field can be parsed as torch.device type.
E.g., "trainer0/cpu", "ps0/cuda:0".
In addition, the device field can be optional, and the default value is "cpu".

Returns:
A workername and a device.
"""
fields = remote_device.split("/")
if len(fields) == 2:
[on, device] = fields
elif len(fields) == 1:
on = fields[0]
device = "cpu"
else:
raise RuntimeError(
"Could not parse remote_device: {}. The valid format is '<workername>/<device>'".format(
remote_device
)
)

# Since the workername in the input remote device won't be validated until the created remote module is executed,
# only do some very basic sanity check on workername at the module creation time.
# As currently there is no regex to describe the format of workername, just check whether the workername is empty.
if not on:
raise RuntimeError(
"The workername in remote_device '{}' cannot be empty. The valid format is '<workername>/<device>'".format(
remote_device
)
)

return on, device
65 changes: 34 additions & 31 deletions torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py
Expand Up @@ -20,7 +20,7 @@
skip_if_lt_x_gpu,
skip_if_rocm,
)
from torch.testing._internal.dist_utils import dist_init, INIT_METHOD_TEMPLATE
from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
Expand Down Expand Up @@ -619,35 +619,38 @@ def test_ddp_dist_autograd_local_vs_remote(self):
rank=self.rank,
)

remote_layer1 = RemoteModule(
"worker0", device="cpu", module_cls=nn.Linear, args=(10, 5, False)
)
layer1 = nn.Linear(10, 5, False)
# Start with the same parameters for remote and local
layer1.weight = remote_layer1.module_rref.to_here().weight

# Run local case.
layer2 = nn.Linear(5, 1)
inputs = torch.rand((10, 10))
ddp_model = DistributedDataParallel(layer2)
loss = ddp_model(layer1(inputs)).sum()
loss.backward()

# Run remote case.
with dist_autograd.context() as context_id:
loss = ddp_model(remote_layer1(inputs)).sum()
dist_autograd.backward(context_id, [loss])
grads_dict = dist_autograd.get_gradients(context_id)
dist.barrier()
self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight])
self.assertEqual(
layer1.weight.grad,
rpc.rpc_sync(
"worker0",
DdpComparisonTest.get_remote_grads,
args=(remote_layer1.module_rref, context_id),
),
# Use two different remote device input string, w/ and w/o the default
# device string "cpu", respectively.
for remote_device in ["worker0/cpu", "worker0"]:
remote_layer1 = RemoteModule(
remote_device=remote_device, module_cls=nn.Linear, args=(10, 5, False)
)
layer1 = nn.Linear(10, 5, False)
# Start with the same parameters for remote and local
layer1.weight = remote_layer1.module_rref.to_here().weight

# Run local case.
layer2 = nn.Linear(5, 1)
inputs = torch.rand((10, 10))
ddp_model = DistributedDataParallel(layer2)
loss = ddp_model(layer1(inputs)).sum()
loss.backward()

# Run remote case.
with dist_autograd.context() as context_id:
loss = ddp_model(remote_layer1(inputs)).sum()
dist_autograd.backward(context_id, [loss])
grads_dict = dist_autograd.get_gradients(context_id)
dist.barrier()
self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight])
self.assertEqual(
layer1.weight.grad,
rpc.rpc_sync(
"worker0",
DdpComparisonTest.get_remote_grads,
args=(remote_layer1.module_rref, context_id),
),
)

@skip_if_lt_x_gpu(NUM_TRAINERS)
@requires_nccl()
Expand All @@ -667,7 +670,7 @@ def test_ddp_dist_autograd_local_vs_remote_gpu(self):
)

remote_layer1 = RemoteModule(
"worker0", device="cpu", module_cls=nn.Linear, args=(10, 7, False)
remote_device="worker0/cpu", module_cls=nn.Linear, args=(10, 7, False)
wayi1 marked this conversation as resolved.
Show resolved Hide resolved
)
layer1 = nn.Linear(10, 7, False)
# Start with the same parameters for remote and local
Expand All @@ -677,7 +680,7 @@ def test_ddp_dist_autograd_local_vs_remote_gpu(self):
ddp_layer2 = DistributedDataParallel(layer2, device_ids=[self.rank])

remote_layer3 = RemoteModule(
"worker0", device="cpu", module_cls=nn.Linear, args=(5, 3, False)
remote_device="worker0/cpu", module_cls=nn.Linear, args=(5, 3, False)
)
layer3 = nn.Linear(5, 3, False)
# Start with the same parameters for remote and local
Expand Down
69 changes: 52 additions & 17 deletions torch/testing/_internal/distributed/nn/api/remote_module_test.py
Expand Up @@ -78,23 +78,20 @@ def world_size(self): # Override setting in RpcAgentTestFixture
return 2

@staticmethod
def _create_remote_module_iter(dst_worker_name, device="cpu", modes=None):
def _create_remote_module_iter(remote_device, modes=None):
if modes is None:
modes = ModuleCreationMode.__members__.values()

args = (1,)
kwargs = dict(first_kwarg=2)

if ModuleCreationMode.MODULE_CTOR in modes:
remote_module = RemoteModule(
dst_worker_name, device, MyModule, args, kwargs
)
remote_module = RemoteModule(remote_device, MyModule, args, kwargs)
yield remote_module

if ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE in modes:
remote_module = _RemoteModule(
dst_worker_name,
device,
remote_device,
create_scripted_module,
args,
kwargs,
Expand All @@ -108,20 +105,21 @@ def test_bad_module(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
remote_device = "{}/cpu".format(dst_worker_name)
args = (1,)
kwargs = dict(first_kwarg=2)

with self.assertRaisesRegex(
ValueError,
r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of <class nn.Module>,",
):
RemoteModule(dst_worker_name, "cpu", BadModule, args, kwargs)
RemoteModule(remote_device, BadModule, args, kwargs)
wayi1 marked this conversation as resolved.
Show resolved Hide resolved

with self.assertRaisesRegex(
ValueError,
r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of <class nn.Module>,",
):
RemoteModule(dst_worker_name, "cpu", BadModule, args, kwargs)
RemoteModule(remote_device, BadModule, args, kwargs)

@dist_utils.dist_init
def test_forward_async(self):
Expand Down Expand Up @@ -227,7 +225,7 @@ def test_valid_device(self):
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)

for remote_module in self._create_remote_module_iter(
dst_worker_name, device="cuda:0", modes=[ModuleCreationMode.MODULE_CTOR]
"{}/cuda:0".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR]
):
device = rpc.rpc_sync(
dst_worker_name, remote_device, (remote_module.module_rref,)
Expand All @@ -248,8 +246,7 @@ def test_invalid_devices(self):
):
list(
self._create_remote_module_iter(
dst_worker_name,
device="foo",
"{}/foo".format(dst_worker_name),
modes=[ModuleCreationMode.MODULE_CTOR],
)
)
Expand All @@ -259,18 +256,16 @@ def test_invalid_devices(self):
):
list(
self._create_remote_module_iter(
dst_worker_name,
device="cuda:100",
"{}/cuda:100".format(dst_worker_name),
modes=[ModuleCreationMode.MODULE_CTOR],
)
)

with self.assertRaisesRegex(RuntimeError, r"Invalid device string: 'cpu2'"):
list(
self._create_remote_module_iter(
dst_worker_name,
"{}/cpu2".format(dst_worker_name),
modes=[ModuleCreationMode.MODULE_CTOR],
device="cpu2",
)
)

Expand All @@ -279,8 +274,48 @@ def test_invalid_devices(self):
):
list(
self._create_remote_module_iter(
dst_worker_name,
device="cpu:2",
"{}/cpu:2".format(dst_worker_name),
modes=[ModuleCreationMode.MODULE_CTOR],
)
)

with self.assertRaisesRegex(RuntimeError, r"Device string must not be empty"):
list(
self._create_remote_module_iter(
"{}/".format(dst_worker_name),
modes=[ModuleCreationMode.MODULE_CTOR],
)
)

with self.assertRaisesRegex(
RuntimeError,
r"Could not parse remote_device: worker1/cuda:0/cuda:1. The valid format is '<workername>/<device>'",
):
list(
self._create_remote_module_iter(
"{}/cuda:0/cuda:1".format(dst_worker_name),
modes=[ModuleCreationMode.MODULE_CTOR],
)
)

with self.assertRaisesRegex(
RuntimeError,
r"The workername in remote_device '/' cannot be empty. The valid format is '<workername>/<device>'",
):
list(
self._create_remote_module_iter(
"/",
modes=[ModuleCreationMode.MODULE_CTOR],
)
)

with self.assertRaisesRegex(
RuntimeError,
r"The workername in remote_device '/cuda:0' cannot be empty. The valid format is '<workername>/<device>'",
):
list(
self._create_remote_module_iter(
"/cuda:0",
modes=[ModuleCreationMode.MODULE_CTOR],
)
)
Expand Down