Skip to content

Commit

Permalink
Merge common parts of FutureNCCL into at::ivalue::Future
Browse files Browse the repository at this point in the history
Pull Request resolved: #48505

This commit is part of a stack that reworks FutureNCCL in order to extract a generic CUDA-aware Future subclass. The stack deliberately breaks up this transition into elementary changes, to make it easier to verify that the behavior is preserved (or to highlight how it gets changed).

---

FutureNCCL isn't just adding CUDA support to ivalue::Future, it's also reimplementing a lot of the latter's logic (by overriding plenty of its methods). That's brittle, as whenever a new method is added to ivalue::Future there's a risk of forgetting to add it to FutureNCCL, and in such a case calling this method on FutureNCCL would defer to the base class and give inconsistent results (e.g., future not being completed when it actually is). This _is already happening_, for example with the waitAndThrow or hasError, which are not implemented by FutureNCCL. In addition, this creates duplication between the two classes, which could lead to inconsistencies of behavior, bugs, missing features, ...

The best solution would be to keep the core future logic in ivalue::Future, and have _only_ the CUDA additions in FutureNCCL. That's what we're going to do, in two steps. In the previous commit, I split the CUDA features into separate hooks, which are called by FutureNCCL's other methods. In this commit, I'm removing these latter methods, and invoke the hooks directly from ivalue::Future.
ghstack-source-id: 117439437

