Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable GPU-to-GPU comm in TensorPipeAgent #44418

Closed
wants to merge 63 commits into from
Closed
Show file tree
Hide file tree
Changes from 59 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
b0e83ac
Use streams from pool on RPC callees
mrshenli Sep 9, 2020
6ae4c16
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 10, 2020
0a23d4d
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 10, 2020
c6dec67
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 11, 2020
7b073dc
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 11, 2020
7489147
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 11, 2020
2dc32dc
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 11, 2020
af38592
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 18, 2020
77b06e7
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 20, 2020
2b84077
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 21, 2020
b445a87
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 21, 2020
3115ad3
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 21, 2020
7fab7fc
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 22, 2020
5b33064
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 22, 2020
20ce42b
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 22, 2020
0314cfc
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 23, 2020
a25d60a
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 23, 2020
f642020
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 23, 2020
5d74596
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Sep 23, 2020
31a1c64
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 16, 2020
c215331
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 18, 2020
dc846b5
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 19, 2020
0ff581a
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 19, 2020
209010c
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 19, 2020
402fa03
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 19, 2020
eedec85
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 20, 2020
0494738
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 20, 2020
9a22ec3
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 23, 2020
0b2cb77
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 23, 2020
c87efe8
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 23, 2020
27b1a0d
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 25, 2020
46c49e5
Update on "[WIP] Use streams from pool on RPC callees"
mrshenli Oct 26, 2020
124d160
Update on "Use streams from pool on RPC callees"
mrshenli Dec 8, 2020
ced8292
Update on "Use streams from pool on RPC callees"
mrshenli Dec 8, 2020
6e86c60
Update on "Use streams from pool on RPC callees"
mrshenli Dec 8, 2020
6988c90
Update on "Use streams from pool on RPC callees"
mrshenli Dec 8, 2020
5960d23
Update on "Use streams from pool on RPC callees"
mrshenli Dec 9, 2020
4ab8e85
Update on "Use streams from pool on RPC callees"
mrshenli Dec 9, 2020
5a90106
Update on "Use streams from pool on RPC callees"
mrshenli Dec 18, 2020
bb94e5f
Update on "Use streams from pool on RPC callees"
mrshenli Jan 10, 2021
6834072
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 11, 2021
ed0c81e
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 11, 2021
84ecc25
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 11, 2021
59d95ba
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 11, 2021
e9133d0
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 11, 2021
a1dcfe2
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 11, 2021
9c98bc9
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 11, 2021
b90a8af
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 12, 2021
3a0651f
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 12, 2021
43349e4
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
e2b9581
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
0e5e5b6
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
fb6e4ea
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
8c935e4
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
9b5afb1
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
20294df
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
0994ff8
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
4d6ce07
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
2880a01
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
f27608f
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
36fc05b
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 13, 2021
28209c4
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 14, 2021
b302be9
Update on "Enable GPU-to-GPU comm in TensorPipeAgent"
mrshenli Jan 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 17 additions & 17 deletions aten/src/ATen/cuda/CUDAFuture.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

