Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RPC] Add option to make rref.get_type not block. #50977

Closed
wants to merge 9 commits into from
20 changes: 14 additions & 6 deletions torch/csrc/distributed/rpc/init.cpp
Expand Up @@ -402,13 +402,17 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
// retrieve cached type py::object
&PyRRef::getRRefType,
py::arg("timeout") = kUnsetRpcTimeout,
py::arg("blocking") = true,
R"(
Returns the type of the data object referenced by this
``RRef``. On the owner, this is same as
``type(rref.local_value())``. On a user, this will trigger an
RPC to fetch the ``type`` object from the owner. After this
function is run once, the ``type`` object is cached by the
``RRef``, and subsequent invocations no longer trigger RPC.
If ``blocking=True``, returns the type of the data object
referenced by this ``RRef``. On the owner, this is same as
``type(rref.local_value())``. Otherwise, returns a future to
this result. On a user, this will trigger an RPC to fetch the
``type`` object from the owner. After this function is run
once, the ``type`` object is cached by the ``RRef``, and
subsequent invocations no longer trigger RPC. Note that this is
true regardless of the ``blocking`` argument of subsequent
calls.

Args:
rref (torch.distributed.rpc.RRef): The RRef to get type of.
Expand All @@ -417,6 +421,10 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
this timeframe, an exception indicating so will be
raised. If this argument is not provided, the default
RPC timeout will be used.
blocking (bool, optional): Whether to synchronously wait on
the RPC triggered by the first call and return the
type. If ``False``, will return a future. Default is
``True``.
)")
.def(
"_get_future",
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/distributed/rpc/py_rref.cpp
Expand Up @@ -251,17 +251,17 @@ py::object PyRRef::createRRefProxy(
}
}