Differential Revision: [D25180535](https://our.internmc.facebook.com/intern/diff/D25180535/)
  • Loading branch information
lw committed Nov 29, 2020
1 parent 2356eae commit f7cff2f
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 127 deletions.
52 changes: 35 additions & 17 deletions aten/src/ATen/core/ivalue_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,18 +290,22 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
/**
* Wait on the future until it completes.
*/
virtual void wait() {
void wait() {
std::unique_lock<std::mutex> lock(mutex_);
while (!completed_) {
finished_cv_.wait(lock);
}

if (!eptr_) {
postWaitHook(value_);
}
}

/**
* Wait on the future until it completes and throw an
* exception if an error exists.
*/
virtual void waitAndThrow() {
void waitAndThrow() {
std::unique_lock<std::mutex> lock(mutex_);
while (!completed_) {
finished_cv_.wait(lock);
Expand All @@ -310,12 +314,14 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
if (eptr_) {
std::rethrow_exception(eptr_);
}

postWaitHook(value_);
}

/**
* Explicitly mark the future as completed with the output value.
*/
virtual void markCompleted(IValue value) {
void markCompleted(IValue value) {
std::unique_lock<std::mutex> lock(mutex_);
TORCH_CHECK(
!completed(),
Expand All @@ -324,6 +330,8 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
completed_ = true;
value_ = std::move(value);

postMarkCompletedHook(value_);

std::vector<std::function<void(void)>> cbs;
cbs.swap(callbacks_);
lock.unlock();
Expand Down Expand Up @@ -359,7 +367,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
}

// Get the result of the current future.
virtual IValue value() {
IValue value() {
std::unique_lock<std::mutex> lock(mutex_);
AT_ASSERT(completed());
if (eptr_) {
Expand All @@ -370,7 +378,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {

// This accessor should only be used if we know that the future is
// completed() with no error.
virtual const IValue& constValue() {
const IValue& constValue() {
std::unique_lock<std::mutex> lock(mutex_);
AT_ASSERT(completed());
AT_ASSERT(!eptr_);
Expand All @@ -383,8 +391,9 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
* If the future has already completed,
* this function will execute the callback immediately.
*/
virtual void addCallback(std::function<void(void)> callback) {
void addCallback(std::function<void(void)> callback) {
std::unique_lock<std::mutex> lock(mutex_);
callback = wrapCallback(std::move(callback));
if (completed()) {
lock.unlock();
callback();
Expand All @@ -398,22 +407,18 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
* value of the callback. This is necessary when the callback provider needs
* to know for sure when the callback has finished.
*/
virtual c10::intrusive_ptr<Future> then(
c10::intrusive_ptr<Future> then(
std::function<IValue(void)> callback,
TypePtr type) {
auto fut = c10::make_intrusive<Future>(type);
// Cannot move capture std::function in lambda, because it cannot deduce
// the template type for std::function. Hence use std::bind to explicitly
// specify types.
addCallback(std::bind(
[fut](std::function<IValue(void)> cb) {
auto fut = createInstance(std::move(type));
addCallback(
[fut, cb = std::move(callback)]() {
try {
fut->markCompleted(cb());
} catch (std::exception& e) {
fut->setError(std::current_exception());
}
},
std::move(callback)));
});
return fut;
}

Expand Down Expand Up @@ -452,11 +457,11 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
}

// Check if the current future has completed
virtual bool completed() const {
bool completed() const {
return completed_;
}

virtual bool hasValue() const {
bool hasValue() const {
std::unique_lock<std::mutex> lock(mutex_);
return completed_ && !eptr_;
}
Expand All @@ -479,6 +484,17 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
return type_;
}

protected:
virtual c10::intrusive_ptr<Future> createInstance(at::TypePtr type) {
return c10::make_intrusive<Future>(type);
}

virtual void postMarkCompletedHook(const at::IValue& value) {}

virtual std::function<void(void)> wrapCallback(std::function<void(void)> callback) { return callback; }

virtual void postWaitHook(const at::IValue& value) {}

private:
void setErrorInternal(
std::exception_ptr eptr,
Expand All @@ -487,6 +503,8 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
completed_ = true;
eptr_ = std::move(eptr);

// Do not call postMarkCompletedHook() here as there isn't any value.

std::vector<std::function<void(void)>> cbs;
cbs.swap(callbacks_);
lock.unlock();
Expand Down
139 changes: 29 additions & 110 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,11 @@ class ProcessGroupNCCL : public ProcessGroup {
at::IValue value,
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents)
: at::ivalue::Future(c10::ListType::create(c10::TensorType::get())),
value_(std::move(value)),
cudaEvents_(std::move(cudaEvents)) {
for (const at::cuda::CUDAEvent& event : *cudaEvents_) {
TORCH_INTERNAL_ASSERT(event.isCreated());
}
for (const at::DataPtr& data_ptr : extractDataPtrs(value_)) {
for (const at::DataPtr& data_ptr : extractDataPtrs(value)) {
TORCH_INTERNAL_ASSERT(
std::find_if(
cudaEvents_->begin(),
Expand All @@ -227,71 +226,21 @@ class ProcessGroupNCCL : public ProcessGroup {
return ev.device_index() == data_ptr.device().index();
}) != cudaEvents_->end());
}
markCompleted(std::move(value));
}

private:
FutureNCCL(at::TypePtr type) : at::ivalue::Future(std::move(type)) {}
// We need this because it will be the ::make() static method that actually
// creates the instance. This is a brittle approach and the passkey idiom
// would be a more robust solution. However, this will go away in #48505.
friend c10::intrusive_ptr<FutureNCCL>;
using at::ivalue::Future::Future;

public:
// Gets the current stream of the device and synchronizes recorded streams
// with that. It will return after synchronizing the correct GPU streams to
// ensure we can have async CUDA execution and it does not wait for the
// entire operation to complete on GPU.
void wait() override {
if (error_) {
throw *error_;
void setDataPtrExtractor(DataPtrExtractor data_ptr_extractor) override {
// To avoid races with other threads that may be using the extractor, we
// won't modify it after it's first set.
if (dataPtrExtractor_ == nullptr) {
dataPtrExtractor_ = std::move(data_ptr_extractor);
}

postWaitHook();
}

// If FutureNCCL was created by FutureNCCL::then, its value would be empty
// initially. FutureNCCL::then will later use this method to set its value
// to the return value of the callback.
void markCompleted(at::IValue value) override {
TORCH_INTERNAL_ASSERT(
value_.isNone(),
"Attempting to set value of a FutureNCCL which has a value."
"FutureNCCL's value was internally set to NCCL collective's "
"outputs or the return value of the callback.");
value_ = std::move(value);

postMarkCompletedHook();
}

// Just returns FutureNCCL's value after wait returns.
at::IValue value() override {
TORCH_INTERNAL_ASSERT(hasValue(), "FutureNCCL's value is None.")
wait();
return value_;
}

const at::IValue& constValue() override {
TORCH_INTERNAL_ASSERT(hasValue(), "FutureNCCL's value is None.")
wait();
return value_;
}

// Adds a callback to FutureNCCL. It invokes the callback inline after
// synchronizing FutureNCCL's own cudaEvents with the stream that runs
// this callback. This new FutureNCCL's cudaEvents will record the
// callback's stream and will have the result value of the callback.
void addCallback(std::function<void(void)> callback) override {
std::function<void(void)> wrappedCallback =
wrapCallback(std::move(callback));
wrappedCallback();
}

// Adds a callback to FutureNCCL, and returns another FutureNCCL to hold
// the return value of the callback and new cudaEvents that recorded the
// stream that runs this callback.
c10::intrusive_ptr<Future> then(
std::function<at::IValue(void)> callback,
at::TypePtr type) override {
protected:
c10::intrusive_ptr<Future> createInstance(at::TypePtr type) override {
auto fut = c10::make_intrusive<FutureNCCL>(std::move(type));
// The new future needs the DataPtr extractor when it gets marked complete
// but this might happen immediately inline or in parallel by another
Expand All @@ -300,59 +249,31 @@ class ProcessGroupNCCL : public ProcessGroup {
// if the default extractor can't handle some of the user's types.
// Therefore we propagate our extractor.
fut->setDataPtrExtractor(dataPtrExtractor_);

// Cannot move capture std::function in lambda, because it cannot deduce
// the template type for std::function. Hence use std::bind to explicitly
// specify types.
addCallback(std::bind(
[&](std::function<at::IValue(void)> cb) {
try {
fut->markCompleted(at::IValue(cb()));
} catch (const std::exception& e) {
fut->setError(std::current_exception());
}
},
std::move(callback)));
return fut;
}

bool completed() const override {
return true;
}

bool hasValue() const override {
return !value_.isNone();
}

void setDataPtrExtractor(DataPtrExtractor data_ptr_extractor) override {
// To avoid races with other threads that may be using the extractor, we
// won't modify it after it's first set.
if (dataPtrExtractor_ == nullptr) {
dataPtrExtractor_ = std::move(data_ptr_extractor);
}
}

protected:
void postMarkCompletedHook() {
TORCH_INTERNAL_ASSERT(cudaEvents_ == nullptr);
std::vector<bool> isCudaDeviceUsed(c10::cuda::device_count(), false);
for (const at::DataPtr& data_ptr : extractDataPtrs(value_)) {
if (data_ptr.device().is_cuda()) {
isCudaDeviceUsed[data_ptr.device().index()] = true;
void postMarkCompletedHook(const at::IValue& value) override {
// Check whether the first or second constructor created this instance.
if (cudaEvents_ == nullptr) {
std::vector<bool> isCudaDeviceUsed(c10::cuda::device_count(), false);
for (const at::DataPtr& data_ptr : extractDataPtrs(value)) {
if (data_ptr.device().is_cuda()) {
isCudaDeviceUsed[data_ptr.device().index()] = true;
}
}
}

cudaEvents_ = std::make_shared<std::vector<at::cuda::CUDAEvent>>();
for (c10::DeviceIndex idx = 0; idx < isCudaDeviceUsed.size(); idx++) {
if (isCudaDeviceUsed[idx]) {
at::cuda::CUDAEvent cudaEvent;
cudaEvent.record(at::cuda::getCurrentCUDAStream(idx));
(*cudaEvents_).push_back(std::move(cudaEvent));
cudaEvents_ = std::make_shared<std::vector<at::cuda::CUDAEvent>>();
for (c10::DeviceIndex idx = 0; idx < isCudaDeviceUsed.size(); idx++) {
if (isCudaDeviceUsed[idx]) {
at::cuda::CUDAEvent cudaEvent;
cudaEvent.record(at::cuda::getCurrentCUDAStream(idx));
(*cudaEvents_).push_back(std::move(cudaEvent));
}
}
}
}

std::function<void(void)> wrapCallback(std::function<void(void)> callback) {
std::function<void(void)> wrapCallback(std::function<void(void)> callback) override {
return [this, callback{std::move(callback)}]() {
// We'd love to get a stream for all devices, even those that are not used
// by the value, because the callback could use those other devices, but
Expand All @@ -378,7 +299,7 @@ class ProcessGroupNCCL : public ProcessGroup {

// Do not free the underlying data storage of value_ before its
// usage on the stream finishes.
for (const at::DataPtr& data_ptr : extractDataPtrs(value_)) {
for (const at::DataPtr& data_ptr : extractDataPtrs(constValue())) {
if (data_ptr.device().is_cuda()) {
c10::cuda::CUDACachingAllocator::recordStream(
data_ptr, at::cuda::getCurrentCUDAStream(data_ptr.device().index()));
Expand All @@ -389,13 +310,13 @@ class ProcessGroupNCCL : public ProcessGroup {
};
}

void postWaitHook() {
void postWaitHook(const at::IValue& value) override {
for (at::cuda::CUDAEvent& cudaEvent : *cudaEvents_) {
cudaEvent.block(
at::cuda::getCurrentCUDAStream(cudaEvent.device_index()));
}

for (const at::DataPtr& data_ptr : extractDataPtrs(value_)) {
for (const at::DataPtr& data_ptr : extractDataPtrs(value)) {
if (data_ptr.device().is_cuda()) {
c10::cuda::CUDACachingAllocator::recordStream(
data_ptr, at::cuda::getCurrentCUDAStream(data_ptr.device().index()));
Expand All @@ -404,10 +325,8 @@ class ProcessGroupNCCL : public ProcessGroup {
}

private:
at::IValue value_;
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents_;
DataPtrExtractor dataPtrExtractor_;
c10::optional<FutureError> error_;

std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs(
const at::IValue& value) {
Expand Down

0 comments on commit f7cff2f

Please sign in to comment.