namespace at { namespace cuda {

struct TORCH_CUDA_API CUDAFuture final : at::ivalue::Future {
struct TORCH_CUDA_API CUDAFuture : at::ivalue::Future {
public:
using at::ivalue::Future::Future;

Expand Down Expand Up @@ -106,22 +106,7 @@ struct TORCH_CUDA_API CUDAFuture final : at::ivalue::Future {
}
}

private:
// The device that was current when markCompleted was called, which we'll
// restore when invoking callbacks.
c10::DeviceIndex currentDevice_;

// The events that correspond to the completion of the async I/O kernels. They
// are recorded on the appropriate streams when the future is marked completed
// and can then be queried/waited/blocked on. There is one event for each
// distinct device on which the value's tensors reside.
std::vector<at::cuda::CUDAEvent> cudaEvents_;

// A cached version of the data ptrs extracted from the value when the future
// is first marked completed.
std::vector<std::reference_wrapper<const at::DataPtr>> dataPtrs_;

std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs(
virtual std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs(
const at::IValue& value) {
at::IValue::HashAliasedIValues sub_values;
// Prefer getSubValues() over visit() as the latter is a silent no-op for
Expand All @@ -136,6 +121,21 @@ struct TORCH_CUDA_API CUDAFuture final : at::ivalue::Future {
}
return data_ptrs;
}

private:
// The device that was current when markCompleted was called, which we'll
// restore when invoking callbacks.
c10::DeviceIndex currentDevice_;

// The events that correspond to the completion of the async I/O kernels. They
// are recorded on the appropriate streams when the future is marked completed
// and can then be queried/waited/blocked on. There is one event for each
// distinct device on which the value's tensors reside.
std::vector<at::cuda::CUDAEvent> cudaEvents_;

// A cached version of the data ptrs extracted from the value when the future
// is first marked completed.
std::vector<std::reference_wrapper<const at::DataPtr>> dataPtrs_;
};

} // namespace cuda
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/distributed/rpc/macros.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#pragma once

#if defined(USE_CUDA) && !defined(__HIP_PLATFORM_HCC__)
#define USE_CUDA_NOT_ROCM
#endif
115 changes: 90 additions & 25 deletions torch/csrc/distributed/rpc/tensorpipe_agent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
#include <torch/csrc/distributed/rpc/tensorpipe_utils.h>
#include <torch/csrc/distributed/rpc/utils.h>

#ifdef USE_CUDA_NOT_ROCM
#include <ATen/cuda/CUDAMultiStreamGuard.h>
#endif

namespace torch {
namespace distributed {
namespace rpc {
Expand Down Expand Up @@ -201,6 +205,30 @@ C10_REGISTER_CREATOR(

} // namespace

namespace {

// This is a wrapper of CUDAMultiStreamGuard to run in both CUDA-enabled and
// CPU-only environments. When CUDA is not available, all methods are no-ops.
struct MultiStreamGuard {
MultiStreamGuard(const MultiStreamGuard& other) = delete;
MultiStreamGuard(MultiStreamGuard&& other) = delete;
MultiStreamGuard& operator=(const MultiStreamGuard& rhs) = delete;
MultiStreamGuard& operator=(MultiStreamGuard&& rhs) = delete;

#ifndef USE_CUDA_NOT_ROCM
explicit MultiStreamGuard(
const std::shared_ptr<LazyStreamContext>& /* unused */) {}
#else
explicit MultiStreamGuard(const std::shared_ptr<LazyStreamContext>& ctx)
: guard(ctx->getReservedStreams()) {}

private:
at::cuda::CUDAMultiStreamGuard guard;
#endif
};

} // namespace

////////////////////////// MetricsTracker /////////////////////////////////

TensorPipeAgent::TimeSeriesMetricsTracker::TimeSeriesMetricsTracker(
Expand Down Expand Up @@ -412,35 +440,43 @@ void TensorPipeAgent::onListenerAccepted(

void TensorPipeAgent::pipeRead(
const std::shared_ptr<tensorpipe::Pipe>& pipe,
std::function<void(const tensorpipe::Error&, Message&&)> fn) noexcept {
std::function<void(
const tensorpipe::Error&,
Message&&,
std::shared_ptr<LazyStreamContext>)> fn) noexcept {
pipe->readDescriptor([fn{std::move(fn)}, pipe](
const tensorpipe::Error& error,
tensorpipe::Message tpMessage) mutable {
if (error) {
fn(error, Message());
fn(error, Message(), nullptr);
return;
}

TensorpipeReadBuffers tpBuffers = tensorpipeAllocate(tpMessage);
auto ctx = createLazyStreamContext();
TensorpipeReadBuffers tpBuffers = tensorpipeAllocate(tpMessage, ctx);

pipe->read(
std::move(tpMessage),
[tpBuffers{
std::make_shared<TensorpipeReadBuffers>(std::move(tpBuffers))},
fn{std::move(fn)}](
fn{std::move(fn)},
ctx{std::move(ctx)}](
const tensorpipe::Error& error,
tensorpipe::Message tpMessage) mutable {
if (error) {
fn(error, Message());
fn(error, Message(), nullptr);
return;
}

// make sure ops on current streams won't access the tensors before
// communication is done.
ctx->blockCurrentStreams();
mrshenli marked this conversation as resolved.
Show resolved Hide resolved
// FIXME This does some unpickling, which could be a bit expensive:
// perhaps it would be best to perform it inside the worker threads?
Message rpcMessage = tensorpipeDeserialize(
std::move(tpMessage), std::move(*tpBuffers));

fn(error, std::move(rpcMessage));
fn(error, std::move(rpcMessage), std::move(ctx));
});
});
}
Expand All @@ -449,18 +485,20 @@ void TensorPipeAgent::pipeWrite(
const std::shared_ptr<tensorpipe::Pipe>& pipe,
Message&& rpcMessage,
std::vector<c10::DeviceIndex>&& devices,
std::shared_ptr<LazyStreamContext> ctx,
std::function<void(const tensorpipe::Error&)> fn) noexcept {
tensorpipe::Message tpMessage;
TensorpipeWriteBuffers tpBuffers;

std::tie(tpMessage, tpBuffers) =
tensorpipeSerialize(std::move(rpcMessage), std::move(devices));
tensorpipeSerialize(std::move(rpcMessage), std::move(devices), ctx);

pipe->write(
std::move(tpMessage),
[tpBuffers{
std::make_shared<TensorpipeWriteBuffers>(std::move(tpBuffers))},
fn{std::move(fn)}](
fn{std::move(fn)},
ctx{std::move(ctx)}](
const tensorpipe::Error& error, tensorpipe::Message /* unused */) {
fn(error);
});
Expand All @@ -469,7 +507,8 @@ void TensorPipeAgent::pipeWrite(
void TensorPipeAgent::sendCompletedResponseMessage(
std::shared_ptr<tensorpipe::Pipe>& pipe,
std::shared_ptr<JitFuture>& futureResponseMessage,
uint64_t messageId) {
uint64_t messageId,
std::shared_ptr<LazyStreamContext> ctx) {
if (!rpcAgentRunning_.load()) {
LOG(WARNING) << "RPC agent for " << workerInfo_.name_
<< " won't send response to request #" << messageId << " to "
Expand All @@ -496,6 +535,7 @@ void TensorPipeAgent::sendCompletedResponseMessage(
pipe,
std::move(responseMessage),
std::move(devices),
std::move(ctx),
[this, pipe, messageId](const tensorpipe::Error& error) {
if (error) {
LOG(WARNING)
Expand All @@ -515,7 +555,8 @@ void TensorPipeAgent::sendCompletedResponseMessage(
pipe,
createExceptionResponse(
futureResponseMessage->tryRetrieveErrorMessage(), messageId),
{},
/* devices */ {},
std::move(ctx),
[this, pipe, messageId](const tensorpipe::Error& error) {
if (error) {
LOG(WARNING)
Expand All @@ -537,7 +578,9 @@ void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
pipeRead(
pipe,
[this, pipe](
const tensorpipe::Error& error, Message&& requestMessage) mutable {
const tensorpipe::Error& error,
Message&& requestMessage,
std::shared_ptr<LazyStreamContext> ctx) mutable {
if (error) {
// FIXME This is not a correct way to check whether this error was
// "intentionally" caused by the remote end shutting down. We should
Expand Down Expand Up @@ -570,7 +613,10 @@ void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
threadPool_.run([this,
pipe,
messageId,
requestMessage{std::move(requestMessage)}]() mutable {
requestMessage{std::move(requestMessage)},
ctx{std::move(ctx)}]() mutable {
// create guards again as this function runs on a different thread
MultiStreamGuard guard(ctx);
VLOG(1) << "RPC agent for " << workerInfo_.name_
<< " is running request #" << messageId << " from "
<< pipe->getRemoteName() << " in thread pool";
Expand All @@ -588,17 +634,20 @@ void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
if (futureResponseMessage->completed()) {
decreaseCallCount(serverActiveCalls_);
sendCompletedResponseMessage(
pipe, futureResponseMessage, messageId);
pipe, futureResponseMessage, messageId, std::move(ctx));
} else {
// Not complete yet
increaseCallCount(serverActiveAsyncCalls_);
futureResponseMessage->addCallback(
[this, pipe, futureResponseMessage, messageId]() mutable {
decreaseCallCount(serverActiveCalls_);
decreaseCallCount(serverActiveAsyncCalls_);
sendCompletedResponseMessage(
pipe, futureResponseMessage, messageId);
});
futureResponseMessage->addCallback([this,
pipe,
futureResponseMessage,
messageId,
ctx{std::move(ctx)}]() mutable {
decreaseCallCount(serverActiveCalls_);
decreaseCallCount(serverActiveAsyncCalls_);
sendCompletedResponseMessage(
pipe, futureResponseMessage, messageId, std::move(ctx));
});
}

VLOG(1) << "RPC agent for " << workerInfo_.name_
Expand Down Expand Up @@ -641,7 +690,8 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
ClientPipe& clientPipe = it->second;
auto& pendingResponseMessage = clientPipe.pendingResponseMessage_;

auto futureResponseMessage = std::make_shared<AtomicJitFuture>();
auto futureResponseMessage = std::make_shared<AtomicJitFuture>(
reverseDeviceMaps_.empty() && opts_.deviceMaps.empty());
uint64_t messageId = nextMessageID_++;
requestMessage.setId(messageId);
pendingResponseMessage[messageId] = futureResponseMessage;
Expand Down Expand Up @@ -686,10 +736,13 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is sending request #"
<< messageId << " to " << clientPipe.pipe_->getRemoteName();

auto ctx = createLazyStreamContext();
ctx->waitForCurrentStreams(requestMessage.tensors());
pipeWrite(
clientPipe.pipe_,
std::move(requestMessage),
std::move(devices),
std::move(ctx),
[this, &clientPipe, messageId](const tensorpipe::Error& error) mutable {
if (error) {
if (error.isOfType<tensorpipe::PipeClosedError>() &&
Expand All @@ -716,7 +769,10 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
pipeRead(
clientPipe.pipe_,
[this, &clientPipe](
const tensorpipe::Error& error, Message&& responseMessage) {
const tensorpipe::Error& error,
Message&& responseMessage,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
mrshenli marked this conversation as resolved.
Show resolved Hide resolved
std::shared_ptr<LazyStreamContext> ctx) {
if (error) {
if (error.isOfType<tensorpipe::PipeClosedError>() &&
!rpcAgentRunning_.load()) {
Expand Down Expand Up @@ -777,7 +833,8 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
} else {
markFutureAsComplete(
std::move(futureResponseMessage),
std::move(responseMessage));
std::move(responseMessage),
std::move(ctx));
}
});
});
Expand Down Expand Up @@ -1029,14 +1086,17 @@ void TensorPipeAgent::decreaseCallCount(int32_t& count) {

void TensorPipeAgent::markFutureAsComplete(
std::shared_ptr<AtomicJitFuture> atomicFuture,
Message message) {
Message message,
std::shared_ptr<LazyStreamContext> ctx) {
if (!atomicFuture->isComplete.test_and_set()) {
// Completing the future will run its callbacks, which could execute
// arbitrary user code. To prevent blocking or stalling the TensorPipe event
// loops, we defer this to a worker thread.
threadPool_.run([this,
atomicFuture{std::move(atomicFuture)},
message{std::move(message)}]() mutable {
message{std::move(message)},
ctx{std::move(ctx)}]() mutable {
MultiStreamGuard guard(ctx);
atomicFuture->jitFuture->markCompleted(
IValue(c10::make_intrusive<Message>(std::move(message))));
// The future's callbacks may schedule further RPCs, increasing the count.
Expand Down Expand Up @@ -1096,6 +1156,7 @@ std::vector<c10::DeviceIndex> TensorPipeAgent::getDevicesForTensors(
std::vector<c10::DeviceIndex> deviceIndices;
deviceIndices.reserve(message.tensors().size());
const auto& deviceMap = iter->second;
bool hasCudaTensor = false;
for (const auto& t : message.tensors()) {
if (t.device().is_cpu()) {
deviceIndices.push_back(-1);
Expand All @@ -1108,8 +1169,12 @@ std::vector<c10::DeviceIndex> TensorPipeAgent::getDevicesForTensors(
t.device(),
" but received a tensor on that device.");
deviceIndices.push_back(deviceIter->second);
hasCudaTensor = true;
}
}
if (!hasCudaTensor) {
deviceIndices.clear();
}
return deviceIndices;
}
}
Expand Down