New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement timeout support for RRefs #38590
Changes from 8 commits
8041352
8da2867
67aa5f0
2591953
b298026
0b40c25
4b17581
e809bc5
7161737
c882b0a
6625743
b004366
7d16178
3a2eb2e
0590f83
56700d1
a72eeed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,14 @@ namespace torch { | |
namespace distributed { | ||
namespace rpc { | ||
|
||
// An enum denoting common RPC errors to allow specific error handling for them. | ||
enum RPCErrorType { | ||
UNKNOWN_ERROR = 0, /* Indicates that error type could not be parsed */ | ||
TIMEOUT = 1, /* Indicates that the RPC has timed out */ | ||
INTENTIONAL_FAILURE = 2 /* Deliberate failure, such as those injected by | ||
FaultyProcessGroupAgent for testing */ | ||
}; | ||
|
||
enum MessageType { | ||
// messages for dist.rpc on builtin operators | ||
SCRIPT_CALL = 0, | ||
|
@@ -20,7 +28,7 @@ enum MessageType { | |
// messages for dist.remote on builtin operators and Python UDF | ||
SCRIPT_REMOTE_CALL = 4, // A remote call on a builtin operator | ||
PYTHON_REMOTE_CALL = 5, // A remote call on a Python UDF | ||
REMOTE_RET = 6, // A remote call on a Python UDF | ||
REMOTE_RET = 6, // Response for remote calls for UDF, builtin, or script | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for fixing this! |
||
|
||
// RRef related internal messages | ||
SCRIPT_RREF_FETCH_CALL = 7, // A UserRRef<IValue> fetches value from owner | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -95,7 +95,8 @@ std::shared_ptr<FutureMessage> sendPythonRemoteCall( | |
const WorkerInfo& dst, | ||
SerializedPyObj serializedPyObj, | ||
const IValue& rrefId, | ||
const IValue& forkId) { | ||
const IValue& forkId, | ||
const float rpcTimeoutSeconds) { | ||
auto pythonRemoteCall = std::make_unique<PythonRemoteCall>( | ||
std::move(serializedPyObj), rrefId, forkId); | ||
|
||
|
@@ -106,7 +107,8 @@ std::shared_ptr<FutureMessage> sendPythonRemoteCall( | |
*agent, | ||
dst, | ||
std::move(*pythonRemoteCall).toMessage(), | ||
true /*forceGradRecording*/); | ||
true /*forceGradRecording*/, | ||
rpcTimeoutSeconds); | ||
} | ||
|
||
} // namespace | ||
|
@@ -223,6 +225,7 @@ c10::intrusive_ptr<JitFuture> pyRpcTorchscript( | |
PyRRef pyRemoteBuiltin( | ||
const WorkerInfo& dst, | ||
const std::string& opName, | ||
const float rpcTimeoutSeconds, | ||
const py::args& args, | ||
const py::kwargs& kwargs) { | ||
DCHECK(PyGILState_Check()); | ||
|
@@ -242,7 +245,11 @@ PyRRef pyRemoteBuiltin( | |
op, std::move(stack), userRRef->rrefId(), userRRef->forkId()); | ||
|
||
auto fm = sendMessageWithAutograd( | ||
*agent, dst, std::move(*scriptRemoteCall).toMessage(), false); | ||
*agent, | ||
dst, | ||
std::move(*scriptRemoteCall).toMessage(), | ||
/*forceGradRecord */ false, | ||
/* timeout */ rpcTimeoutSeconds); | ||
|
||
userRRef->registerOwnerCreationFuture(fm); | ||
ctx.addPendingUser(userRRef->forkId(), userRRef); | ||
|
@@ -258,22 +265,29 @@ PyRRef pyRemoteBuiltin( | |
auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>( | ||
op, std::move(stack), ownerRRef->rrefId(), ownerRRef->rrefId()); | ||
auto fm = sendMessageWithAutograd( | ||
*agent, dst, std::move(*scriptRemoteCall).toMessage(), false); | ||
*agent, | ||
dst, | ||
std::move(*scriptRemoteCall).toMessage(), | ||
/* forceGradRecord */ false, | ||
/* timeout */ rpcTimeoutSeconds); | ||
|
||
ownerRRef->registerOwnerCreationFuture(fm); | ||
|
||
// Builtin operators does not return py::object, and hence does not require | ||
// GIL for destructing the potentially deleted OwerRRef. | ||
fm->addCallback( | ||
[](const FutureMessage& fm) { callback::finishCreatingOwnerRRef(fm); }); | ||
[ownerRRefId = ownerRRef->rrefId()](const FutureMessage& fm) { | ||
callback::finishCreatingOwnerRRef(fm, ownerRRefId); | ||
}); | ||
return PyRRef(ownerRRef); | ||
} | ||
} | ||
|
||
PyRRef pyRemotePythonUdf( | ||
const WorkerInfo& dst, | ||
std::string& pickledPythonUDF, | ||
std::vector<torch::Tensor>& tensors) { | ||
std::vector<torch::Tensor>& tensors, | ||
const float rpcTimeoutSeconds) { | ||
DCHECK(!PyGILState_Check()); | ||
auto& ctx = RRefContext::getInstance(); | ||
auto serializedPyObj = | ||
|
@@ -284,7 +298,8 @@ PyRRef pyRemotePythonUdf( | |
dst, | ||
std::move(serializedPyObj), | ||
userRRef->rrefId().toIValue(), | ||
userRRef->forkId().toIValue()); | ||
userRRef->forkId().toIValue(), | ||
rpcTimeoutSeconds); | ||
|
||
userRRef->registerOwnerCreationFuture(fm); | ||
|
||
|
@@ -301,24 +316,27 @@ PyRRef pyRemotePythonUdf( | |
dst, | ||
std::move(serializedPyObj), | ||
ownerRRef->rrefId().toIValue(), | ||
ownerRRef->rrefId().toIValue()); | ||
ownerRRef->rrefId().toIValue(), | ||
rpcTimeoutSeconds); | ||
|
||
ownerRRef->registerOwnerCreationFuture(fm); | ||
|
||
fm->addCallback([](const FutureMessage& fm) { | ||
auto deletedRRef = callback::finishCreatingOwnerRRef(fm); | ||
if (deletedRRef && deletedRRef->isPyObj()) { | ||
py::gil_scoped_acquire ag; | ||
deletedRRef.reset(); | ||
} | ||
}); | ||
fm->addCallback( | ||
[ownerRRefId = ownerRRef->rrefId()](const FutureMessage& fm) { | ||
auto deletedRRef = callback::finishCreatingOwnerRRef(fm, ownerRRefId); | ||
if (deletedRRef && deletedRRef->isPyObj()) { | ||
py::gil_scoped_acquire ag; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This probably is no longer necessary. Not requesting changes in this PR. I added #39355 to track and will investigate after this PR is merged. |
||
deletedRRef.reset(); | ||
} | ||
}); | ||
return PyRRef(ownerRRef); | ||
} | ||
} | ||
|
||
PyRRef pyRemoteTorchscript( | ||
const std::string& dstWorkerName, | ||
const std::string& qualifiedNameStr, | ||
const float rpcTimeoutSeconds, | ||
const py::args& args, | ||
const py::kwargs& kwargs) { | ||
DCHECK(!PyGILState_Check()); | ||
|
@@ -335,8 +353,8 @@ PyRRef pyRemoteTorchscript( | |
functionSchema, args, kwargs, c10::nullopt); | ||
} | ||
DCHECK(!PyGILState_Check()); | ||
auto rrefPtr = | ||
remoteTorchscript(dstWorkerName, qualifiedName, functionSchema, stack); | ||
auto rrefPtr = remoteTorchscript( | ||
dstWorkerName, qualifiedName, functionSchema, stack, rpcTimeoutSeconds); | ||
return PyRRef(rrefPtr); | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
timeout
inrpc/api.py
are written astimeout (float, optional)