Skip to content

Commit

Permalink
Fix CUDA RPC Stream Synchronization
Browse files Browse the repository at this point in the history
When converting RPC Message into Python objects, we were not using
a CUDAFuture for the chained Future. As a result, the streams are
not synchronized when calling `rpc_async(...).wait()`. This commit
uses `Future::then` API to create the chained Future, which will
be creating a CUDAFuture if the existing Future is a CUDA one.

fixes #50881
fixes #50839

ghstack-source-id: 56c79004a6250bb608d473300260d181a3b11cc9
Pull Request resolved: #50949
  • Loading branch information
mrshenli committed Jan 22, 2021
1 parent 5f07b53 commit 08fd1b0
Showing 1 changed file with 12 additions and 19 deletions.
31 changes: 12 additions & 19 deletions torch/csrc/distributed/rpc/python_functions.cpp
Expand Up @@ -138,36 +138,29 @@ c10::intrusive_ptr<JitFuture> toPyJitFuture(
const std::shared_ptr<JitFuture>& messageJitFuture,
bool hasValue) {
if (hasValue) {
c10::intrusive_ptr<JitFuture> pyJitFuture =
c10::make_intrusive<JitFuture>(PyObjectType::get());
std::weak_ptr<JitFuture> wp = messageJitFuture;
messageJitFuture->addCallback(
at::wrapPropagateTLSState<void>([pyJitFuture, wp]() {
return messageJitFuture->then(
at::wrapPropagateTLSState<IValue>([wp]() {
auto future = wp.lock();
if (future->hasError()) {
pyJitFuture->setError(future->exception_ptr());
std::rethrow_exception(future->exception_ptr());
} else {
pyJitFuture->markCompleted(
toPyIValue(*future->value().toCustomClass<Message>()));
return toPyIValue(*future->value().toCustomClass<Message>());
}
}));

return pyJitFuture;
}),
PyObjectType::get());
} else {
c10::intrusive_ptr<JitFuture> pyJitFuture =
c10::make_intrusive<JitFuture>(NoneType::get());
std::weak_ptr<JitFuture> wp = messageJitFuture;
messageJitFuture->addCallback(
at::wrapPropagateTLSState<void>([wp, pyJitFuture]() {
return messageJitFuture->then(
at::wrapPropagateTLSState<IValue>([wp]() {
auto future = wp.lock();
if (future->hasError()) {
pyJitFuture->setError(future->exception_ptr());
std::rethrow_exception(future->exception_ptr());
} else {
pyJitFuture->markCompleted(IValue());
return IValue();
}
}));

return pyJitFuture;
}),
NoneType::get());
}
}

Expand Down

0 comments on commit 08fd1b0

Please sign in to comment.