Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split out reusable CUDAFuture from FutureNCCL #48506

Closed
wants to merge 10 commits into from
153 changes: 153 additions & 0 deletions aten/src/ATen/cuda/CUDAFuture.h
@@ -0,0 +1,153 @@
#pragma once

#include <functional>
#include <memory>
#include <mutex>
#include <utility>
#include <vector>

#include <ATen/core/ivalue.h>
#include <ATen/core/ivalue_inl.h>
#include <ATen/core/jit_type.h>
#include <ATen/cuda/CUDAEvent.h>
#include <ATen/cuda/CUDAMultiStreamGuard.h>
#include <c10/core/Allocator.h>
#include <c10/core/Device.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/macros/Export.h>
#include <c10/util/intrusive_ptr.h>

namespace at { namespace cuda {

struct TORCH_CUDA_API CUDAFuture : at::ivalue::Future {
public:
using at::ivalue::Future::Future;

void setDataPtrExtractor(DataPtrExtractor dataPtrExtractor) override {
std::unique_lock<std::mutex> lock(dataPtrExtractorMutex_);
dataPtrExtractor_ = std::move(dataPtrExtractor);
}

protected:
c10::intrusive_ptr<Future> createInstance(at::TypePtr type) override {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is the only way we want to create an instance, can we delete the default constructor? also, where's the PR that have the createInstance for ivalue::Future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only way I found to have at::ivalue::Future::then() be able to create an instance of at::cuda::CUDAFuture. In short, the a method of the superclass needs to create an instance of the subclass. It's not supposed to be used by external users, hence it's protected.

The superclass implementation of createInstance is here: https://github.com/pytorch/pytorch/pull/48505/files#diff-2f833078d12338d0ac920ab654b5c791cb1219729d6e7c97a08393b32a46d173R491

auto fut = c10::make_intrusive<CUDAFuture>(std::move(type));
// The new future needs the DataPtr extractor when it gets marked complete
// but this might happen immediately inline or in parallel by another
// thread. In both these cases this would/might happen before the user has
// time to set their own DataPtr extractor, which might lead to failures
// if the default extractor can't handle some of the user's types.
// Therefore we propagate our extractor.
fut->setDataPtrExtractor(dataPtrExtractor_);
return fut;
}

void postMarkCompletedHook(const at::IValue& value) override {
std::vector<bool> isCudaDeviceUsed(c10::cuda::device_count(), false);
for (const at::DataPtr& data_ptr : extractDataPtrs(value)) {
if (data_ptr.device().is_cuda()) {
isCudaDeviceUsed[data_ptr.device().index()] = true;
}
}

cudaEvents_ = std::make_shared<std::vector<at::cuda::CUDAEvent>>();
for (c10::DeviceIndex idx = 0; idx < isCudaDeviceUsed.size(); idx++) {
if (isCudaDeviceUsed[idx]) {
at::cuda::CUDAEvent cudaEvent;
cudaEvent.record(at::cuda::getCurrentCUDAStream(idx));
(*cudaEvents_).push_back(std::move(cudaEvent));
}
}
}

std::function<void(void)> wrapCallback(
std::function<void(void)> callback) override {
return [this, callback{std::move(callback)}]() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference of this between capture by reference? are the callbacks gonna destructed sometime but we still want them alive in the lambda?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::function AFAIK is a RAII wrapper so it might hold some state which gets destroyed once the function goes out of scope. If we captured a reference to the std::function that would thus risk being a dangling reference.

// We'd love to get a stream for all devices, even those that are not used
// by the value, because the callback could use those other devices, but
// unfortunately this could cause a deadlock with NCCL. See
// https://github.com/pytorch/pytorch/pull/48500#issuecomment-735395414
// In general, if some devices haven't been used yet, by getting a stream
// for them we'd initialize them, and in addition to causing NCCL to
// misbehaving this also ends up using memory on those devices, which the
// user might not want.
std::vector<at::cuda::CUDAStream> streams;
for (at::cuda::CUDAEvent& cudaEvent : *cudaEvents_) {
c10::DeviceIndex idx = cudaEvent.device_index();
// FIXME Should we find a way to allow to change the priority of
// streams?
at::cuda::CUDAStream stream =
at::cuda::getStreamFromPool(/*isHighPriority=*/false, idx);
cudaEvent.block(stream);
streams.push_back(stream);
}

// Use the dedicated callback stream to run callback.
at::cuda::CUDAMultiStreamGuard streamGuard(streams);

// Do not free the underlying data storage of value_ before its
// usage on the stream finishes.
for (const at::DataPtr& data_ptr : extractDataPtrs(constValue())) {
if (data_ptr.device().is_cuda()) {
c10::cuda::CUDACachingAllocator::recordStream(
data_ptr, at::cuda::getCurrentCUDAStream(data_ptr.device().index()));
}
}

callback();
};
}

void postWaitHook(const at::IValue& value) override {
for (at::cuda::CUDAEvent& cudaEvent : *cudaEvents_) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, this means users cannot call wait on CUDAFuture before it is marked as completed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow: the postWaitHook is called by ivalue::Future once a wait has finished, and that only happens when the future is marked complete. Thus one can certainly start waiting before the future is complete, in which case one will block inside ivalue::Future (waiting on a condition variable). Since that only unblocks after the future is complete, we're guaranteed that postWaitHook will be called after postMarkCompletedHook. Is this what you were wondering about?

cudaEvent.block(
at::cuda::getCurrentCUDAStream(cudaEvent.device_index()));
}

for (const at::DataPtr& data_ptr : extractDataPtrs(value)) {
if (data_ptr.device().is_cuda()) {
c10::cuda::CUDACachingAllocator::recordStream(
data_ptr, at::cuda::getCurrentCUDAStream(data_ptr.device().index()));
}
}
}

// FIXME This field is protected (rather than private) and wrapped in a
// shared_ptr in order to support the FutureNCCL subclass, which wants to set
// the events on its own in order to use the same ones as its WorkNCCL class.
// Once WorkNCCL is gone (as part of the Future and Work merge) this should be
// fixed.
protected:
// The events that correspond to the completion of the async I/O kernels. They
// are recorded on the appropriate streams when the future is marked completed
// and can then be queried/waited/blocked on. There is one event for each
// distinct device on which the value's tensors reside.
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents_;

private:
DataPtrExtractor dataPtrExtractor_;
std::mutex dataPtrExtractorMutex_;

// FIXME This too is protected so that it can be used by FutureNCCL. Please
// undo that once FutureNCCL is dropped in favor of a "vanilla" CUDAFuture.
protected:
std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs(
const at::IValue& value) {
std::unique_lock<std::mutex> lock(dataPtrExtractorMutex_);
std::vector<std::reference_wrapper<const at::DataPtr>> data_ptrs;
if (dataPtrExtractor_ != nullptr) {
// If a Python communication hook is used, dataPtrExtractor_ will be
// set in torch/csrc/jit/python/pybind_utils.h, which allows Python
// dependency to be imported.
data_ptrs = dataPtrExtractor_(value);
} else {
// If a C++ communication hook is used, use the default extractor.
data_ptrs = at::ivalue::Future::defaultDataPtrExtractor(value);
}
return data_ptrs;
}
};

} // namespace cuda
} // namespace at
131 changes: 12 additions & 119 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
Expand Up @@ -13,6 +13,7 @@

