Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/rpc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ backend, and must be initialized with `torch.distributed.init_process_group
before using other functions. See the `documentation for
torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ for
additional details. Next, to initialize the RPC framework we need to use
`init_model_parallel` which would initialize the RPC framework, RRef framework
`init_rpc` which would initialize the RPC framework, RRef framework
and distributed autograd.

.. automodule:: torch.distributed.rpc
.. autofunction:: init_model_parallel
.. autofunction:: init_rpc

RRef
----
Expand Down
10 changes: 5 additions & 5 deletions test/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def set_termination_signal():
_TERMINATION_SIGNAL.set()


def dist_init(old_test_method=None, setup_model_parallel=True, clean_shutdown=True):
def dist_init(old_test_method=None, setup_rpc=True, clean_shutdown=True):
"""
We use this decorator for setting up and tearing down state since
MultiProcessTestCase runs each `test*` method in a separate process and
Expand All @@ -67,7 +67,7 @@ def dist_init(old_test_method=None, setup_model_parallel=True, clean_shutdown=Tr
if old_test_method is None:
return partial(
dist_init,
setup_model_parallel=setup_model_parallel,
setup_rpc=setup_rpc,
clean_shutdown=clean_shutdown,
)

Expand All @@ -78,12 +78,12 @@ def new_test_method(self, *arg, **kwargs):
"worker{}".format(rank): rank for rank in range(self.world_size)
}

if setup_model_parallel:
if setup_rpc:
global _ALL_NODE_NAMES
_ALL_NODE_NAMES = self.worker_name_to_id.keys()

# Use enough 'num_send_recv_threads' until we fix https://github.com/pytorch/pytorch/issues/26359
rpc.init_model_parallel(
rpc.init_rpc(
self_name="worker%d" % self.rank,
backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
init_method=self.init_method,
Expand All @@ -94,7 +94,7 @@ def new_test_method(self, *arg, **kwargs):

return_value = old_test_method(self, *arg, **kwargs)

if setup_model_parallel:
if setup_rpc:
if clean_shutdown:
# Follower reports done.
if self.rank == MASTER_RANK:
Expand Down
36 changes: 18 additions & 18 deletions test/rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def test_self_remote_rref_as_self_remote_arg(self):

@mock.patch.object(torch.distributed.autograd, "_init")
@mock.patch.object(torch.distributed.rpc.api, "_start_rpc_agent")
@dist_init(setup_model_parallel=False)
@dist_init(setup_rpc=False)
def test_register_rpc_backend_and_start_rpc_backend(
self, mock_rpc_agent, mock_dist_autograd_init
):
Expand All @@ -294,7 +294,7 @@ def test_register_rpc_backend_and_start_rpc_backend(
backend_name, stub_start_rpc_backend_handler
)

rpc.init_model_parallel(
rpc.init_rpc(
self_name="worker1",
backend=backend,
init_method=self.init_method,
Expand All @@ -303,13 +303,13 @@ def test_register_rpc_backend_and_start_rpc_backend(
)

@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
@dist_init(setup_model_parallel=False)
@dist_init(setup_rpc=False)
def test_duplicate_name(self):
with self.assertRaisesRegex(RuntimeError, "is not unique"):
store, _, _ = next(torch.distributed.rendezvous(
self.init_method, rank=self.rank, world_size=self.world_size
))
rpc._init_rpc(
rpc._init_rpc_backend(
backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
store=store,
self_name="duplicate_name",
Expand All @@ -318,9 +318,9 @@ def test_duplicate_name(self):
)
rpc.join_rpc()

@dist_init(setup_model_parallel=False)
@dist_init(setup_rpc=False)
def test_reinit(self):
rpc.init_model_parallel(
rpc.init_rpc(
self_name="worker{}".format(self.rank),
backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
init_method=self.init_method,
Expand All @@ -342,7 +342,7 @@ def test_reinit(self):
dist.barrier()

with self.assertRaisesRegex(RuntimeError, "is already initialized"):
rpc.init_model_parallel(
rpc.init_rpc(
self_name="worker{}".format(self.rank),
backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
init_method=self.init_method,
Expand All @@ -351,13 +351,13 @@ def test_reinit(self):
)
rpc.join_rpc()

@dist_init(setup_model_parallel=False)
@dist_init(setup_rpc=False)
def test_invalid_names(self):
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
store, _, _ = next(torch.distributed.rendezvous(
self.init_method, rank=self.rank, world_size=self.world_size
))
rpc._init_rpc(
rpc._init_rpc_backend(
backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
store=store,
self_name="abc*",
Expand All @@ -374,7 +374,7 @@ def test_invalid_names(self):
store, _, _ = next(torch.distributed.rendezvous(
self.init_method, rank=self.rank, world_size=self.world_size
))
rpc._init_rpc(
rpc._init_rpc_backend(
backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
store=store,
self_name=" ",
Expand All @@ -389,7 +389,7 @@ def test_invalid_names(self):
store, _, _ = next(torch.distributed.rendezvous(
self.init_method, rank=self.rank, world_size=self.world_size
))
rpc._init_rpc(
rpc._init_rpc_backend(
backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
store=store,
self_name="",
Expand All @@ -406,7 +406,7 @@ def test_invalid_names(self):
store, _, _ = next(torch.distributed.rendezvous(
self.init_method, rank=self.rank, world_size=self.world_size
))
rpc._init_rpc(
rpc._init_rpc_backend(
backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
store=store,
self_name="".join(["a" for i in range(500)]),
Expand Down Expand Up @@ -513,10 +513,10 @@ def test_sync_rpc(self):
self.assertEqual(ret1, torch.ones(n, n) * 2)
self.assertEqual(ret2, torch.ones(n, n) * 3)

@dist_init(setup_model_parallel=False)
@dist_init(setup_rpc=False)
def test_join_rpc(self):
# Initialize RPC.
rpc.init_model_parallel(
rpc.init_rpc(
self_name="worker%d" % self.rank,
backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
init_method=self.init_method,
Expand Down Expand Up @@ -1035,7 +1035,7 @@ def test_remote_same_worker(self):
self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)

@unittest.skip("Test is flaky on ASAN, see https://github.com/pytorch/pytorch/issues/29117")
@dist_init(setup_model_parallel=True)
@dist_init(setup_rpc=True)
def test_call_method_on_rref(self):
"""
Tests that it is possible to call an instance method on a remote objet
Expand Down Expand Up @@ -1067,10 +1067,10 @@ def test_get_default_rpc_timeout(self):
timeout = rpc.get_rpc_timeout()
self.assertEqual(timeout, rpc.constants.DEFAULT_RPC_TIMEOUT)

