Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge common parts of FutureNCCL into at::ivalue::Future #48505

Closed
wants to merge 10 commits into from
50 changes: 33 additions & 17 deletions aten/src/ATen/core/ivalue_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,23 +290,27 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
/**
* Wait on the future until it completes.
*/
virtual void wait() {
void wait() {
std::unique_lock<std::mutex> lock(mutex_);
while (!completed_) {
finished_cv_.wait(lock);
}

postWaitHook();
}

/**
* Wait on the future until it completes and throw an
* exception if an error exists.
*/
virtual void waitAndThrow() {
void waitAndThrow() {
std::unique_lock<std::mutex> lock(mutex_);
while (!completed_) {
finished_cv_.wait(lock);
}

postWaitHook();

if (eptr_) {
std::rethrow_exception(eptr_);
}
Expand All @@ -315,7 +319,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
/**
* Explicitly mark the future as completed with the output value.
*/
virtual void markCompleted(IValue value) {
void markCompleted(IValue value) {
std::unique_lock<std::mutex> lock(mutex_);
TORCH_CHECK(
!completed(),
Expand All @@ -324,6 +328,8 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
completed_ = true;
value_ = std::move(value);

postMarkCompletedHook(value_);

std::vector<std::function<void(void)>> cbs;
cbs.swap(callbacks_);
lock.unlock();
Expand Down Expand Up @@ -359,7 +365,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
}

// Get the result of the current future.
virtual IValue value() {
IValue value() {
std::unique_lock<std::mutex> lock(mutex_);
AT_ASSERT(completed());
if (eptr_) {
Expand All @@ -370,7 +376,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {

// This accessor should only be used if we know that the future is
// completed() with no error.
virtual const IValue& constValue() {
const IValue& constValue() {
std::unique_lock<std::mutex> lock(mutex_);
AT_ASSERT(completed());
AT_ASSERT(!eptr_);
Expand All @@ -383,8 +389,9 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
* If the future has already completed,
* this function will execute the callback immediately.
*/
virtual void addCallback(std::function<void(void)> callback) {
void addCallback(std::function<void(void)> callback) {
std::unique_lock<std::mutex> lock(mutex_);
callback = wrapCallback(std::move(callback));
if (completed()) {
lock.unlock();
callback();
Expand All @@ -398,22 +405,18 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
* value of the callback. This is necessary when the callback provider needs
* to know for sure when the callback has finished.
*/
virtual c10::intrusive_ptr<Future> then(
c10::intrusive_ptr<Future> then(
std::function<IValue(void)> callback,
TypePtr type) {
auto fut = c10::make_intrusive<Future>(type);
// Cannot move capture std::function in lambda, because it cannot deduce
// the template type for std::function. Hence use std::bind to explicitly
// specify types.
Comment on lines -405 to -407
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was curious to see what the reason for this problem was so I tried to undo this fix... and it seems to work? Maybe the comment was outdated?

addCallback(std::bind(
[fut](std::function<IValue(void)> cb) {
auto fut = createInstance(std::move(type));
addCallback(
[fut, cb = std::move(callback)]() {
try {
fut->markCompleted(cb());
} catch (std::exception& e) {
fut->setError(std::current_exception());
}
},
std::move(callback)));
});
return fut;
}

Expand Down Expand Up @@ -452,11 +455,11 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
}

// Check if the current future has completed
virtual bool completed() const {
bool completed() const {
return completed_;
}

virtual bool hasValue() const {
bool hasValue() const {
std::unique_lock<std::mutex> lock(mutex_);
return completed_ && !eptr_;
}
Expand All @@ -479,6 +482,17 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
return type_;
}

protected:
virtual c10::intrusive_ptr<Future> createInstance(at::TypePtr type) {
return c10::make_intrusive<Future>(type);
}

virtual void postMarkCompletedHook(const at::IValue& value) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you briefly document how these there apis are used, probably only FutureNCCL is using them now, but if there're other future derived type there, might be a good reference.


virtual std::function<void(void)> wrapCallback(std::function<void(void)> callback) { return callback; }

virtual void postWaitHook() {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add some comments to these functions and explain what derived classes need to do when implementing them?


private:
void setErrorInternal(
std::exception_ptr eptr,
Expand All @@ -487,6 +501,8 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
completed_ = true;
eptr_ = std::move(eptr);

// Do not call postMarkCompletedHook() here as there isn't any value.

std::vector<std::function<void(void)>> cbs;
cbs.swap(callbacks_);
lock.unlock();
Expand Down
100 changes: 11 additions & 89 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,101 +213,25 @@ class ProcessGroupNCCL : public ProcessGroup {
at::IValue value,
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents)
: at::ivalue::Future(c10::ListType::create(c10::TensorType::get())),
value_(std::move(value)),
cudaEvents_(std::move(cudaEvents)) {
markCompleted(std::move(value));
}

FutureNCCL(at::TypePtr type) : at::ivalue::Future(std::move(type)) {}

// Gets the current stream of the device and synchronizes recorded streams
// with that. It will return after synchronizing the correct GPU streams to
// ensure we can have async CUDA execution and it does not wait for the
// entire operation to complete on GPU.
void wait() override {
if (error_) {
throw *error_;
}

postWaitHook();
}

// If FutureNCCL was created by FutureNCCL::then, its value would be empty
// initially. FutureNCCL::then will later use this method to set its value
// to the return value of the callback.
void markCompleted(at::IValue value) override {
TORCH_INTERNAL_ASSERT(
value_.isNone(),
"Attempting to set value of a FutureNCCL which has a value."
"FutureNCCL's value was internally set to NCCL collective's "
"outputs or the return value of the callback.");
value_ = std::move(value);

postMarkCompletedHook();
}

// Just returns FutureNCCL's value after wait returns.
at::IValue value() override {
TORCH_INTERNAL_ASSERT(hasValue(), "FutureNCCL's value is None.")
wait();
return value_;
}

const at::IValue& constValue() override {
TORCH_INTERNAL_ASSERT(hasValue(), "FutureNCCL's value is None.")
wait();
return value_;
}

// Adds a callback to FutureNCCL. It invokes the callback inline after
// synchronizing FutureNCCL's own cudaEvents with the stream that runs
// this callback. This new FutureNCCL's cudaEvents will record the
// callback's stream and will have the result value of the callback.
void addCallback(std::function<void(void)> callback) override {
std::function<void(void)> wrappedCallback =
wrapCallback(std::move(callback));
wrappedCallback();
}

// Adds a callback to FutureNCCL, and returns another FutureNCCL to hold
// the return value of the callback and new cudaEvents that recorded the
// stream that runs this callback.
c10::intrusive_ptr<Future> then(
std::function<at::IValue(void)> callback,
at::TypePtr type) override {
auto fut = c10::make_intrusive<FutureNCCL>(std::move(type));

// Cannot move capture std::function in lambda, because it cannot deduce
// the template type for std::function. Hence use std::bind to explicitly
// specify types.
addCallback(std::bind(
[&](std::function<at::IValue(void)> cb) {
try {
fut->markCompleted(at::IValue(cb()));
} catch (const std::exception& e) {
fut->setError(std::current_exception());
}
},
std::move(callback)));
return fut;
}

bool completed() const override {
return true;
}

bool hasValue() const override {
return !value_.isNone();
}
using at::ivalue::Future::Future;

void setDataPtrExtractor(DataPtrExtractor data_ptr_extractor) override {
dataPtrExtractor_ = std::move(data_ptr_extractor);
}

protected:
void postMarkCompletedHook() {
c10::intrusive_ptr<Future> createInstance(at::TypePtr type) override {
return c10::make_intrusive<FutureNCCL>(std::move(type));
}

void postMarkCompletedHook(const at::IValue& value) override {
if (cudaEvents_ == nullptr) {
std::vector<bool> isCudaDeviceUsed(c10::cuda::device_count(), false);
for (const at::DataPtr& data_ptr : extractDataPtrs(value_)) {
for (const at::DataPtr& data_ptr : extractDataPtrs(value)) {
if (data_ptr.device().is_cuda()) {
isCudaDeviceUsed[data_ptr.device().index()] = true;
}
Expand All @@ -324,7 +248,7 @@ class ProcessGroupNCCL : public ProcessGroup {
}
}

std::function<void(void)> wrapCallback(std::function<void(void)> callback) {
std::function<void(void)> wrapCallback(std::function<void(void)> callback) override {
return [this, callback{std::move(callback)}]() {
// Get a stream for all devices, even those that are not used by the
// value, because the user's callback could use those other devices.
Expand All @@ -338,7 +262,7 @@ class ProcessGroupNCCL : public ProcessGroup {

// Do not free the underlying data storage of value_ before its
// usage on the stream finishes.
for (const at::DataPtr& data_ptr : extractDataPtrs(value_)) {
for (const at::DataPtr& data_ptr : extractDataPtrs(constValue())) {
if (data_ptr.device().is_cuda()) {
c10::cuda::CUDACachingAllocator::recordStream(
data_ptr, streams[data_ptr.device().index()]);
Expand All @@ -356,18 +280,16 @@ class ProcessGroupNCCL : public ProcessGroup {
};
}

void postWaitHook() {
void postWaitHook() override {
for (at::cuda::CUDAEvent& cudaEvent : *cudaEvents_) {
cudaEvent.block(
at::cuda::getCurrentCUDAStream(cudaEvent.device_index()));
}
}

private:
at::IValue value_;
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents_;
DataPtrExtractor dataPtrExtractor_;
c10::optional<FutureError> error_;

std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs(
const at::IValue& value) {
Expand Down