Skip to content

Commit

Permalink
Enable GPU-to-GPU comm in TensorPipeAgent
Browse files Browse the repository at this point in the history
This commit uses TensorPipe's cuda_ipc channel to conduct
cross-process same-machine GPU-to-GPU communication. On the sender
side, `TensorPipeAgent` grabs a stream to each device used by the
message, let these streams wait for current streams, and passes
the streams to TensorPipe `CudaBuffer`. On the receiver side, it
also grabs a stream for each device used in the message, and uses
these streams to receive tensors and run user functions. After that,
these streams are then used for sending the response back to the
sender. When receiving the response, the sender will grab a new set
of streams and use them for TensorPipe's `CudaBuffer`.

If device maps are provided, `TensorPipeAgent::send` will return a
derived class of `CUDAFuture`, which is specifically tailored for
RPC Messages.

TODOs:
1. Enable sending CUDA RPC to the same process.
2. Add a custom CUDA stream pool.
3. When TensorPipe addressed the error for `cudaPointerGetAttributes()`,
remove `cuda:0` context initialization code in `backend_registry.py`.
4. When TensorPipe can detect availability of peer access, enable all
tests on platforms without peer access.

ghstack-source-id: 86feb4664a7101318efb1e2ac477ba76e43e38d7
Pull Request resolved: #44418
  • Loading branch information
mrshenli committed Jan 13, 2021
1 parent 2a60314 commit cea0297
Show file tree
Hide file tree
Showing 9 changed files with 704 additions and 93 deletions.
34 changes: 17 additions & 17 deletions aten/src/ATen/cuda/CUDAFuture.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

