Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-GPU support to FutureNCCL #48500

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 7 additions & 7 deletions torch/lib/c10d/ProcessGroupNCCL.cpp
Expand Up @@ -1008,15 +1008,15 @@ std::vector<at::Tensor> ProcessGroupNCCL::WorkNCCL::result() {

c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupNCCL::WorkNCCL::
getFuture() {
TORCH_INTERNAL_ASSERT(
outputs_->size() == 1,
"WorkNCCL's getFuture API is only supported for single-process single-device mode.");
auto deviceIndex = (*outputs_)[0].device().index();
// Create a new FutureNCCL object after checking for single-process
// single-device mode.
std::vector<c10::DeviceIndex> deviceIndices;
for (const c10::Device& device : devices_) {
TORCH_INTERNAL_ASSERT(device.is_cuda());
deviceIndices.push_back(device.index());
}

return c10::make_intrusive<FutureNCCL>(
at::IValue(*outputs_),
deviceIndex,
std::move(deviceIndices),
cudaEvents_);
}

Expand Down
107 changes: 73 additions & 34 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
Expand Up @@ -13,6 +13,7 @@

#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAEvent.h>
#include <ATen/cuda/CUDAMultiStreamGuard.h>
#include <c10/core/Stream.h>
#include <c10/core/StreamGuard.h>
#include <c10/cuda/CUDACachingAllocator.h>
Expand Down Expand Up @@ -196,10 +197,8 @@ class ProcessGroupNCCL : public ProcessGroup {
// or NCCL's barrier().
//
// If created by WorkNCCL's getFuture API, FutureNCCL has a reference to
// WorkNCCL's cudaEvents, NCCL collective's outputs, and the device index of
// outputs' device. Its value is NCCL collective's
// outputs. FutureNCCL only supports single-process single-device mode where
// the size of outputs is equal to 1.
// WorkNCCL's cudaEvents, NCCL collective's outputs, and the device indices of
// outputs' devices. Its value is NCCL collective's outputs.
//
// If created by FutureNCCL's then callback, its value becomes the value of
// callback() and its cudaEvents will record the NCCL stream that runs that
Expand All @@ -212,28 +211,37 @@ class ProcessGroupNCCL : public ProcessGroup {
public:
explicit FutureNCCL(
at::IValue value,
c10::DeviceIndex deviceIndex,
std::vector<c10::DeviceIndex> deviceIndices,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we expect devices to be distinct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do, although I think it should be fine if there are duplicated. When the CUDAFuture itself determines the set of devices (inside markCompleted) it explicitly deduplicated them. Let me add a check to the other constructor (the one invoked by ProcessGroupNCCL) to ensure they are distinct.

std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents)
: at::ivalue::Future(c10::ListType::create(c10::TensorType::get())),
value_(std::move(value)),
deviceIndex_(deviceIndex),
deviceIndices_(std::move(deviceIndices)),
cudaEvents_(std::move(cudaEvents)) {
TORCH_INTERNAL_ASSERT(
cudaEvents_->size() == 1,
"FutureNCCL only supports single-process single-device mode.");
cudaEvents_->size() == deviceIndices_.size(),
"The device indices and the events must be paired up. Got ",
deviceIndices_.size(), " devices and ", cudaEvents_->size(),
" events.");
for (const at::cuda::CUDAEvent& event : *cudaEvents_) {
TORCH_INTERNAL_ASSERT(event.isCreated());
TORCH_INTERNAL_ASSERT(event.device_index() == deviceIndex_);
TORCH_INTERNAL_ASSERT(
std::find(
deviceIndices_.begin(),
deviceIndices_.end(),
event.device_index()) != deviceIndices_.end());
}
for (const at::DataPtr& data_ptr : extractDataPtrs(value_)) {
TORCH_INTERNAL_ASSERT(data_ptr.device().index() == deviceIndex_);
TORCH_INTERNAL_ASSERT(
std::find(
deviceIndices_.begin(),
deviceIndices_.end(),
data_ptr.device().index()) != deviceIndices_.end());
}
}

private:
explicit FutureNCCL(c10::DeviceIndex deviceIndex)
: at::ivalue::Future(c10::ListType::create(c10::TensorType::get())),
deviceIndex_(deviceIndex) {}
FutureNCCL()
: at::ivalue::Future(c10::ListType::create(c10::TensorType::get())) {}
// We need this because it will be the ::make() static method that actually
// creates the instance. This is a brittle approach and the passkey idiom
// would be a more robust solution. However, this will go away in #48505.
Expand All @@ -248,11 +256,17 @@ class ProcessGroupNCCL : public ProcessGroup {
if (error_) {
throw *error_;
}
auto stream = at::cuda::getCurrentCUDAStream(deviceIndex_);
(*cudaEvents_)[0].block(stream);

for (int i = 0; i < deviceIndices_.size(); i++) {
(*cudaEvents_)[i].block(
at::cuda::getCurrentCUDAStream(deviceIndices_[i]));
}

for (const at::DataPtr& data_ptr : extractDataPtrs(value_)) {
c10::cuda::CUDACachingAllocator::recordStream(data_ptr, stream);
if (data_ptr.device().is_cuda()) {
c10::cuda::CUDACachingAllocator::recordStream(
data_ptr, at::cuda::getCurrentCUDAStream(data_ptr.device().index()));
}
}
}

Expand All @@ -265,18 +279,25 @@ class ProcessGroupNCCL : public ProcessGroup {
"Attempting to set value of a FutureNCCL which has a value."
"FutureNCCL's value was internally set to NCCL collective's "
"outputs or the return value of the callback.");
for (const at::DataPtr& data_ptr : extractDataPtrs(value)) {
TORCH_INTERNAL_ASSERT(data_ptr.device().index() == deviceIndex_);
}
value_ = std::move(value);

TORCH_INTERNAL_ASSERT(cudaEvents_ == nullptr);
// Create a new cudaEvents object of size 1 that will record the current
// stream after callback and will be passed to the new FutureNCCL.
cudaEvents_ = std::make_shared<std::vector<at::cuda::CUDAEvent>>(1);
// In case of chained then callback calls, cudaEvents
// records callback's stream.
(*cudaEvents_)[0].record(at::cuda::getCurrentCUDAStream(deviceIndex_));
std::vector<bool> isCudaDeviceUsed(c10::cuda::device_count(), false);
for (const at::DataPtr& data_ptr : extractDataPtrs(value_)) {
if (data_ptr.device().is_cuda()) {
isCudaDeviceUsed[data_ptr.device().index()] = true;
}
}

cudaEvents_ = std::make_shared<std::vector<at::cuda::CUDAEvent>>();
for (c10::DeviceIndex idx = 0; idx < isCudaDeviceUsed.size(); idx++) {
if (isCudaDeviceUsed[idx]) {
at::cuda::CUDAEvent cudaEvent;
cudaEvent.record(at::cuda::getCurrentCUDAStream(idx));
deviceIndices_.push_back(idx);
(*cudaEvents_).push_back(std::move(cudaEvent));
}
}
}

// Just returns FutureNCCL's value after wait returns.
Expand All @@ -297,19 +318,37 @@ class ProcessGroupNCCL : public ProcessGroup {
// this callback. This new FutureNCCL's cudaEvents will record the
// callback's stream and will have the result value of the callback.
void addCallback(std::function<void(void)> callback) override {
// FIXME Should we find a way to allow to change the priority of streams?
at::cuda::CUDAStream stream =
at::cuda::getStreamFromPool(/*isHighPriority=*/false, deviceIndex_);
// We'd love to get a stream for all devices, even those that are not used
// by the value, because the callback could use those other devices, but
// unfortunately this could cause a deadlock with NCCL. See
// https://github.com/pytorch/pytorch/pull/48500#issuecomment-735395414
// In general, if some devices haven't been used yet, by getting a stream
// for them we'd initialize them, and in addition to causing NCCL to
// misbehaving this also ends up using memory on those devices, which the
// user might not want.
std::vector<at::cuda::CUDAStream> streams;
for (int i = 0; i < deviceIndices_.size(); i++) {
c10::DeviceIndex idx = deviceIndices_[i];
// FIXME Should we find a way to allow to change the priority of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when do we need high-priority streams?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually have no idea. Do you know when high- vs low-priority streams were used in ProcessGroupNCCL? What was the reason? Does it still apply here?

Copy link
Contributor

@mrshenli mrshenli Dec 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have never seen a user trying to configure stream priority for NCCL ops. And I don't think it's possible with the init_process_group API. Users probably will have to use the un-documented ProcessGroupNCCL ctor API.

Besides, I am also not sure how much impact the stream priority can have on the schedule and how visible that is to the e2e perf.

If there is no specific use case for now, we probably can keep it simple for now? It will be interesting to run some experiments to quantify its impact on perf in the future.

// streams?
at::cuda::CUDAStream stream =
at::cuda::getStreamFromPool(/*isHighPriority=*/false, idx);
(*cudaEvents_)[i].block(stream);
streams.push_back(stream);
}

// Use the dedicated callback stream to run callback.
at::cuda::CUDAMultiStreamGuard streamGuard(streams);

// Do not free the underlying data storage of value_ before its
// usage on the stream finishes.
for (const at::DataPtr& data_ptr : extractDataPtrs(value_)) {
c10::cuda::CUDACachingAllocator::recordStream(data_ptr, stream);
if (data_ptr.device().is_cuda()) {
c10::cuda::CUDACachingAllocator::recordStream(
data_ptr, at::cuda::getCurrentCUDAStream(data_ptr.device().index()));
}
}

(*cudaEvents_)[0].block(stream);
// Use the dedicated callback stream to run callback.
c10::StreamGuard streamGuard{stream};
callback();
}

Expand All @@ -319,7 +358,7 @@ class ProcessGroupNCCL : public ProcessGroup {
c10::intrusive_ptr<Future> then(
std::function<at::IValue(void)> callback,
at::TypePtr /* unused */) override {
auto fut = c10::make_intrusive<FutureNCCL>(deviceIndex_);
auto fut = c10::make_intrusive<FutureNCCL>();
// The new future needs the DataPtr extractor when it gets marked complete
// but this might happen immediately inline or in parallel by another
// thread. In both these cases this would/might happen before the user has
Expand Down Expand Up @@ -361,7 +400,7 @@ class ProcessGroupNCCL : public ProcessGroup {

private:
at::IValue value_;
c10::DeviceIndex deviceIndex_;
std::vector<c10::DeviceIndex> deviceIndices_;
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents_;
DataPtrExtractor dataPtrExtractor_;
c10::optional<FutureError> error_;
Expand Down