py::object PyRRef::getRRefType(float timeout) {
py::object PyRRef::getRRefType(float timeout, bool blocking) {
// GIL is not released when calling this function.
if (!type_.has_value()) {
pybind11::gil_scoped_release release;
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
auto& typeFuncs = pythonRpcHandler.getRRefTypeFunctions();
pybind11::gil_scoped_acquire acquire;
type_ = isOwner() ? typeFuncs.onOwner_(*this)
: typeFuncs.onUser_(*this, timeout);
type_ = isOwner() ? typeFuncs.onOwner_(*this, blocking)
: typeFuncs.onUser_(*this, timeout, blocking);
}

// Returns py::object that can be Python type or future.
return *type_;
}

Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/distributed/rpc/py_rref.h
Expand Up @@ -53,7 +53,9 @@ class PYBIND11_EXPORT PyRRef {
// get the type of the data object referenced by this RRef. Timeout argument
// is only used in the first invocation of this function as an argument to the
// RPC to the owner node of the RRef.
py::object getRRefType(float timeout = rpc::kUnsetRpcTimeout);
py::object getRRefType(
float timeout = rpc::kUnsetRpcTimeout,
bool blocking = true);

// Run the backward pass with the RRef as the root.
void backward(int64_t autogradContextId, bool retainGraph);
Expand Down
23 changes: 19 additions & 4 deletions torch/distributed/rpc/api.py
Expand Up @@ -7,6 +7,7 @@
from typing import Generic, TypeVar, Set, Any

import torch
from torch.futures import Future

from torch._C._distributed_rpc import (
PyRRef,
Expand Down Expand Up @@ -361,17 +362,31 @@ def _to_worker_info(to):
raise ValueError("Cannot get WorkerInfo from name {}".format(to))


def _rref_typeof_on_owner(rref):
return type(rref.local_value())
def _rref_typeof_on_owner(rref, blocking=True):
rref_type = type(rref.local_value())
if blocking:
return rref_type
else:
# Wrap result into a completed Future. This is so that if blocking=`False`
# is specified, we return a future regardless of if this call is on user
# or owner.
future = Future[type]()
future.set_result(rref_type)
return future
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this returns a future, do we need to annotate this function with an @rpc.functions.async_execution decorator?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recall Future is not picklable, and we didn't plan to make Future picklable, as communicating a future to a remote process does not seem reasonable? E.g., if the future is not completed yet, do we need to update both local and remote futures when the local one is marked as completed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't return a future when called remotely, as the only remote call is from the below _rref_typeof_on_user, and we call that with blocking (just need to get the type in that case).

I didn't decorate it with rpc.functions.async_execution for that reason, because its always ran synchronously when called over RPC. So I guess this means picklability of Future isn't a concern?

I think this could potentially cause issues later if this function is ran over RPC with blocking=False, but since its private I'm assuming use cases will be like the ones below.



def _rref_typeof_on_user(rref, timeout=UNSET_RPC_TIMEOUT):
return rpc_sync(
def _rref_typeof_on_user(rref, timeout=UNSET_RPC_TIMEOUT, blocking=True):
fut = rpc_async(
rref.owner(),
_rref_typeof_on_owner,
args=(rref,),
timeout=timeout
)
if blocking:
return fut.wait()
else:
return fut



T = TypeVar("T")
Expand Down
100 changes: 84 additions & 16 deletions torch/testing/_internal/distributed/rpc/rpc_test.py
Expand Up @@ -2340,47 +2340,102 @@ def test_return_local_rrefs(self):
self.assertEqual(rets, [11, 12, 13])

@dist_init
def test_rref_type(self):
def _test_rref_type(self, blocking):

def launched_rpc(events):
expected_name = "rpc_sync#_rref_typeof_on_owner"
expected_name = f"rpc_{RPCExecMode.ASYNC.value}#_rref_typeof_on_owner"
return any([e.name.startswith(expected_name) for e in events])

dst = worker_name((self.rank + 1) % self.world_size)
rref = rpc.remote(dst, torch.add, args=(torch.ones(2), 1))

with torch.autograd.profiler.profile() as p:
t = rref._get_type()
t = rref._get_type(blocking=blocking)
if not blocking:
t = t.wait()

self.assertTrue(launched_rpc(p.function_events))
self.assertEqual(t, type(torch.ones(2)))
expected_type = type(torch.ones(2))
self.assertEqual(t, expected_type)

futs = []

def verify(fut):
self.assertEqual(fut.value(), expected_type)

with torch.autograd.profiler.profile() as p:
for _ in range(10):
t = rref._get_type()

t = rref._get_type(blocking=blocking)
if not blocking:
futs.append(t)
t.add_done_callback(verify)
t = t.wait()
self.assertEqual(t, expected_type)

if not blocking:
# Note that cached calls with blocking=False all return the same
# cached original future.
first_fut = futs[0]
for f in futs[1:]:
self.assertTrue(f is first_fut)
# Ensure we never launch another RPC, other than for the very
# first call.
self.assertFalse(launched_rpc(p.function_events))
self.assertEqual(t, type(torch.ones(2)))

rref = rpc.remote(dst, MyClass, args=(0,))
self.assertEqual(rref._get_type(), MyClass)
rref_type = rref._get_type(blocking=blocking)
if not blocking:
rref_type = rref_type.wait()
self.assertEqual(rref_type, MyClass)

def test_rref_type_blocking(self):
self._test_rref_type(blocking=True)

def test_rref_type_non_blocking(self):
self._test_rref_type(blocking=False)

@dist_init
def test_rref_type_with_error(self):
def _test_rref_type_with_error(self, blocking):
dst = worker_name((self.rank + 1) % self.world_size)
# 10 ms timeout
rref = rpc.remote(dst, raise_func)
# Blocking: error raised inline
if blocking:
with self.assertRaisesRegex(ValueError, "Expected error"):
rref._get_type(blocking=blocking)
else:
# Non-blocking: Immediately return future, block on wait
fut = rref._get_type(blocking=blocking)
with self.assertRaisesRegex(ValueError, "Expected error"):
fut.wait()

with self.assertRaisesRegex(ValueError, "Expected error"):
rref._get_type()

def test_rref_type_with_error_blocking(self):
self._test_rref_type_with_error(blocking=True)

def test_rref_type_with_error_non_blocking(self):
self._test_rref_type_with_error(blocking=False)

@dist_init
def test_rref_type_owner(self):
def _test_rref_type_owner(self, blocking):
rref = RRef(torch.ones(2) + 1)
self.assertEqual(rref._get_type(), type(torch.ones(2)))
rref_type = rref._get_type(blocking=blocking)
if not blocking:
rref_type = rref_type.wait()
self.assertEqual(rref_type, type(torch.ones(2)))

rref = RRef(MyClass(0))
self.assertEqual(rref._get_type(), MyClass)
rref_type = rref._get_type(blocking=blocking)
if not blocking:
rref_type = rref_type.wait()
self.assertEqual(rref_type, MyClass)

def test_rref_type_owner_blocking(self):
self._test_rref_type_owner(blocking=True)

def test_rref_type_owner_non_blocking(self):
self._test_rref_type_owner(blocking=False)

@staticmethod
def _slow_add(x, y):
Expand Down Expand Up @@ -5247,16 +5302,29 @@ def test_custom_stream_nested_multi(self):
)

@dist_init
def test_rref_get_type_timeout(self):
def _test_rref_get_type_timeout(self, blocking):
# Test where we try to get the type of a RRef from an owner, but RRef
# creation is slower than timeout passed into _get_type.
dst_rank = (self.rank + 1) % self.world_size
dst = worker_name(dst_rank)
slow_rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2), True))
timeout = 0.5
expected_err = self.get_timeout_error_regex()
with self.assertRaisesRegex(RuntimeError, expected_err):
slow_rref._get_type(timeout=timeout)
# Blocking: blocks on inline call
if blocking:
with self.assertRaisesRegex(RuntimeError, expected_err):
slow_rref._get_type(timeout=timeout, blocking=blocking)
# Non-blocking: blocks on wait
else:
fut = slow_rref._get_type(timeout=timeout, blocking=blocking)
with self.assertRaisesRegex(RuntimeError, expected_err):
fut.wait()

def test_rref_get_type_timeout_blocking(self):
self._test_rref_get_type_timeout(blocking=True)

def test_rref_get_type_timeout_non_blocking(self):
self._test_rref_get_type_timeout(blocking=False)

@dist_init
def test_op_with_invalid_args(self):
Expand Down