Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement timeout support for RRefs #38590

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 10 additions & 1 deletion torch/csrc/distributed/rpc/init.cpp
Expand Up @@ -204,11 +204,19 @@ PyObject* rpc_init(PyObject* /* unused */) {
.def(
"to_here",
&PyRRef::toHere,
py::arg("timeout") = py::cast(kUnsetRpcTimeout),
py::call_guard<py::gil_scoped_release>(),
R"(
Blocking call that copies the value of the RRef from the owner
to the local node and returns it. If the current node is the
owner, returns a reference to the local value.

Arguments:
timeout (Optional, float): Timeout for ``to_here``. If
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

timeout in rpc/api.py are written as timeout (float, optional)

the call does not complete within this timeframe, an
exception indicating so will be raised. If this argument
is not provided, the default RPC timeout (60s) will be
used.
)")
.def(
"local_value",
Expand Down Expand Up @@ -619,7 +627,8 @@ PyObject* rpc_init(PyObject* /* unused */) {
py::call_guard<py::gil_scoped_release>(),
py::arg("dst"),
py::arg("pickledPythonUDF"),
py::arg("tensors"));
py::arg("tensors"),
py::arg("timeout"));

module.def(
"_invoke_remote_torchscript",
Expand Down
10 changes: 9 additions & 1 deletion torch/csrc/distributed/rpc/message.h
Expand Up @@ -8,6 +8,14 @@ namespace torch {
namespace distributed {
namespace rpc {

// An enum denoting common RPC errors to allow specific error handling for them.
enum RPCErrorType {
UNKNOWN_ERROR = 0, /* Indicates that error type could not be parsed */
TIMEOUT = 1, /* Indicates that the RPC has timed out */
INTENTIONAL_FAILURE = 2 /* Deliberate failure, such as those injected by
FaultyProcessGroupAgent for testing */
};

enum MessageType {
// messages for dist.rpc on builtin operators
SCRIPT_CALL = 0,
Expand All @@ -20,7 +28,7 @@ enum MessageType {
// messages for dist.remote on builtin operators and Python UDF
SCRIPT_REMOTE_CALL = 4, // A remote call on a builtin operator
PYTHON_REMOTE_CALL = 5, // A remote call on a Python UDF
REMOTE_RET = 6, // A remote call on a Python UDF
REMOTE_RET = 6, // Response for remote calls for UDF, builtin, or script
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this!


// RRef related internal messages
SCRIPT_RREF_FETCH_CALL = 7, // A UserRRef<IValue> fetches value from owner
Expand Down
77 changes: 42 additions & 35 deletions torch/csrc/distributed/rpc/process_group_agent.cpp
Expand Up @@ -2,6 +2,7 @@

#include <c10/util/C++17.h>
#include <c10d/ProcessGroup.hpp>
#include <fmt/format.h>
#include <torch/csrc/distributed/rpc/request_callback_impl.h>
#include <torch/csrc/distributed/rpc/utils.h>

Expand All @@ -10,6 +11,8 @@
namespace torch {
namespace distributed {
namespace rpc {
const std::string kRPCTimeoutErrorStr =
"RPC ran for more than {} milliseconds and timed out.";

namespace {
constexpr auto kSecToMsConversion = 1000;
Expand Down Expand Up @@ -371,36 +374,7 @@ std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
// Sending to ourselves: bypass the send logic and enqueue directly
// to our receiving queue.
if (to.id_ == (worker_id_t)pg_->getRank()) {
threadPool_.run(std::bind(
[this, future](const Message& message) {
// Unlike the other cases, need to add a tensor deleter, since the
// data outlives the scope of this function. It's shared_ptr<> due
// to c++11 lambda capture limitations with unique_ptr<>.
std::unique_ptr<std::string> payload;
try {
payload = std::make_unique<std::string>(
wireSerialize(message.payload(), message.tensors()));
// only increment sendCounts when the message is indeed added into
// local recv.
sendCounts_.increment(pg_->getRank());
} catch (std::exception& e) {
markFutureWithError(message.id(), e.what());
return;
}
const char* data = payload->data();
size_t len = payload->length();
std::string* delete_when_done = payload.release();
enqueueRecv(RecvWork(
getWorkerInfo(pg_->getRank()),
message.type(),
message.id(),
torch::from_blob(
(void*)data,
len,
[delete_when_done](void*) { delete delete_when_done; },
{torch::kChar})));
},
std::move(message)));
sendToSelf(std::move(message));
return future;
}

Expand Down Expand Up @@ -479,6 +453,39 @@ void ProcessGroupAgent::handleSend(const SendWork& work) {
}
}

void ProcessGroupAgent::sendToSelf(Message&& message) {
threadPool_.run(std::bind(
[this](const Message& message) {
// Unlike the other cases, need to add a tensor deleter, since the
// data outlives the scope of this function. It's shared_ptr<> due
// to c++11 lambda capture limitations with unique_ptr<>.
std::unique_ptr<std::string> payload;
try {
payload = std::make_unique<std::string>(
wireSerialize(message.payload(), message.tensors()));
// only increment sendCounts when the message is indeed added into
// local recv.
sendCounts_.increment(pg_->getRank());
} catch (std::exception& e) {
markFutureWithError(message.id(), e.what());
return;
}
const char* data = payload->data();
size_t len = payload->length();
std::string* delete_when_done = payload.release();
enqueueRecv(RecvWork(
getWorkerInfo(pg_->getRank()),
message.type(),
message.id(),
torch::from_blob(
(void*)data,
len,
[delete_when_done](void*) { delete delete_when_done; },
{torch::kChar})));
},
std::move(message)));
}

void ProcessGroupAgent::enqueueSend(SendWork work) {
// NB: this can be changed to use a native move capture when moved to C++14
threadPool_.run(std::bind(
Expand Down Expand Up @@ -796,13 +803,13 @@ void ProcessGroupAgent::pollTimedOutRPCs() {
futureCV_.notify_all();

for (const auto& timedOutFuture : timedOutFutures) {
auto err = c10::str(
"RPC ran for more than ",
timedOutFuture.timeout_.count(),
" milliseconds and timed out.");
auto err = makeRPCError(
fmt::format(kRPCTimeoutErrorStr, timedOutFuture.timeout_.count()),
RPCErrorType::TIMEOUT);

if (!timedOutFuture.future_->hasError()) {
--clientActiveCalls_;
timedOutFuture.future_->setError(err);
timedOutFuture.future_->setError(std::move(err));
// The future timed out and will not be processed by handleRecv(), even
// if we eventually get a response. In order to keep track of all
// send/recv pairs, we increment the count here.
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/distributed/rpc/process_group_agent.h
Expand Up @@ -93,6 +93,8 @@ class ProcessGroupAgent : public RpcAgent {

// put SendWork into a queue and notify the worker thread
virtual void enqueueSend(SendWork work);
// Bypass handleSend() logic and send a message to self rakn
virtual void sendToSelf(Message&& message);

private:
class MessageCounter {
Expand Down
7 changes: 4 additions & 3 deletions torch/csrc/distributed/rpc/py_rref.cpp
Expand Up @@ -153,14 +153,15 @@ std::string PyRRef::ownerName() const {
return rref_->ownerName();
}

py::object PyRRef::toHere() const {
py::object PyRRef::toHere(const float timeoutSeconds) const {
if (rref_->isOwner()) {
return localValue();
} else {
// toHere() calls python_rpc_handler which acquires GIL when UserRRef holds
// a python object
IValue value =
c10::static_intrusive_pointer_cast<UserRRef>(rref_)->toHere();
IValue value = c10::static_intrusive_pointer_cast<UserRRef>(rref_)->toHere(
timeoutSeconds);

if (rref_->isPyObj()) {
// python_rpc_handler deserialization will acquires GIL.
auto rfr_values = value.toTuple()->elements();
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/distributed/rpc/py_rref.h
Expand Up @@ -23,7 +23,9 @@ class PyRRef {
bool confirmedByOwner() const;
WorkerInfo owner() const;
std::string ownerName() const;
py::object toHere() const;
py::object toHere(
const float timeoutSeconds =
torch::distributed::rpc::kUnsetRpcTimeout) const;
py::object localValue() const;
std::string str() const;
py::tuple pickle() const;
Expand Down
52 changes: 35 additions & 17 deletions torch/csrc/distributed/rpc/python_functions.cpp
Expand Up @@ -95,7 +95,8 @@ std::shared_ptr<FutureMessage> sendPythonRemoteCall(
const WorkerInfo& dst,
SerializedPyObj serializedPyObj,
const IValue& rrefId,
const IValue& forkId) {
const IValue& forkId,
const float rpcTimeoutSeconds) {
auto pythonRemoteCall = std::make_unique<PythonRemoteCall>(
std::move(serializedPyObj), rrefId, forkId);

Expand All @@ -106,7 +107,8 @@ std::shared_ptr<FutureMessage> sendPythonRemoteCall(
*agent,
dst,
std::move(*pythonRemoteCall).toMessage(),
true /*forceGradRecording*/);
true /*forceGradRecording*/,
rpcTimeoutSeconds);
}

} // namespace
Expand Down Expand Up @@ -223,6 +225,7 @@ c10::intrusive_ptr<JitFuture> pyRpcTorchscript(
PyRRef pyRemoteBuiltin(
const WorkerInfo& dst,
const std::string& opName,
const float rpcTimeoutSeconds,
const py::args& args,
const py::kwargs& kwargs) {
DCHECK(PyGILState_Check());
Expand All @@ -242,7 +245,11 @@ PyRRef pyRemoteBuiltin(
op, std::move(stack), userRRef->rrefId(), userRRef->forkId());

auto fm = sendMessageWithAutograd(
*agent, dst, std::move(*scriptRemoteCall).toMessage(), false);
*agent,
dst,
std::move(*scriptRemoteCall).toMessage(),
/*forceGradRecord */ false,
/* timeout */ rpcTimeoutSeconds);

userRRef->registerOwnerCreationFuture(fm);
ctx.addPendingUser(userRRef->forkId(), userRRef);
Expand All @@ -258,22 +265,29 @@ PyRRef pyRemoteBuiltin(
auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>(
op, std::move(stack), ownerRRef->rrefId(), ownerRRef->rrefId());
auto fm = sendMessageWithAutograd(
*agent, dst, std::move(*scriptRemoteCall).toMessage(), false);
*agent,
dst,
std::move(*scriptRemoteCall).toMessage(),
/* forceGradRecord */ false,
/* timeout */ rpcTimeoutSeconds);

ownerRRef->registerOwnerCreationFuture(fm);

// Builtin operators does not return py::object, and hence does not require
// GIL for destructing the potentially deleted OwerRRef.
fm->addCallback(
[](const FutureMessage& fm) { callback::finishCreatingOwnerRRef(fm); });
[ownerRRefId = ownerRRef->rrefId()](const FutureMessage& fm) {
callback::finishCreatingOwnerRRef(fm, ownerRRefId);
});
return PyRRef(ownerRRef);
}
}

PyRRef pyRemotePythonUdf(
const WorkerInfo& dst,
std::string& pickledPythonUDF,
std::vector<torch::Tensor>& tensors) {
std::vector<torch::Tensor>& tensors,
const float rpcTimeoutSeconds) {
DCHECK(!PyGILState_Check());
auto& ctx = RRefContext::getInstance();
auto serializedPyObj =
Expand All @@ -284,7 +298,8 @@ PyRRef pyRemotePythonUdf(
dst,
std::move(serializedPyObj),
userRRef->rrefId().toIValue(),
userRRef->forkId().toIValue());
userRRef->forkId().toIValue(),
rpcTimeoutSeconds);

userRRef->registerOwnerCreationFuture(fm);

Expand All @@ -301,24 +316,27 @@ PyRRef pyRemotePythonUdf(
dst,
std::move(serializedPyObj),
ownerRRef->rrefId().toIValue(),
ownerRRef->rrefId().toIValue());
ownerRRef->rrefId().toIValue(),
rpcTimeoutSeconds);

ownerRRef->registerOwnerCreationFuture(fm);

fm->addCallback([](const FutureMessage& fm) {
auto deletedRRef = callback::finishCreatingOwnerRRef(fm);
if (deletedRRef && deletedRRef->isPyObj()) {
py::gil_scoped_acquire ag;
deletedRRef.reset();
}
});
fm->addCallback(
[ownerRRefId = ownerRRef->rrefId()](const FutureMessage& fm) {
auto deletedRRef = callback::finishCreatingOwnerRRef(fm, ownerRRefId);
if (deletedRRef && deletedRRef->isPyObj()) {
py::gil_scoped_acquire ag;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably is no longer necessary. Not requesting changes in this PR. I added #39355 to track and will investigate after this PR is merged.

deletedRRef.reset();
}
});
return PyRRef(ownerRRef);
}
}

PyRRef pyRemoteTorchscript(
const std::string& dstWorkerName,
const std::string& qualifiedNameStr,
const float rpcTimeoutSeconds,
const py::args& args,
const py::kwargs& kwargs) {
DCHECK(!PyGILState_Check());
Expand All @@ -335,8 +353,8 @@ PyRRef pyRemoteTorchscript(
functionSchema, args, kwargs, c10::nullopt);
}
DCHECK(!PyGILState_Check());
auto rrefPtr =
remoteTorchscript(dstWorkerName, qualifiedName, functionSchema, stack);
auto rrefPtr = remoteTorchscript(
dstWorkerName, qualifiedName, functionSchema, stack, rpcTimeoutSeconds);
return PyRRef(rrefPtr);
}

Expand Down
5 changes: 4 additions & 1 deletion torch/csrc/distributed/rpc/python_functions.h
Expand Up @@ -44,17 +44,20 @@ c10::intrusive_ptr<JitFuture> pyRpcTorchscript(
PyRRef pyRemoteBuiltin(
const WorkerInfo& dst,
const std::string& opName,
const float rpcTimeoutSeconds,
const py::args& args,
const py::kwargs& kwargs);

PyRRef pyRemotePythonUdf(
const WorkerInfo& dst,
std::string& pickledPythonUDF,
std::vector<torch::Tensor>& tensors);
std::vector<torch::Tensor>& tensors,
const float rpcTimeoutSeconds);

PyRRef pyRemoteTorchscript(
const std::string& dstWorkerName,
const std::string& qualifiedNameStr,
const float rpcTimeoutSeconds,
const py::args& args,
const py::kwargs& kwargs);

Expand Down