Skip to content

Commit

Permalink
Implement timeout support for RRefs
Browse files Browse the repository at this point in the history
Pull Request resolved: #38590

This PR implements timeout semantics for RRef for parity with rpc_sync and rpc_async. How it works:

- Timeout parameter is added to rpc.remote. If the rpc.remote call times out, note that the error won't be raised to the user in that call, as it is not blocking (similar to rpc_async). Instead, the timeout error will be raised the next time the RRef is used (either by pickling or to_here call).
- Error handling semantics are added to RRef to deal with the timeout errors. Previously, if there was an error creating the OwnerRRef, the callback on the local user would throw an error in a callback, resulting in an `std::terminate`. Instead of this, the error is now caught and surfaced to the user the next time the RRef is used. As part of this, we have added an `RPCErrorType` enum and defined RRef error handlers to handle the `RPCErrorrTypes` (currently just timeout and unknown)
- A timeout parameter is added to `to_here()` which gives the user control over the max amount of time it can block for.
- `ctx.prepareChildForFork()` which is called when the RRef is pickled (i.e. used as an arg over RPC) checks if the `rpc.remote()` call had timed out, and if so, raises that error to the user.
- Tests are added, primarily via delay injection.
ghstack-source-id: 105232837

Differential Revision: [D21588165](https://our.internmc.facebook.com/intern/diff/D21588165/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D21588165/)!
  • Loading branch information
rohan-varma committed Jun 4, 2020
1 parent ec5d579 commit ba9528d
Show file tree
Hide file tree
Showing 23 changed files with 737 additions and 107 deletions.
15 changes: 12 additions & 3 deletions torch/csrc/distributed/rpc/init.cpp
Expand Up @@ -205,11 +205,19 @@ PyObject* rpc_init(PyObject* /* unused */) {
.def(
"to_here",
&PyRRef::toHere,
py::arg("timeout") = py::cast(kUnsetRpcTimeout),
py::call_guard<py::gil_scoped_release>(),
R"(
Blocking call that copies the value of the RRef from the owner
to the local node and returns it. If the current node is the
owner, returns a reference to the local value.
Arguments:
timeout (float, optional): Timeout for ``to_here``. If
the call does not complete within this timeframe, an
exception indicating so will be raised. If this argument
is not provided, the default RPC timeout (60s) will be
used.
)")
.def(
"local_value",
Expand Down Expand Up @@ -534,9 +542,9 @@ PyObject* rpc_init(PyObject* /* unused */) {
});

module.def(
"_delete_all_user_rrefs",
"_delete_all_user_and_unforked_owner_rrefs",
[](std::chrono::milliseconds timeoutMillis) {
RRefContext::getInstance().delAllUsers(timeoutMillis);
RRefContext::getInstance().delAllUsersAndUnforkedOwners(timeoutMillis);
},
py::arg("timeout") = kDeleteAllUsersTimeout);

Expand Down Expand Up @@ -626,7 +634,8 @@ PyObject* rpc_init(PyObject* /* unused */) {
py::call_guard<py::gil_scoped_release>(),
py::arg("dst"),
py::arg("pickledPythonUDF"),
py::arg("tensors"));
py::arg("tensors"),
py::arg("timeout"));

module.def(
"_invoke_remote_torchscript",
Expand Down
10 changes: 9 additions & 1 deletion torch/csrc/distributed/rpc/message.h
Expand Up @@ -8,6 +8,14 @@ namespace torch {
namespace distributed {
namespace rpc {

// An enum denoting common RPC errors to allow specific error handling for them.
enum RPCErrorType {
UNKNOWN_ERROR = 0, /* Indicates that error type could not be parsed */
TIMEOUT = 1, /* Indicates that the RPC has timed out */
INTENTIONAL_FAILURE = 2 /* Deliberate failure, such as those injected by
FaultyProcessGroupAgent for testing */
};

enum MessageType {
// messages for dist.rpc on builtin operators
SCRIPT_CALL = 0,
Expand All @@ -20,7 +28,7 @@ enum MessageType {
// messages for dist.remote on builtin operators and Python UDF
SCRIPT_REMOTE_CALL = 4, // A remote call on a builtin operator
PYTHON_REMOTE_CALL = 5, // A remote call on a Python UDF
REMOTE_RET = 6, // A remote call on a Python UDF
REMOTE_RET = 6, // Response for remote calls for UDF, builtin, or script

// RRef related internal messages
SCRIPT_RREF_FETCH_CALL = 7, // A UserRRef<IValue> fetches value from owner
Expand Down
77 changes: 42 additions & 35 deletions torch/csrc/distributed/rpc/process_group_agent.cpp
Expand Up @@ -2,6 +2,7 @@

#include <c10/util/C++17.h>
#include <c10d/ProcessGroup.hpp>
#include <fmt/format.h>
#include <torch/csrc/distributed/rpc/request_callback_impl.h>
#include <torch/csrc/distributed/rpc/utils.h>

Expand All @@ -10,6 +11,8 @@
namespace torch {
namespace distributed {
namespace rpc {
const std::string kRPCTimeoutErrorStr =
"RPC ran for more than {} milliseconds and timed out.";

namespace {
constexpr auto kSecToMsConversion = 1000;
Expand Down Expand Up @@ -371,36 +374,7 @@ std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
// Sending to ourselves: bypass the send logic and enqueue directly
// to our receiving queue.
if (to.id_ == (worker_id_t)pg_->getRank()) {
threadPool_.run(std::bind(
[this, future](const Message& message) {
// Unlike the other cases, need to add a tensor deleter, since the
// data outlives the scope of this function. It's shared_ptr<> due
// to c++11 lambda capture limitations with unique_ptr<>.
std::unique_ptr<std::string> payload;
try {
payload = std::make_unique<std::string>(
wireSerialize(message.payload(), message.tensors()));
// only increment sendCounts when the message is indeed added into
// local recv.
sendCounts_.increment(pg_->getRank());
} catch (std::exception& e) {
markFutureWithError(message.id(), e.what());
return;
}
const char* data = payload->data();
size_t len = payload->length();
std::string* delete_when_done = payload.release();
enqueueRecv(RecvWork(
getWorkerInfo(pg_->getRank()),
message.type(),
message.id(),
torch::from_blob(
(void*)data,
len,
[delete_when_done](void*) { delete delete_when_done; },
{torch::kChar})));
},
std::move(message)));
sendToSelf(std::move(message));
return future;
}

Expand Down Expand Up @@ -479,6 +453,39 @@ void ProcessGroupAgent::handleSend(const SendWork& work) {
}
}

void ProcessGroupAgent::sendToSelf(Message&& message) {
threadPool_.run(std::bind(
[this](const Message& message) {
// Unlike the other cases, need to add a tensor deleter, since the
// data outlives the scope of this function. It's shared_ptr<> due
// to c++11 lambda capture limitations with unique_ptr<>.
std::unique_ptr<std::string> payload;
try {
payload = std::make_unique<std::string>(
wireSerialize(message.payload(), message.tensors()));
// only increment sendCounts when the message is indeed added into
// local recv.
sendCounts_.increment(pg_->getRank());
} catch (std::exception& e) {
markFutureWithError(message.id(), e.what());
return;
}
const char* data = payload->data();
size_t len = payload->length();
std::string* delete_when_done = payload.release();
enqueueRecv(RecvWork(
getWorkerInfo(pg_->getRank()),
message.type(),
message.id(),
torch::from_blob(
(void*)data,
len,
[delete_when_done](void*) { delete delete_when_done; },
{torch::kChar})));
},
std::move(message)));
}

void ProcessGroupAgent::enqueueSend(SendWork work) {
// NB: this can be changed to use a native move capture when moved to C++14
threadPool_.run(std::bind(
Expand Down Expand Up @@ -796,13 +803,13 @@ void ProcessGroupAgent::pollTimedOutRPCs() {
futureCV_.notify_all();

for (const auto& timedOutFuture : timedOutFutures) {
auto err = c10::str(
"RPC ran for more than ",
timedOutFuture.timeout_.count(),
" milliseconds and timed out.");
auto errStr =
fmt::format(kRPCTimeoutErrorStr, timedOutFuture.timeout_.count());
auto err = makeRPCError(errStr, RPCErrorType::TIMEOUT);

if (!timedOutFuture.future_->hasError()) {
--clientActiveCalls_;
timedOutFuture.future_->setError(err);
timedOutFuture.future_->setError(std::move(err));
// The future timed out and will not be processed by handleRecv(), even
// if we eventually get a response. In order to keep track of all
// send/recv pairs, we increment the count here.
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/distributed/rpc/process_group_agent.h
Expand Up @@ -93,6 +93,8 @@ class ProcessGroupAgent : public RpcAgent {

// put SendWork into a queue and notify the worker thread
virtual void enqueueSend(SendWork work);
// Bypass handleSend() logic and send a message to self rank
virtual void sendToSelf(Message&& message);

private:
class MessageCounter {
Expand Down
7 changes: 4 additions & 3 deletions torch/csrc/distributed/rpc/py_rref.cpp
Expand Up @@ -153,14 +153,15 @@ std::string PyRRef::ownerName() const {
return rref_->ownerName();
}

py::object PyRRef::toHere() const {
py::object PyRRef::toHere(const float timeoutSeconds) const {
if (rref_->isOwner()) {
return localValue();
} else {
// toHere() calls python_rpc_handler which acquires GIL when UserRRef holds
// a python object
IValue value =
c10::static_intrusive_pointer_cast<UserRRef>(rref_)->toHere();
IValue value = c10::static_intrusive_pointer_cast<UserRRef>(rref_)->toHere(
timeoutSeconds);

if (rref_->isPyObj()) {
// python_rpc_handler deserialization will acquires GIL.
auto rfr_values = value.toTuple()->elements();
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/distributed/rpc/py_rref.h
Expand Up @@ -23,7 +23,9 @@ class PyRRef {
bool confirmedByOwner() const;
WorkerInfo owner() const;
std::string ownerName() const;
py::object toHere() const;
py::object toHere(
const float timeoutSeconds =
torch::distributed::rpc::kUnsetRpcTimeout) const;
py::object localValue() const;
std::string str() const;
py::tuple pickle() const;
Expand Down
52 changes: 35 additions & 17 deletions torch/csrc/distributed/rpc/python_functions.cpp
Expand Up @@ -95,7 +95,8 @@ std::shared_ptr<FutureMessage> sendPythonRemoteCall(
const WorkerInfo& dst,
SerializedPyObj serializedPyObj,
const IValue& rrefId,
const IValue& forkId) {
const IValue& forkId,
const float rpcTimeoutSeconds) {
auto pythonRemoteCall = std::make_unique<PythonRemoteCall>(
std::move(serializedPyObj), rrefId, forkId);

Expand All @@ -106,7 +107,8 @@ std::shared_ptr<FutureMessage> sendPythonRemoteCall(
*agent,
dst,
std::move(*pythonRemoteCall).toMessage(),
true /*forceGradRecording*/);
true /*forceGradRecording*/,
rpcTimeoutSeconds);
}

} // namespace
Expand Down Expand Up @@ -225,6 +227,7 @@ c10::intrusive_ptr<JitFuture> pyRpcTorchscript(
PyRRef pyRemoteBuiltin(
const WorkerInfo& dst,
const std::string& opName,
const float rpcTimeoutSeconds,
const py::args& args,
const py::kwargs& kwargs) {
DCHECK(PyGILState_Check());
Expand All @@ -244,7 +247,11 @@ PyRRef pyRemoteBuiltin(
op, std::move(stack), userRRef->rrefId(), userRRef->forkId());

auto fm = sendMessageWithAutograd(
*agent, dst, std::move(*scriptRemoteCall).toMessage(), false);
*agent,
dst,
std::move(*scriptRemoteCall).toMessage(),
/*forceGradRecord */ false,
/* timeout */ rpcTimeoutSeconds);

userRRef->registerOwnerCreationFuture(fm);
ctx.addPendingUser(userRRef->forkId(), userRRef);
Expand All @@ -260,22 +267,29 @@ PyRRef pyRemoteBuiltin(
auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>(
op, std::move(stack), ownerRRef->rrefId(), ownerRRef->rrefId());
auto fm = sendMessageWithAutograd(
*agent, dst, std::move(*scriptRemoteCall).toMessage(), false);
*agent,
dst,
std::move(*scriptRemoteCall).toMessage(),
/* forceGradRecord */ false,
/* timeout */ rpcTimeoutSeconds);

ownerRRef->registerOwnerCreationFuture(fm);

// Builtin operators does not return py::object, and hence does not require
// GIL for destructing the potentially deleted OwerRRef.
fm->addCallback(
[](const FutureMessage& fm) { callback::finishCreatingOwnerRRef(fm); });
[ownerRRefId = ownerRRef->rrefId()](const FutureMessage& fm) {
callback::finishCreatingOwnerRRef(fm, ownerRRefId);
});
return PyRRef(ownerRRef);
}
}

PyRRef pyRemotePythonUdf(
const WorkerInfo& dst,
std::string& pickledPythonUDF,
std::vector<torch::Tensor>& tensors) {
std::vector<torch::Tensor>& tensors,
const float rpcTimeoutSeconds) {
DCHECK(!PyGILState_Check());
auto& ctx = RRefContext::getInstance();
auto serializedPyObj =
Expand All @@ -286,7 +300,8 @@ PyRRef pyRemotePythonUdf(
dst,
std::move(serializedPyObj),
userRRef->rrefId().toIValue(),
userRRef->forkId().toIValue());
userRRef->forkId().toIValue(),
rpcTimeoutSeconds);

userRRef->registerOwnerCreationFuture(fm);

Expand All @@ -303,24 +318,27 @@ PyRRef pyRemotePythonUdf(
dst,
std::move(serializedPyObj),
ownerRRef->rrefId().toIValue(),
ownerRRef->rrefId().toIValue());
ownerRRef->rrefId().toIValue(),
rpcTimeoutSeconds);

ownerRRef->registerOwnerCreationFuture(fm);

fm->addCallback([](const FutureMessage& fm) {
auto deletedRRef = callback::finishCreatingOwnerRRef(fm);
if (deletedRRef && deletedRRef->isPyObj()) {
py::gil_scoped_acquire ag;
deletedRRef.reset();
}
});
fm->addCallback(
[ownerRRefId = ownerRRef->rrefId()](const FutureMessage& fm) {
auto deletedRRef = callback::finishCreatingOwnerRRef(fm, ownerRRefId);
if (deletedRRef && deletedRRef->isPyObj()) {
py::gil_scoped_acquire ag;
deletedRRef.reset();
}
});
return PyRRef(ownerRRef);
}
}

PyRRef pyRemoteTorchscript(
const std::string& dstWorkerName,
const std::string& qualifiedNameStr,
const float rpcTimeoutSeconds,
const py::args& args,
const py::kwargs& kwargs) {
DCHECK(!PyGILState_Check());
Expand All @@ -337,8 +355,8 @@ PyRRef pyRemoteTorchscript(
functionSchema, args, kwargs, c10::nullopt);
}
DCHECK(!PyGILState_Check());
auto rrefPtr =
remoteTorchscript(dstWorkerName, qualifiedName, functionSchema, stack);
auto rrefPtr = remoteTorchscript(
dstWorkerName, qualifiedName, functionSchema, stack, rpcTimeoutSeconds);
return PyRRef(rrefPtr);
}

Expand Down
5 changes: 4 additions & 1 deletion torch/csrc/distributed/rpc/python_functions.h
Expand Up @@ -45,17 +45,20 @@ c10::intrusive_ptr<JitFuture> pyRpcTorchscript(
PyRRef pyRemoteBuiltin(
const WorkerInfo& dst,
const std::string& opName,
const float rpcTimeoutSeconds,
const py::args& args,
const py::kwargs& kwargs);

PyRRef pyRemotePythonUdf(
const WorkerInfo& dst,
std::string& pickledPythonUDF,
std::vector<torch::Tensor>& tensors);
std::vector<torch::Tensor>& tensors,
const float rpcTimeoutSeconds);

PyRRef pyRemoteTorchscript(
const std::string& dstWorkerName,
const std::string& qualifiedNameStr,
const float rpcTimeoutSeconds,
const py::args& args,
const py::kwargs& kwargs);

Expand Down

0 comments on commit ba9528d

Please sign in to comment.