Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RPC] Support timeout in rref._get_type() #50498

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions torch/csrc/distributed/rpc/init.cpp
Expand Up @@ -373,13 +373,22 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
// Intentionally not releasing GIL, as most accesses just
// retrieve cached type py::object
&PyRRef::getRRefType,
py::arg("timeout") = kUnsetRpcTimeout,
R"(
Returns the type of the data object referenced by this
``RRef``. On the owner, this is same as
``type(rref.local_value())``. On a user, this will trigger an
RPC to fetch the ``type`` object from the owner. After this
function is run once, the ``type`` object is cached by the
``RRef``, and subsequent invocations no longer trigger RPC.

Args:
rref (torch.distributed.rpc.RRef): The RRef to get type of.
timeout (float, optional): Timeout, in seconds for
``_get_type``. If the call does not complete within
this timeframe, an exception indicating so will be
raised. If this argument is not provided, the default
RPC timeout will be used.
)")
.def(
"_get_future",
Expand Down
5 changes: 3 additions & 2 deletions torch/csrc/distributed/rpc/py_rref.cpp
Expand Up @@ -249,14 +249,15 @@ py::object PyRRef::createRRefProxy(const RRefProxyType& type) const {
}
}

py::object PyRRef::getRRefType() {
py::object PyRRef::getRRefType(float timeout) {
// GIL is not released when calling this function.
if (!type_.has_value()) {
pybind11::gil_scoped_release release;
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
auto& typeFuncs = pythonRpcHandler.getRRefTypeFunctions();
pybind11::gil_scoped_acquire acquire;
type_ = isOwner() ? typeFuncs.onOwner_(*this) : typeFuncs.onUser_(*this);
type_ = isOwner() ? typeFuncs.onOwner_(*this)
: typeFuncs.onUser_(*this, timeout);
}

return *type_;
Expand Down
6 changes: 4 additions & 2 deletions torch/csrc/distributed/rpc/py_rref.h
Expand Up @@ -48,8 +48,10 @@ class PYBIND11_EXPORT PyRRef {
// of this RRef to run functions on the object referenced by this RRef.
py::object createRRefProxy(const RRefProxyType& mode) const;

// get the type of the data object referenced by this RRef.
py::object getRRefType();
// get the type of the data object referenced by this RRef. Timeout argument
// is only used in the first invocation of this function as an argument to the
// RPC to the owner node of the RRef.
py::object getRRefType(float timeout = rpc::kUnsetRpcTimeout);

// Run the backward pass with the RRef as the root.
void backward(int64_t autogradContextId, bool retainGraph);
Expand Down
5 changes: 3 additions & 2 deletions torch/distributed/rpc/api.py
Expand Up @@ -365,11 +365,12 @@ def _rref_typeof_on_owner(rref):
return type(rref.local_value())


def _rref_typeof_on_user(rref):
def _rref_typeof_on_user(rref, timeout=UNSET_RPC_TIMEOUT):
return rpc_sync(
rref.owner(),
_rref_typeof_on_owner,
args=(rref,)
args=(rref,),
timeout=timeout
)


Expand Down
17 changes: 16 additions & 1 deletion torch/testing/_internal/distributed/rpc/rpc_test.py
Expand Up @@ -154,8 +154,11 @@ def __setstate__(self, obj):


class MyClass:
def __init__(self, a):
def __init__(self, a, delay=False):
self.a = a
# delay initialization to simulate errors if specified
if delay:
time.sleep(2)

def my_instance_method(self, b):
return self.a + b
Expand Down Expand Up @@ -5191,6 +5194,18 @@ def test_custom_stream_nested_multi(self):
{"cuda:0": "cuda:1", "cuda:1": "cuda:0"}
)

@dist_init
def test_rref_get_type_timeout(self):
# Test where we try to get the type of a RRef from an owner, but RRef
# creation is slower than timeout passed into _get_type.
dst_rank = (self.rank + 1) % self.world_size
dst = worker_name(dst_rank)
slow_rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2), True))
timeout = 0.5
expected_err = self.get_timeout_error_regex()
with self.assertRaisesRegex(RuntimeError, expected_err):
slow_rref._get_type(timeout=timeout)

@dist_init
def test_op_with_invalid_args(self):
dst = worker_name((self.rank + 1) % self.world_size)
Expand Down