namespace at { namespace cuda {

struct TORCH_CUDA_API CUDAFuture final : at::ivalue::Future {
struct TORCH_CUDA_API CUDAFuture : at::ivalue::Future {
public:
using at::ivalue::Future::Future;

Expand Down Expand Up @@ -106,22 +106,7 @@ struct TORCH_CUDA_API CUDAFuture final : at::ivalue::Future {
}
}

private:
// The device that was current when markCompleted was called, which we'll
// restore when invoking callbacks.
c10::DeviceIndex currentDevice_;

// The events that correspond to the completion of the async I/O kernels. They
// are recorded on the appropriate streams when the future is marked completed
// and can then be queried/waited/blocked on. There is one event for each
// distinct device on which the value's tensors reside.
std::vector<at::cuda::CUDAEvent> cudaEvents_;

// A cached version of the data ptrs extracted from the value when the future
// is first marked completed.
std::vector<std::reference_wrapper<const at::DataPtr>> dataPtrs_;

std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs(
virtual std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs(
const at::IValue& value) {
at::IValue::HashAliasedIValues sub_values;
// Prefer getSubValues() over visit() as the latter is a silent no-op for
Expand All @@ -136,6 +121,21 @@ struct TORCH_CUDA_API CUDAFuture final : at::ivalue::Future {
}
return data_ptrs;
}

private:
// The device that was current when markCompleted was called, which we'll
// restore when invoking callbacks.
c10::DeviceIndex currentDevice_;

// The events that correspond to the completion of the async I/O kernels. They
// are recorded on the appropriate streams when the future is marked completed
// and can then be queried/waited/blocked on. There is one event for each
// distinct device on which the value's tensors reside.
std::vector<at::cuda::CUDAEvent> cudaEvents_;

// A cached version of the data ptrs extracted from the value when the future
// is first marked completed.
std::vector<std::reference_wrapper<const at::DataPtr>> dataPtrs_;
};

} // namespace cuda
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/distributed/rpc/macros.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#pragma once

#if defined(USE_CUDA) && !defined(__HIP_PLATFORM_HCC__)
#define USE_CUDA_NOT_ROCM
#endif
111 changes: 86 additions & 25 deletions torch/csrc/distributed/rpc/tensorpipe_agent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
#include <torch/csrc/distributed/rpc/tensorpipe_utils.h>
#include <torch/csrc/distributed/rpc/utils.h>

#ifdef USE_CUDA_NOT_ROCM
#include <ATen/cuda/CUDAMultiStreamGuard.h>
#endif

namespace torch {
namespace distributed {
namespace rpc {
Expand Down Expand Up @@ -201,6 +205,30 @@ C10_REGISTER_CREATOR(

} // namespace

namespace {

// This is a wrapper of CUDAMultiStreamGuard to run in both CUDA-enabled and
// CPU-only environments. When CUDA is not available, all methods are no-ops.
struct MultiStreamGuard {
MultiStreamGuard(const MultiStreamGuard& other) = delete;
MultiStreamGuard(MultiStreamGuard&& other) = delete;
MultiStreamGuard& operator=(const MultiStreamGuard& rhs) = delete;
MultiStreamGuard& operator=(MultiStreamGuard&& rhs) = delete;

#ifndef USE_CUDA_NOT_ROCM
explicit MultiStreamGuard(
const std::shared_ptr<LazyStreamContext>& /* unused */) {}
#else
explicit MultiStreamGuard(const std::shared_ptr<LazyStreamContext>& ctx)
: guard(ctx->getReservedStreams()) {}

private:
at::cuda::CUDAMultiStreamGuard guard;
#endif
};

} // namespace

////////////////////////// MetricsTracker /////////////////////////////////

TensorPipeAgent::TimeSeriesMetricsTracker::TimeSeriesMetricsTracker(
Expand Down Expand Up @@ -412,26 +440,31 @@ void TensorPipeAgent::onListenerAccepted(

void TensorPipeAgent::pipeRead(
const std::shared_ptr<tensorpipe::Pipe>& pipe,
std::function<void(const tensorpipe::Error&, Message&&)> fn) noexcept {
std::function<void(
const tensorpipe::Error&,
Message&&,
std::shared_ptr<LazyStreamContext>)> fn) noexcept {
pipe->readDescriptor([fn{std::move(fn)}, pipe](
const tensorpipe::Error& error,
tensorpipe::Message tpMessage) mutable {
if (error) {
fn(error, Message());
fn(error, Message(), nullptr);
return;
}

TensorpipeReadBuffers tpBuffers = tensorpipeAllocate(tpMessage);
auto ctx = createLazyStreamContext();
TensorpipeReadBuffers tpBuffers = tensorpipeAllocate(tpMessage, ctx);

pipe->read(
std::move(tpMessage),
[tpBuffers{
std::make_shared<TensorpipeReadBuffers>(std::move(tpBuffers))},
fn{std::move(fn)}](
fn{std::move(fn)},
ctx{std::move(ctx)}](
const tensorpipe::Error& error,
tensorpipe::Message tpMessage) mutable {
if (error) {
fn(error, Message());
fn(error, Message(), nullptr);
return;
}

Expand All @@ -440,7 +473,7 @@ void TensorPipeAgent::pipeRead(
Message rpcMessage = tensorpipeDeserialize(
std::move(tpMessage), std::move(*tpBuffers));

fn(error, std::move(rpcMessage));
fn(error, std::move(rpcMessage), std::move(ctx));
});
});
}
Expand All @@ -449,18 +482,20 @@ void TensorPipeAgent::pipeWrite(
const std::shared_ptr<tensorpipe::Pipe>& pipe,
Message&& rpcMessage,
std::vector<c10::DeviceIndex>&& devices,
std::shared_ptr<LazyStreamContext> ctx,
std::function<void(const tensorpipe::Error&)> fn) noexcept {
tensorpipe::Message tpMessage;
TensorpipeWriteBuffers tpBuffers;

std::tie(tpMessage, tpBuffers) =
tensorpipeSerialize(std::move(rpcMessage), std::move(devices));
tensorpipeSerialize(std::move(rpcMessage), std::move(devices), ctx);

pipe->write(
std::move(tpMessage),
[tpBuffers{
std::make_shared<TensorpipeWriteBuffers>(std::move(tpBuffers))},
fn{std::move(fn)}](
fn{std::move(fn)},
ctx{std::move(ctx)}](
const tensorpipe::Error& error, tensorpipe::Message /* unused */) {
fn(error);
});
Expand All @@ -469,7 +504,8 @@ void TensorPipeAgent::pipeWrite(
void TensorPipeAgent::sendCompletedResponseMessage(
std::shared_ptr<tensorpipe::Pipe>& pipe,
std::shared_ptr<JitFuture>& futureResponseMessage,
uint64_t messageId) {
uint64_t messageId,
std::shared_ptr<LazyStreamContext> ctx) {
if (!rpcAgentRunning_.load()) {
LOG(WARNING) << "RPC agent for " << workerInfo_.name_
<< " won't send response to request #" << messageId << " to "
Expand All @@ -496,6 +532,7 @@ void TensorPipeAgent::sendCompletedResponseMessage(
pipe,
std::move(responseMessage),
std::move(devices),
std::move(ctx),
[this, pipe, messageId](const tensorpipe::Error& error) {
if (error) {
LOG(WARNING)
Expand All @@ -515,7 +552,8 @@ void TensorPipeAgent::sendCompletedResponseMessage(
pipe,
createExceptionResponse(
futureResponseMessage->tryRetrieveErrorMessage(), messageId),
{},
/* devices */ {},
std::move(ctx),
[this, pipe, messageId](const tensorpipe::Error& error) {
if (error) {
LOG(WARNING)
Expand All @@ -537,7 +575,9 @@ void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
pipeRead(
pipe,
[this, pipe](
const tensorpipe::Error& error, Message&& requestMessage) mutable {
const tensorpipe::Error& error,
Message&& requestMessage,
std::shared_ptr<LazyStreamContext> ctx) mutable {
if (error) {
// FIXME This is not a correct way to check whether this error was
// "intentionally" caused by the remote end shutting down. We should
Expand Down Expand Up @@ -570,7 +610,10 @@ void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
threadPool_.run([this,
pipe,
messageId,
requestMessage{std::move(requestMessage)}]() mutable {
requestMessage{std::move(requestMessage)},
ctx{std::move(ctx)}]() mutable {
// create guards again as this function runs on a different thread
MultiStreamGuard guard(ctx);
VLOG(1) << "RPC agent for " << workerInfo_.name_
<< " is running request #" << messageId << " from "
<< pipe->getRemoteName() << " in thread pool";
Expand All @@ -588,17 +631,20 @@ void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
if (futureResponseMessage->completed()) {
decreaseCallCount(serverActiveCalls_);
sendCompletedResponseMessage(
pipe, futureResponseMessage, messageId);
pipe, futureResponseMessage, messageId, std::move(ctx));
} else {
// Not complete yet
increaseCallCount(serverActiveAsyncCalls_);
futureResponseMessage->addCallback(
[this, pipe, futureResponseMessage, messageId]() mutable {
decreaseCallCount(serverActiveCalls_);
decreaseCallCount(serverActiveAsyncCalls_);
sendCompletedResponseMessage(
pipe, futureResponseMessage, messageId);
});
futureResponseMessage->addCallback([this,
pipe,
futureResponseMessage,
messageId,
ctx{std::move(ctx)}]() mutable {
decreaseCallCount(serverActiveCalls_);
decreaseCallCount(serverActiveAsyncCalls_);
sendCompletedResponseMessage(
pipe, futureResponseMessage, messageId, std::move(ctx));
});
}

VLOG(1) << "RPC agent for " << workerInfo_.name_
Expand Down Expand Up @@ -641,7 +687,8 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
ClientPipe& clientPipe = it->second;
auto& pendingResponseMessage = clientPipe.pendingResponseMessage_;

auto futureResponseMessage = std::make_shared<AtomicJitFuture>();
auto futureResponseMessage = std::make_shared<AtomicJitFuture>(
reverseDeviceMaps_.empty() && opts_.deviceMaps.empty());
uint64_t messageId = nextMessageID_++;
requestMessage.setId(messageId);
pendingResponseMessage[messageId] = futureResponseMessage;
Expand Down Expand Up @@ -686,10 +733,13 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is sending request #"
<< messageId << " to " << clientPipe.pipe_->getRemoteName();

auto ctx = createLazyStreamContext();
ctx->waitForCurrentStreams(requestMessage.tensors());
pipeWrite(
clientPipe.pipe_,
std::move(requestMessage),
std::move(devices),
std::move(ctx),
[this, &clientPipe, messageId](const tensorpipe::Error& error) mutable {
if (error) {
if (error.isOfType<tensorpipe::PipeClosedError>() &&
Expand All @@ -716,7 +766,9 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
pipeRead(
clientPipe.pipe_,
[this, &clientPipe](
const tensorpipe::Error& error, Message&& responseMessage) {
const tensorpipe::Error& error,
Message&& responseMessage,
std::shared_ptr<LazyStreamContext> ctx) {
if (error) {
if (error.isOfType<tensorpipe::PipeClosedError>() &&
!rpcAgentRunning_.load()) {
Expand Down Expand Up @@ -777,7 +829,8 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
} else {
markFutureAsComplete(
std::move(futureResponseMessage),
std::move(responseMessage));
std::move(responseMessage),
std::move(ctx));
}
});
});
Expand Down Expand Up @@ -1029,14 +1082,17 @@ void TensorPipeAgent::decreaseCallCount(int32_t& count) {

void TensorPipeAgent::markFutureAsComplete(
std::shared_ptr<AtomicJitFuture> atomicFuture,
Message message) {
Message message,
std::shared_ptr<LazyStreamContext> ctx) {
if (!atomicFuture->isComplete.test_and_set()) {
// Completing the future will run its callbacks, which could execute
// arbitrary user code. To prevent blocking or stalling the TensorPipe event
// loops, we defer this to a worker thread.
threadPool_.run([this,
atomicFuture{std::move(atomicFuture)},
message{std::move(message)}]() mutable {
message{std::move(message)},
ctx{std::move(ctx)}]() mutable {
MultiStreamGuard guard(ctx);
atomicFuture->jitFuture->markCompleted(
IValue(c10::make_intrusive<Message>(std::move(message))));
// The future's callbacks may schedule further RPCs, increasing the count.
Expand Down Expand Up @@ -1096,6 +1152,7 @@ std::vector<c10::DeviceIndex> TensorPipeAgent::getDevicesForTensors(
std::vector<c10::DeviceIndex> deviceIndices;
deviceIndices.reserve(message.tensors().size());
const auto& deviceMap = iter->second;
bool hasCudaTensor = false;
for (const auto& t : message.tensors()) {
if (t.device().is_cpu()) {
deviceIndices.push_back(-1);
Expand All @@ -1108,8 +1165,12 @@ std::vector<c10::DeviceIndex> TensorPipeAgent::getDevicesForTensors(
t.device(),
" but received a tensor on that device.");
deviceIndices.push_back(deviceIter->second);
hasCudaTensor = true;
}
}
if (!hasCudaTensor) {
deviceIndices.clear();
}
return deviceIndices;
}
}
Expand Down

0 comments on commit cea0297

Please sign in to comment.