Skip to content

Commit

Permalink
Merge common parts of FutureNCCL into at::ivalue::Future (#48505)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #48505

This commit is part of a stack that reworks FutureNCCL in order to extract a generic CUDA-aware Future subclass. The stack deliberately breaks up this transition into elementary changes, to make it easier to verify that the behavior is preserved (or to highlight how it gets changed).

 ---

FutureNCCL isn't just adding CUDA support to ivalue::Future, it's also reimplementing a lot of the latter's logic (by overriding plenty of its methods). That's brittle, as whenever a new method is added to ivalue::Future there's a risk of forgetting to add it to FutureNCCL, and in such a case calling this method on FutureNCCL would defer to the base class and give inconsistent results (e.g., future not being completed when it actually is). This _is already happening_, for example with the waitAndThrow or hasError, which are not implemented by FutureNCCL. In addition, this creates duplication between the two classes, which could lead to inconsistencies of behavior, bugs, missing features, ...

The best solution would be to keep the core future logic in ivalue::Future, and have _only_ the CUDA additions in FutureNCCL. That's what we're going to do, in two steps. In the previous commit, I split the CUDA features into separate hooks, which are called by FutureNCCL's other methods. In this commit, I'm removing these latter methods, and invoke the hooks directly from ivalue::Future.
ghstack-source-id: 118180032

Test Plan: Unit tests

Reviewed By: wanchaol

Differential Revision: D25180535

fbshipit-source-id: 19181fe133152044eb677062a9e31e5e4ad3c03c
  • Loading branch information
lw authored and facebook-github-bot committed Dec 10, 2020
1 parent 9078088 commit 4c425e8
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 126 deletions.
78 changes: 61 additions & 17 deletions aten/src/ATen/core/ivalue_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,18 +290,22 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
/**
* Wait on the future until it completes.
*/
virtual void wait() {
void wait() {
std::unique_lock<std::mutex> lock(mutex_);
while (!completed_) {
finished_cv_.wait(lock);
}

if (!eptr_) {
postWaitHook(value_);
}
}

/**
* Wait on the future until it completes and throw an
* exception if an error exists.
*/
virtual void waitAndThrow() {
void waitAndThrow() {
std::unique_lock<std::mutex> lock(mutex_);
while (!completed_) {
finished_cv_.wait(lock);
Expand All @@ -310,12 +314,14 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
if (eptr_) {
std::rethrow_exception(eptr_);
}

postWaitHook(value_);
}

/**
* Explicitly mark the future as completed with the output value.
*/
virtual void markCompleted(IValue value) {
void markCompleted(IValue value) {
std::unique_lock<std::mutex> lock(mutex_);
TORCH_CHECK(
!completed(),
Expand All @@ -324,6 +330,8 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
completed_ = true;
value_ = std::move(value);

postMarkCompletedHook(value_);

std::vector<std::function<void(void)>> cbs;
cbs.swap(callbacks_);
lock.unlock();
Expand Down Expand Up @@ -359,7 +367,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
}

// Get the result of the current future.
virtual IValue value() {
IValue value() {
std::unique_lock<std::mutex> lock(mutex_);
AT_ASSERT(completed());
if (eptr_) {
Expand All @@ -370,7 +378,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {

// This accessor should only be used if we know that the future is
// completed() with no error.
virtual const IValue& constValue() {
const IValue& constValue() {
std::unique_lock<std::mutex> lock(mutex_);
AT_ASSERT(completed());
AT_ASSERT(!eptr_);
Expand All @@ -383,8 +391,9 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
* If the future has already completed,
* this function will execute the callback immediately.
*/
virtual void addCallback(std::function<void(void)> callback) {
void addCallback(std::function<void(void)> callback) {
std::unique_lock<std::mutex> lock(mutex_);
callback = wrapCallback(std::move(callback));
if (completed()) {
lock.unlock();
callback();
Expand All @@ -398,22 +407,18 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
* value of the callback. This is necessary when the callback provider needs
* to know for sure when the callback has finished.
*/
virtual c10::intrusive_ptr<Future> then(
c10::intrusive_ptr<Future> then(
std::function<IValue(void)> callback,
TypePtr type) {
auto fut = c10::make_intrusive<Future>(type);
// Cannot move capture std::function in lambda, because it cannot deduce
// the template type for std::function. Hence use std::bind to explicitly
// specify types.
addCallback(std::bind(
[fut](std::function<IValue(void)> cb) {
auto fut = createInstance(std::move(type));
addCallback(
[fut, cb = std::move(callback)]() {
try {
fut->markCompleted(cb());
} catch (std::exception&) {
fut->setError(std::current_exception());
}
},
std::move(callback)));
});
return fut;
}

Expand Down Expand Up @@ -452,11 +457,11 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
}

// Check if the current future has completed
virtual bool completed() const {
bool completed() const {
return completed_;
}

virtual bool hasValue() const {
bool hasValue() const {
std::unique_lock<std::mutex> lock(mutex_);
return completed_ && !eptr_;
}
Expand All @@ -479,6 +484,43 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
return type_;
}

protected:
// This hook is called by this class's then() method when it prepares the
// instance it returns to the caller. It should be overridden by subclasses so
// that they can produce an instace of their own type.
virtual c10::intrusive_ptr<Future> createInstance(at::TypePtr type) {
return c10::make_intrusive<Future>(type);
}

// This hook will be called by this class (the superclass) when the future is
// marked completed _with a value_ (hence not in case of error). This is done
// right away, while the mutex is still held, before any callbacks are run.
// It allows subclasses to further update their state if they so need. For
// example the CUDAFuture subclass uses it to determine what devices the value
// resides on and record an event in those devices' current streams.
virtual void postMarkCompletedHook(const at::IValue& value) {}

// This hook will be called by the addCallback() and the then() methods before
// storing the callback for later execution (or before running it inline if
// the future is already complete). Note that this method could thus be called
// while the future is _not_ yet complete. By default this method does nothing
// but subclasses can override this method to add functionality. For example
// the CUDAFuture subclass ensures the callback runs with CUDA streams which
// are synchronized with the events recorded in the I/O streams.
virtual std::function<void(void)> wrapCallback(
std::function<void(void)> callback) {
return callback;
}

// This hook will be called by this class after a user thread has completed
// waiting on a successful future. It will thus not be called if the future
// completes with an error. It will also not be called if the user accesses
// the future's value without synchronization. Subclasses can override this
// to add some synchronization to the wait. For example, the CUDAFuture
// subclass ensures the user's current CUDA streams synchronize with the I/O
// events stored by the future.
virtual void postWaitHook(const at::IValue& value) {}

private:
void setErrorInternal(
std::exception_ptr eptr,
Expand All @@ -487,6 +529,8 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
completed_ = true;
eptr_ = std::move(eptr);

// Do not call postMarkCompletedHook() here as there isn't any value.

std::vector<std::function<void(void)>> cbs;
cbs.swap(callbacks_);
lock.unlock();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def compute_q(fut):
return [
dist.all_reduce(q, group=group_to_use, async_op=True)
.get_future()
.value()[0]
.wait()[0]
]

def decompress(fut):
Expand Down
135 changes: 27 additions & 108 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ class ProcessGroupNCCL : public ProcessGroup {
at::IValue value,
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents)
: at::ivalue::Future(c10::ListType::create(c10::TensorType::get())),
value_(std::move(value)),
cudaEvents_(std::move(cudaEvents)) {
// Check that the device indices are distinct
std::unordered_set<c10::DeviceIndex> uniqueDeviceIndices;
Expand All @@ -225,7 +224,7 @@ class ProcessGroupNCCL : public ProcessGroup {
cudaEvents_->size() == uniqueDeviceIndices.size(),
"Got ", cudaEvents_->size(), " events, but only ",
uniqueDeviceIndices.size(), " distinct devices");
for (const at::DataPtr& data_ptr : extractDataPtrs(value_)) {
for (const at::DataPtr& data_ptr : extractDataPtrs(value)) {
TORCH_INTERNAL_ASSERT(
std::find_if(
cudaEvents_->begin(),
Expand All @@ -234,71 +233,18 @@ class ProcessGroupNCCL : public ProcessGroup {
return ev.device_index() == data_ptr.device().index();
}) != cudaEvents_->end());
}
markCompleted(std::move(value));
}

private:
FutureNCCL(at::TypePtr type) : at::ivalue::Future(std::move(type)) {}
// We need this because it will be the ::make() static method that actually
// creates the instance. This is a brittle approach and the passkey idiom
// would be a more robust solution. However, this will go away in #48505.
friend c10::intrusive_ptr<FutureNCCL>;
using at::ivalue::Future::Future;

public:
// Gets the current stream of the device and synchronizes recorded streams
// with that. It will return after synchronizing the correct GPU streams to
// ensure we can have async CUDA execution and it does not wait for the
// entire operation to complete on GPU.
void wait() override {
if (error_) {
throw *error_;
}

postWaitHook();
}

// If FutureNCCL was created by FutureNCCL::then, its value would be empty
// initially. FutureNCCL::then will later use this method to set its value
// to the return value of the callback.
void markCompleted(at::IValue value) override {
TORCH_INTERNAL_ASSERT(
value_.isNone(),
"Attempting to set value of a FutureNCCL which has a value."
"FutureNCCL's value was internally set to NCCL collective's "
"outputs or the return value of the callback.");
value_ = std::move(value);

postMarkCompletedHook();
}

// Just returns FutureNCCL's value after wait returns.
at::IValue value() override {
TORCH_INTERNAL_ASSERT(hasValue(), "FutureNCCL's value is None.")
wait();
return value_;
}

const at::IValue& constValue() override {
TORCH_INTERNAL_ASSERT(hasValue(), "FutureNCCL's value is None.")
wait();
return value_;
}

// Adds a callback to FutureNCCL. It invokes the callback inline after
// synchronizing FutureNCCL's own cudaEvents with the stream that runs
// this callback. This new FutureNCCL's cudaEvents will record the
// callback's stream and will have the result value of the callback.
void addCallback(std::function<void(void)> callback) override {
std::function<void(void)> wrappedCallback =
wrapCallback(std::move(callback));
wrappedCallback();
void setDataPtrExtractor(DataPtrExtractor dataPtrExtractor) override {
std::unique_lock<std::mutex> lock(dataPtrExtractorMutex_);
dataPtrExtractor_ = std::move(dataPtrExtractor);
}

// Adds a callback to FutureNCCL, and returns another FutureNCCL to hold
// the return value of the callback and new cudaEvents that recorded the
// stream that runs this callback.
c10::intrusive_ptr<Future> then(
std::function<at::IValue(void)> callback,
at::TypePtr type) override {
protected:
c10::intrusive_ptr<Future> createInstance(at::TypePtr type) override {
auto fut = c10::make_intrusive<FutureNCCL>(std::move(type));
// The new future needs the DataPtr extractor when it gets marked complete
// but this might happen immediately inline or in parallel by another
Expand All @@ -307,56 +253,31 @@ class ProcessGroupNCCL : public ProcessGroup {
// if the default extractor can't handle some of the user's types.
// Therefore we propagate our extractor.
fut->setDataPtrExtractor(dataPtrExtractor_);

// Cannot move capture std::function in lambda, because it cannot deduce
// the template type for std::function. Hence use std::bind to explicitly
// specify types.
addCallback(std::bind(
[&](std::function<at::IValue(void)> cb) {
try {
fut->markCompleted(at::IValue(cb()));
} catch (const std::exception& e) {
fut->setError(std::current_exception());
}
},
std::move(callback)));
return fut;
}

bool completed() const override {
return true;
}

bool hasValue() const override {
return !value_.isNone();
}

void setDataPtrExtractor(DataPtrExtractor dataPtrExtractor) override {
std::unique_lock<std::mutex> lock(dataPtrExtractorMutex_);
dataPtrExtractor_ = std::move(dataPtrExtractor);
}

protected:
void postMarkCompletedHook() {
TORCH_INTERNAL_ASSERT(cudaEvents_ == nullptr);
std::vector<bool> isCudaDeviceUsed(c10::cuda::device_count(), false);
for (const at::DataPtr& data_ptr : extractDataPtrs(value_)) {
if (data_ptr.device().is_cuda()) {
isCudaDeviceUsed[data_ptr.device().index()] = true;
void postMarkCompletedHook(const at::IValue& value) override {
// Check whether the first or second constructor created this instance.
if (cudaEvents_ == nullptr) {
std::vector<bool> isCudaDeviceUsed(c10::cuda::device_count(), false);
for (const at::DataPtr& data_ptr : extractDataPtrs(value)) {
if (data_ptr.device().is_cuda()) {
isCudaDeviceUsed[data_ptr.device().index()] = true;
}
}
}

cudaEvents_ = std::make_shared<std::vector<at::cuda::CUDAEvent>>();
for (c10::DeviceIndex idx = 0; idx < isCudaDeviceUsed.size(); idx++) {
if (isCudaDeviceUsed[idx]) {
at::cuda::CUDAEvent cudaEvent;
cudaEvent.record(at::cuda::getCurrentCUDAStream(idx));
(*cudaEvents_).push_back(std::move(cudaEvent));
cudaEvents_ = std::make_shared<std::vector<at::cuda::CUDAEvent>>();
for (c10::DeviceIndex idx = 0; idx < isCudaDeviceUsed.size(); idx++) {
if (isCudaDeviceUsed[idx]) {
at::cuda::CUDAEvent cudaEvent;
cudaEvent.record(at::cuda::getCurrentCUDAStream(idx));
(*cudaEvents_).push_back(std::move(cudaEvent));
}
}
}
}

std::function<void(void)> wrapCallback(std::function<void(void)> callback) {
std::function<void(void)> wrapCallback(std::function<void(void)> callback) override {
return [this, callback{std::move(callback)}]() {
// We'd love to get a stream for all devices, even those that are not used
// by the value, because the callback could use those other devices, but
Expand All @@ -382,7 +303,7 @@ class ProcessGroupNCCL : public ProcessGroup {

// Do not free the underlying data storage of value_ before its
// usage on the stream finishes.
for (const at::DataPtr& data_ptr : extractDataPtrs(value_)) {
for (const at::DataPtr& data_ptr : extractDataPtrs(constValue())) {
if (data_ptr.device().is_cuda()) {
c10::cuda::CUDACachingAllocator::recordStream(
data_ptr, at::cuda::getCurrentCUDAStream(data_ptr.device().index()));
Expand All @@ -393,13 +314,13 @@ class ProcessGroupNCCL : public ProcessGroup {
};
}

void postWaitHook() {
void postWaitHook(const at::IValue& value) override {
for (at::cuda::CUDAEvent& cudaEvent : *cudaEvents_) {
cudaEvent.block(
at::cuda::getCurrentCUDAStream(cudaEvent.device_index()));
}

for (const at::DataPtr& data_ptr : extractDataPtrs(value_)) {
for (const at::DataPtr& data_ptr : extractDataPtrs(value)) {
if (data_ptr.device().is_cuda()) {
c10::cuda::CUDACachingAllocator::recordStream(
data_ptr, at::cuda::getCurrentCUDAStream(data_ptr.device().index()));
Expand All @@ -408,11 +329,9 @@ class ProcessGroupNCCL : public ProcessGroup {
}

private:
at::IValue value_;
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents_;
DataPtrExtractor dataPtrExtractor_;
std::mutex dataPtrExtractorMutex_;
c10::optional<FutureError> error_;

std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs(
const at::IValue& value) {
Expand Down

0 comments on commit 4c425e8

Please sign in to comment.