#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAEvent.h>
#include <ATen/cuda/CUDAFuture.h>
#include <ATen/cuda/CUDAMultiStreamGuard.h>
#include <c10/core/Stream.h>
#include <c10/core/StreamGuard.h>
Expand Down Expand Up @@ -207,146 +208,38 @@ class ProcessGroupNCCL : public ProcessGroup {
// enables synchronizing the appropriate streams and avoids stalling PyTorch's
// default stream while running the callback. In case of multiple then
// callbacks, each will be executed on its own fresh stream.
struct FutureNCCL : at::ivalue::Future {
struct FutureNCCL : at::cuda::CUDAFuture {
public:
explicit FutureNCCL(
FutureNCCL(
at::IValue value,
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents)
: at::ivalue::Future(c10::ListType::create(c10::TensorType::get())),
cudaEvents_(std::move(cudaEvents)) {
: at::cuda::CUDAFuture(c10::ListType::create(c10::TensorType::get())){
// Check that the device indices are distinct
std::unordered_set<c10::DeviceIndex> uniqueDeviceIndices;
for (const at::cuda::CUDAEvent& event : *cudaEvents_) {
for (const at::cuda::CUDAEvent& event : *cudaEvents) {
TORCH_INTERNAL_ASSERT(event.isCreated());
uniqueDeviceIndices.insert(event.device_index());
}
TORCH_INTERNAL_ASSERT(
cudaEvents_->size() == uniqueDeviceIndices.size(),
"Got ", cudaEvents_->size(), " events, but only ",
cudaEvents->size() == uniqueDeviceIndices.size(),
"Got ", cudaEvents->size(), " events, but only ",
uniqueDeviceIndices.size(), " distinct devices");
for (const at::DataPtr& data_ptr : extractDataPtrs(value)) {
TORCH_INTERNAL_ASSERT(
std::find_if(
cudaEvents_->begin(),
cudaEvents_->end(),
cudaEvents->begin(),
cudaEvents->end(),
[&](const at::cuda::CUDAEvent& ev) {
return ev.device_index() == data_ptr.device().index();
}) != cudaEvents_->end());
}) != cudaEvents->end());
}
cudaEvents_ = std::move(cudaEvents);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these events used anywhere? Will the markCompleted call below this line immediately override these cudaEvents_?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These events will not be overridden when the future is marked complete, because FutureNCCL provides its own postMarkCompletedHook which does nothing (and which overrides the one from CUDAFuture, which is where we store the events).

However, these events will still be used by the postWaitHook, by the wrapCallback, etc.

It's somewhat hacky, but it works, and it's a temporary solution only until we finish merging Future and Work (at which point ProcessGroupNCCL should be able to use CUDAFuture unmodified).

markCompleted(std::move(value));
}

using at::ivalue::Future::Future;

void setDataPtrExtractor(DataPtrExtractor dataPtrExtractor) override {
std::unique_lock<std::mutex> lock(dataPtrExtractorMutex_);
dataPtrExtractor_ = std::move(dataPtrExtractor);
}

protected:
c10::intrusive_ptr<Future> createInstance(at::TypePtr type) override {
auto fut = c10::make_intrusive<FutureNCCL>(std::move(type));
// The new future needs the DataPtr extractor when it gets marked complete
// but this might happen immediately inline or in parallel by another
// thread. In both these cases this would/might happen before the user has
// time to set their own DataPtr extractor, which might lead to failures
// if the default extractor can't handle some of the user's types.
// Therefore we propagate our extractor.
fut->setDataPtrExtractor(dataPtrExtractor_);
return fut;
}

void postMarkCompletedHook(const at::IValue& value) override {
// Check whether the first or second constructor created this instance.
if (cudaEvents_ == nullptr) {
std::vector<bool> isCudaDeviceUsed(c10::cuda::device_count(), false);
for (const at::DataPtr& data_ptr : extractDataPtrs(value)) {
if (data_ptr.device().is_cuda()) {
isCudaDeviceUsed[data_ptr.device().index()] = true;
}
}

cudaEvents_ = std::make_shared<std::vector<at::cuda::CUDAEvent>>();
for (c10::DeviceIndex idx = 0; idx < isCudaDeviceUsed.size(); idx++) {
if (isCudaDeviceUsed[idx]) {
at::cuda::CUDAEvent cudaEvent;
cudaEvent.record(at::cuda::getCurrentCUDAStream(idx));
(*cudaEvents_).push_back(std::move(cudaEvent));
}
}
}
}

std::function<void(void)> wrapCallback(std::function<void(void)> callback) override {
return [this, callback{std::move(callback)}]() {
// We'd love to get a stream for all devices, even those that are not used
// by the value, because the callback could use those other devices, but
// unfortunately this could cause a deadlock with NCCL. See
// https://github.com/pytorch/pytorch/pull/48500#issuecomment-735395414
// In general, if some devices haven't been used yet, by getting a stream
// for them we'd initialize them, and in addition to causing NCCL to
// misbehaving this also ends up using memory on those devices, which the
// user might not want.
std::vector<at::cuda::CUDAStream> streams;
for (at::cuda::CUDAEvent& cudaEvent : *cudaEvents_) {
c10::DeviceIndex idx = cudaEvent.device_index();
// FIXME Should we find a way to allow to change the priority of
// streams?
at::cuda::CUDAStream stream =
at::cuda::getStreamFromPool(/*isHighPriority=*/false, idx);
cudaEvent.block(stream);
streams.push_back(stream);
}

// Use the dedicated callback stream to run callback.
at::cuda::CUDAMultiStreamGuard streamGuard(streams);

// Do not free the underlying data storage of value_ before its
// usage on the stream finishes.
for (const at::DataPtr& data_ptr : extractDataPtrs(constValue())) {
if (data_ptr.device().is_cuda()) {
c10::cuda::CUDACachingAllocator::recordStream(
data_ptr, at::cuda::getCurrentCUDAStream(data_ptr.device().index()));
}
}

callback();
};
}

void postWaitHook(const at::IValue& value) override {
for (at::cuda::CUDAEvent& cudaEvent : *cudaEvents_) {
cudaEvent.block(
at::cuda::getCurrentCUDAStream(cudaEvent.device_index()));
}

for (const at::DataPtr& data_ptr : extractDataPtrs(value)) {
if (data_ptr.device().is_cuda()) {
c10::cuda::CUDACachingAllocator::recordStream(
data_ptr, at::cuda::getCurrentCUDAStream(data_ptr.device().index()));
}
}
}

private:
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents_;
DataPtrExtractor dataPtrExtractor_;
std::mutex dataPtrExtractorMutex_;

std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs(
const at::IValue& value) {
std::unique_lock<std::mutex> lock(dataPtrExtractorMutex_);
std::vector<std::reference_wrapper<const at::DataPtr>> data_ptrs;
if (dataPtrExtractor_ != nullptr) {
// If a Python communication hook is used, dataPtrExtractor_ will be
// set in torch/csrc/jit/python/pybind_utils.h, which allows Python
// dependency to be imported.
data_ptrs = dataPtrExtractor_(value);
} else {
// If a C++ communication hook is used, use the default extractor.
data_ptrs = at::ivalue::Future::defaultDataPtrExtractor(value);
}
return data_ptrs;
// Do nothing because the constructor already stored the events.
}
};

Expand Down