@dist_init(setup_model_parallel=False)
@dist_init(setup_rpc=False)
def test_set_rpc_timeout(self):
timeout = timedelta(seconds=1)
rpc.init_model_parallel(
rpc.init_rpc(
self_name="worker{}".format(self.rank),
backend=rpc.backend_registry.BackendType[TEST_CONFIG.rpc_backend_name],
init_method=self.init_method,
Expand All @@ -1092,7 +1092,7 @@ def test_func():
self.assertEqual(test_func(), "expected result")

def test_dist_init_decorator(self):
@dist_init(setup_model_parallel=False)
@dist_init(setup_rpc=False)
def test_func(self):
return "expected result"

Expand Down
10 changes: 5 additions & 5 deletions torch/distributed/rpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ def is_available():


if is_available():
from .api import _init_rpc
from .api import _init_rpc_backend
from .api import * # noqa: F401
import torch.distributed.autograd

def init_model_parallel(
def init_rpc(
self_name,
backend=backend_registry.BackendType.PROCESS_GROUP,
init_method=None,
Expand All @@ -31,7 +31,7 @@ def init_model_parallel(
rpc_timeout=DEFAULT_RPC_TIMEOUT,
):
r"""
Initializes model parallel primitives such as the local rpc agent
Initializes RPC primitives such as the local RPC agent
and distributed autograd.

Initializes the local RPC agent which immediately makes the current
Expand Down Expand Up @@ -63,7 +63,7 @@ def init_model_parallel(
)
store, _, _ = next(rendezvous_iterator)

# Initialize autograd before RPC since _init_rpc guarantees all
# Initialize autograd before RPC since _init_rpc_backend guarantees all
# processes sync via the store. If we initialize autograd after RPC,
# there could be a race where some nodes might have initialized autograd
# and others might not have. As a result, a node calling
Expand All @@ -72,7 +72,7 @@ def init_model_parallel(
torch.distributed.autograd._init(worker_name_to_id[self_name])

# Initialize RPC.
_init_rpc(
_init_rpc_backend(
backend,
store,
self_name,
Expand Down
18 changes: 9 additions & 9 deletions torch/distributed/rpc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def wrapper(*args, **kwargs):
if _agent is None:
raise RuntimeError(
"RPC has not been initialized. Call "
"torch.distributed.rpc.init_model_parallel first."
"torch.distributed.rpc.init_rpc first."
)
return func(*args, **kwargs)
return wrapper
Expand Down Expand Up @@ -60,8 +60,8 @@ def sync_rpc():



# TODO: add a context manager to wrap _init_rpc and join_rpc
def _init_rpc(
# TODO: add a context manager to wrap _init_rpc_backend and join_rpc
def _init_rpc_backend(
backend=backend_registry.BackendType.PROCESS_GROUP,
store=None,
self_name=None,
Expand Down Expand Up @@ -152,7 +152,7 @@ def remote(to, func, args=None, kwargs=None):
>>> import torch.distributed as dist
>>> import torch.distributed.rpc as rpc
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
>>> rpc.init_model_parallel("worker0")
>>> rpc.init_rpc("worker0")
>>> worker1 = rpc.get_worker_info("worker1")
>>> rref1 = rpc.remote(worker1, torch.add, args=(torch.ones(2), 3))
>>> rref2 = rpc.remote(worker1, torch.add, args=(torch.ones(2), 1))
Expand All @@ -162,7 +162,7 @@ def remote(to, func, args=None, kwargs=None):
On worker 1:
>>> import torch.distributed as dist
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
>>> dist.init_model_parallel("worker1")
>>> dist.init_rpc("worker1")
>>> rpc.join_rpc()
"""
qualified_name = torch.jit._find_builtin(func)
Expand Down Expand Up @@ -227,15 +227,15 @@ def rpc_sync(to, func, args=None, kwargs=None):
>>> import torch.distributed as dist
>>> import torch.distributed.rpc as rpc
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
>>> rpc.init_model_parallel("worker0")
>>> rpc.init_rpc("worker0")
>>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
>>> rpc.join_rpc()

On worker 1:
>>> import torch.distributed as dist
>>> import torch.distributed.rpc as rpc
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
>>> rpc.init_model_parallel("worker1")
>>> rpc.init_rpc("worker1")
>>> rpc.join_rpc()
"""
fut = _invoke_rpc(to, func, args, kwargs)
Expand Down Expand Up @@ -269,7 +269,7 @@ def rpc_async(to, func, args=None, kwargs=None):
>>> import torch.distributed as dist
>>> import torch.distributed.rpc as rpc
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
>>> rpc.init_model_parallel("worker0")
>>> rpc.init_rpc("worker0")
>>> worker1 = rpc.get_worker_id("worker1")
>>> fut1 = rpc.rpc_async(worker1, torch.add, args=(torch.ones(2), 3))
>>> fut2 = rpc.rpc_async(worker1, min, args=(1, 2))
Expand All @@ -280,7 +280,7 @@ def rpc_async(to, func, args=None, kwargs=None):
>>> import torch.distributed as dist
>>> import torch.distributed.rpc as rpc
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
>>> rpc.init_model_parallel("worker1")
>>> rpc.init_rpc("worker1")
>>> rpc.join_rpc()
"""
fut = _invoke_rpc(to, func, args, kwargs)
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/rpc/backend_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def register_backend(backend_name, init_backend_handler):
Arguments:
backend (str): backend string to identify the handler.
handler (function): Handler that is invoked when the
`_init_rpc()` function is called with a backend.
`_init_rpc_backend()` function is called with a backend.
This returns the agent.
"""
global BackendType
Expand Down Expand Up @@ -54,7 +54,7 @@ def _process_group_init_backend_handler(
# Initialize ProcessGroup.
if dist.is_initialized():
raise RuntimeError(
"Default process group must not be initialized before `init_model_parallel`."
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed backticks

"Default process group must not be initialized before init_rpc."
)

world_size = len(worker_name_to_id)
Expand Down