From 08fd1b0f108570e30b3256ff95a187c9dc5009f3 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Fri, 22 Jan 2021 09:08:08 -0800 Subject: [PATCH] Fix CUDA RPC Stream Synchronization When converting RPC Message into Python objects, we were not using a CUDAFuture for the chained Future. As a result, the streams are not synchronized when calling `rpc_async(...).wait()`. This commit uses `Future::then` API to create the chained Future, which will be creating a CUDAFuture if the existing Future is a CUDA one. fixes #50881 fixes #50839 ghstack-source-id: 56c79004a6250bb608d473300260d181a3b11cc9 Pull Request resolved: https://github.com/pytorch/pytorch/pull/50949 --- .../csrc/distributed/rpc/python_functions.cpp | 31 +++++++------------ 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/torch/csrc/distributed/rpc/python_functions.cpp b/torch/csrc/distributed/rpc/python_functions.cpp index 1a399b403ab1..1d5c93ab9904 100644 --- a/torch/csrc/distributed/rpc/python_functions.cpp +++ b/torch/csrc/distributed/rpc/python_functions.cpp @@ -138,36 +138,29 @@ c10::intrusive_ptr toPyJitFuture( const std::shared_ptr& messageJitFuture, bool hasValue) { if (hasValue) { - c10::intrusive_ptr pyJitFuture = - c10::make_intrusive(PyObjectType::get()); std::weak_ptr wp = messageJitFuture; - messageJitFuture->addCallback( - at::wrapPropagateTLSState([pyJitFuture, wp]() { + return messageJitFuture->then( + at::wrapPropagateTLSState([wp]() { auto future = wp.lock(); if (future->hasError()) { - pyJitFuture->setError(future->exception_ptr()); + std::rethrow_exception(future->exception_ptr()); } else { - pyJitFuture->markCompleted( - toPyIValue(*future->value().toCustomClass())); + return toPyIValue(*future->value().toCustomClass()); } - })); - - return pyJitFuture; + }), + PyObjectType::get()); } else { - c10::intrusive_ptr pyJitFuture = - c10::make_intrusive(NoneType::get()); std::weak_ptr wp = messageJitFuture; - messageJitFuture->addCallback( - at::wrapPropagateTLSState([wp, pyJitFuture]() { + return messageJitFuture->then( + at::wrapPropagateTLSState([wp]() { auto future = wp.lock(); if (future->hasError()) { - pyJitFuture->setError(future->exception_ptr()); + std::rethrow_exception(future->exception_ptr()); } else { - pyJitFuture->markCompleted(IValue()); + return IValue(); } - })); - - return pyJitFuture; + }), + NoneType::get()); } }