Skip to content

Commit

Permalink
Enable GPU-to-GPU comm in TensorPipeAgent
Browse files Browse the repository at this point in the history
Pull Request resolved: #44418


This commit uses TensorPipe's cuda_ipc channel to conduct
cross-process same-machine GPU-to-GPU communication. On the sender
side, `TensorPipeAgent` grabs a stream to each device used by the
message, let these streams wait for current streams, and passes
the streams to TensorPipe `CudaBuffer`. On the receiver side, it
also grabs a stream for each device used in the message, and uses
these streams to receive tensors and run user functions. After that,
these streams are then used for sending the response back to the
sender. When receiving the response, the sender will grab a new set
of streams and use them for TensorPipe's `CudaBuffer`.

If device maps are provided, `TensorPipeAgent::send` will return a
derived class of `CUDAFuture`, which is specifically tailored for
RPC Messages.

TODOs:
1. Enable sending CUDA RPC to the same process.
2. Add a custom CUDA stream pool.
3. When TensorPipe addressed the error for `cudaPointerGetAttributes()`,
remove `cuda:0` context initialization code in `backend_registry.py`.
4. When TensorPipe can detect availability of peer access, enable all
tests on platforms without peer access.

Differential Revision: [D23626207](https://our.internmc.facebook.com/intern/diff/D23626207/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D23626207/)!
ghstack-source-id: 119821241
  • Loading branch information
mrshenli committed Jan 14, 2021
1 parent 2a60314 commit 120f934
Show file tree
Hide file tree
Showing 10 changed files with 699 additions and 93 deletions.
34 changes: 17 additions & 17 deletions aten/src/ATen/cuda/CUDAFuture.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

namespace at { namespace cuda {

struct TORCH_CUDA_API CUDAFuture final : at::ivalue::Future {
struct TORCH_CUDA_API CUDAFuture : at::ivalue::Future {
public:
using at::ivalue::Future::Future;

Expand Down Expand Up @@ -106,22 +106,7 @@ struct TORCH_CUDA_API CUDAFuture final : at::ivalue::Future {
}
}

private:
// The device that was current when markCompleted was called, which we'll
// restore when invoking callbacks.
c10::DeviceIndex currentDevice_;

// The events that correspond to the completion of the async I/O kernels. They
// are recorded on the appropriate streams when the future is marked completed
// and can then be queried/waited/blocked on. There is one event for each
// distinct device on which the value's tensors reside.
std::vector<at::cuda::CUDAEvent> cudaEvents_;

// A cached version of the data ptrs extracted from the value when the future
// is first marked completed.
std::vector<std::reference_wrapper<const at::DataPtr>> dataPtrs_;

std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs(
virtual std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs(
const at::IValue& value) {
at::IValue::HashAliasedIValues sub_values;
// Prefer getSubValues() over visit() as the latter is a silent no-op for
Expand All @@ -136,6 +121,21 @@ struct TORCH_CUDA_API CUDAFuture final : at::ivalue::Future {
}
return data_ptrs;
}

private:
// The device that was current when markCompleted was called, which we'll
// restore when invoking callbacks.
c10::DeviceIndex currentDevice_;

// The events that correspond to the completion of the async I/O kernels. They
// are recorded on the appropriate streams when the future is marked completed
// and can then be queried/waited/blocked on. There is one event for each
// distinct device on which the value's tensors reside.
std::vector<at::cuda::CUDAEvent> cudaEvents_;

// A cached version of the data ptrs extracted from the value when the future
// is first marked completed.
std::vector<std::reference_wrapper<const at::DataPtr>> dataPtrs_;
};

} // namespace cuda
Expand Down
1 change: 1 addition & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
add_dependencies(process_group_agent torch c10d)

add_library(tensorpipe_agent
"${TORCH_SRC_DIR}/csrc/distributed/rpc/macros.h"
"${TORCH_SRC_DIR}/csrc/distributed/rpc/tensorpipe_agent.cpp"
"${TORCH_SRC_DIR}/csrc/distributed/rpc/tensorpipe_agent.h"
"${TORCH_SRC_DIR}/csrc/distributed/rpc/tensorpipe_utils.cpp"
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/distributed/rpc/macros.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#pragma once

#if defined(USE_CUDA) && !defined(__HIP_PLATFORM_HCC__)
#define USE_CUDA_NOT_ROCM
#endif
111 changes: 86 additions & 25 deletions torch/csrc/distributed/rpc/tensorpipe_agent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
#include <torch/csrc/distributed/rpc/tensorpipe_utils.h>
#include <torch/csrc/distributed/rpc/utils.h>

#ifdef USE_CUDA_NOT_ROCM
#include <ATen/cuda/CUDAMultiStreamGuard.h>
#endif

namespace torch {
namespace distributed {
namespace rpc {
Expand Down Expand Up @@ -201,6 +205,30 @@ C10_REGISTER_CREATOR(

} // namespace

namespace {

// This is a wrapper of CUDAMultiStreamGuard to run in both CUDA-enabled and
// CPU-only environments. When CUDA is not available, all methods are no-ops.
struct MultiStreamGuard {
MultiStreamGuard(const MultiStreamGuard& other) = delete;
MultiStreamGuard(MultiStreamGuard&& other) = delete;
MultiStreamGuard& operator=(const MultiStreamGuard& rhs) = delete;
MultiStreamGuard& operator=(MultiStreamGuard&& rhs) = delete;

#ifndef USE_CUDA_NOT_ROCM
explicit MultiStreamGuard(
const std::shared_ptr<LazyStreamContext>& /* unused */) {}
#else
explicit MultiStreamGuard(const std::shared_ptr<LazyStreamContext>& ctx)
: guard(ctx->getReservedStreams()) {}

private:
at::cuda::CUDAMultiStreamGuard guard;
#endif
};

} // namespace

////////////////////////// MetricsTracker /////////////////////////////////

TensorPipeAgent::TimeSeriesMetricsTracker::TimeSeriesMetricsTracker(
Expand Down Expand Up @@ -412,26 +440,31 @@ void TensorPipeAgent::onListenerAccepted(

void TensorPipeAgent::pipeRead(
const std::shared_ptr<tensorpipe::Pipe>& pipe,
std::function<void(const tensorpipe::Error&, Message&&)> fn) noexcept {
std::function<void(
const tensorpipe::Error&,
Message&&,
std::shared_ptr<LazyStreamContext>)> fn) noexcept {
pipe->readDescriptor([fn{std::move(fn)}, pipe](
const tensorpipe::Error& error,
tensorpipe::Message tpMessage) mutable {
if (error) {
fn(error, Message());
fn(error, Message(), nullptr);
return;
}

TensorpipeReadBuffers tpBuffers = tensorpipeAllocate(tpMessage);
auto ctx = createLazyStreamContext();
TensorpipeReadBuffers tpBuffers = tensorpipeAllocate(tpMessage, ctx);

pipe->read(
std::move(tpMessage),
[tpBuffers{
std::make_shared<TensorpipeReadBuffers>(std::move(tpBuffers))},
fn{std::move(fn)}](
fn{std::move(fn)},
ctx{std::move(ctx)}](
const tensorpipe::Error& error,
tensorpipe::Message tpMessage) mutable {
if (error) {
fn(error, Message());
fn(error, Message(), nullptr);
return;
}

Expand All @@ -440,7 +473,7 @@ void TensorPipeAgent::pipeRead(
Message rpcMessage = tensorpipeDeserialize(
std::move(tpMessage), std::move(*tpBuffers));

fn(error, std::move(rpcMessage));
fn(error, std::move(rpcMessage), std::move(ctx));
});
});
}
Expand All @@ -449,18 +482,20 @@ void TensorPipeAgent::pipeWrite(
const std::shared_ptr<tensorpipe::Pipe>& pipe,
Message&& rpcMessage,
std::vector<c10::DeviceIndex>&& devices,
std::shared_ptr<LazyStreamContext> ctx,
std::function<void(const tensorpipe::Error&)> fn) noexcept {
tensorpipe::Message tpMessage;
TensorpipeWriteBuffers tpBuffers;

std::tie(tpMessage, tpBuffers) =
tensorpipeSerialize(std::move(rpcMessage), std::move(devices));
tensorpipeSerialize(std::move(rpcMessage), std::move(devices), ctx);

pipe->write(
std::move(tpMessage),
[tpBuffers{
std::make_shared<TensorpipeWriteBuffers>(std::move(tpBuffers))},
fn{std::move(fn)}](
fn{std::move(fn)},
ctx{std::move(ctx)}](
const tensorpipe::Error& error, tensorpipe::Message /* unused */) {
fn(error);
});
Expand All @@ -469,7 +504,8 @@ void TensorPipeAgent::pipeWrite(
void TensorPipeAgent::sendCompletedResponseMessage(
std::shared_ptr<tensorpipe::Pipe>& pipe,
std::shared_ptr<JitFuture>& futureResponseMessage,
uint64_t messageId) {
uint64_t messageId,
std::shared_ptr<LazyStreamContext> ctx) {
if (!rpcAgentRunning_.load()) {
LOG(WARNING) << "RPC agent for " << workerInfo_.name_
<< " won't send response to request #" << messageId << " to "
Expand All @@ -496,6 +532,7 @@ void TensorPipeAgent::sendCompletedResponseMessage(
pipe,
std::move(responseMessage),
std::move(devices),
std::move(ctx),
[this, pipe, messageId](const tensorpipe::Error& error) {
if (error) {
LOG(WARNING)
Expand All @@ -515,7 +552,8 @@ void TensorPipeAgent::sendCompletedResponseMessage(
pipe,
createExceptionResponse(
futureResponseMessage->tryRetrieveErrorMessage(), messageId),
{},
/* devices */ {},
std::move(ctx),
[this, pipe, messageId](const tensorpipe::Error& error) {
if (error) {
LOG(WARNING)
Expand All @@ -537,7 +575,9 @@ void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
pipeRead(
pipe,
[this, pipe](
const tensorpipe::Error& error, Message&& requestMessage) mutable {
const tensorpipe::Error& error,
Message&& requestMessage,
std::shared_ptr<LazyStreamContext> ctx) mutable {
if (error) {
// FIXME This is not a correct way to check whether this error was
// "intentionally" caused by the remote end shutting down. We should
Expand Down Expand Up @@ -570,7 +610,10 @@ void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
threadPool_.run([this,
pipe,
messageId,
requestMessage{std::move(requestMessage)}]() mutable {
requestMessage{std::move(requestMessage)},
ctx{std::move(ctx)}]() mutable {
// create guards again as this function runs on a different thread
MultiStreamGuard guard(ctx);
VLOG(1) << "RPC agent for " << workerInfo_.name_
<< " is running request #" << messageId << " from "
<< pipe->getRemoteName() << " in thread pool";
Expand All @@ -588,17 +631,20 @@ void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
if (futureResponseMessage->completed()) {
decreaseCallCount(serverActiveCalls_);
sendCompletedResponseMessage(
pipe, futureResponseMessage, messageId);
pipe, futureResponseMessage, messageId, std::move(ctx));
} else {
// Not complete yet
increaseCallCount(serverActiveAsyncCalls_);
futureResponseMessage->addCallback(
[this, pipe, futureResponseMessage, messageId]() mutable {
decreaseCallCount(serverActiveCalls_);
decreaseCallCount(serverActiveAsyncCalls_);
sendCompletedResponseMessage(
pipe, futureResponseMessage, messageId);
});
futureResponseMessage->addCallback([this,
pipe,
futureResponseMessage,
messageId,
ctx{std::move(ctx)}]() mutable {
decreaseCallCount(serverActiveCalls_);
decreaseCallCount(serverActiveAsyncCalls_);
sendCompletedResponseMessage(
pipe, futureResponseMessage, messageId, std::move(ctx));
});
}

VLOG(1) << "RPC agent for " << workerInfo_.name_
Expand Down Expand Up @@ -641,7 +687,8 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
ClientPipe& clientPipe = it->second;
auto& pendingResponseMessage = clientPipe.pendingResponseMessage_;

auto futureResponseMessage = std::make_shared<AtomicJitFuture>();
auto futureResponseMessage = std::make_shared<AtomicJitFuture>(
reverseDeviceMaps_.empty() && opts_.deviceMaps.empty());
uint64_t messageId = nextMessageID_++;
requestMessage.setId(messageId);
pendingResponseMessage[messageId] = futureResponseMessage;
Expand Down Expand Up @@ -686,10 +733,13 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is sending request #"
<< messageId << " to " << clientPipe.pipe_->getRemoteName();

auto ctx = createLazyStreamContext();
ctx->waitForCurrentStreams(requestMessage.tensors());
pipeWrite(
clientPipe.pipe_,
std::move(requestMessage),
std::move(devices),
std::move(ctx),
[this, &clientPipe, messageId](const tensorpipe::Error& error) mutable {
if (error) {
if (error.isOfType<tensorpipe::PipeClosedError>() &&
Expand All @@ -716,7 +766,9 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
pipeRead(
clientPipe.pipe_,
[this, &clientPipe](
const tensorpipe::Error& error, Message&& responseMessage) {
const tensorpipe::Error& error,
Message&& responseMessage,
std::shared_ptr<LazyStreamContext> ctx) {
if (error) {
if (error.isOfType<tensorpipe::PipeClosedError>() &&
!rpcAgentRunning_.load()) {
Expand Down Expand Up @@ -777,7 +829,8 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
} else {
markFutureAsComplete(
std::move(futureResponseMessage),
std::move(responseMessage));
std::move(responseMessage),
std::move(ctx));
}
});
});
Expand Down Expand Up @@ -1029,14 +1082,17 @@ void TensorPipeAgent::decreaseCallCount(int32_t& count) {

void TensorPipeAgent::markFutureAsComplete(
std::shared_ptr<AtomicJitFuture> atomicFuture,
Message message) {
Message message,
std::shared_ptr<LazyStreamContext> ctx) {
if (!atomicFuture->isComplete.test_and_set()) {
// Completing the future will run its callbacks, which could execute
// arbitrary user code. To prevent blocking or stalling the TensorPipe event
// loops, we defer this to a worker thread.
threadPool_.run([this,
atomicFuture{std::move(atomicFuture)},
message{std::move(message)}]() mutable {
message{std::move(message)},
ctx{std::move(ctx)}]() mutable {
MultiStreamGuard guard(ctx);
atomicFuture->jitFuture->markCompleted(
IValue(c10::make_intrusive<Message>(std::move(message))));
// The future's callbacks may schedule further RPCs, increasing the count.
Expand Down Expand Up @@ -1096,6 +1152,7 @@ std::vector<c10::DeviceIndex> TensorPipeAgent::getDevicesForTensors(
std::vector<c10::DeviceIndex> deviceIndices;
deviceIndices.reserve(message.tensors().size());
const auto& deviceMap = iter->second;
bool hasCudaTensor = false;
for (const auto& t : message.tensors()) {
if (t.device().is_cpu()) {
deviceIndices.push_back(-1);
Expand All @@ -1108,8 +1165,12 @@ std::vector<c10::DeviceIndex> TensorPipeAgent::getDevicesForTensors(
t.device(),
" but received a tensor on that device.");
deviceIndices.push_back(deviceIter->second);
hasCudaTensor = true;
}
}
if (!hasCudaTensor) {
deviceIndices.clear();
}
return deviceIndices;
}
}
Expand Down

0 comments on commit 120f934

Please sign in to comment.