Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 33 additions & 18 deletions aten/src/ATen/core/ivalue_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,15 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {

friend c10::intrusive_ptr<Future>;

struct FutureCallback {
std::function<void(Future&)> callback;
bool uses_future; // whether the Future& passed in is actually used

template <typename T>
FutureCallback(T callback, bool uses_future)
: callback(std::move(callback)), uses_future(uses_future) {}
};

public:
Future(const Future&) = delete;
Future(Future&&) = delete;
Expand Down Expand Up @@ -942,13 +951,13 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
events_.push_back(std::move(event));
}

std::vector<std::function<void(Future&)>> cbs;
std::vector<FutureCallback> cbs;
cbs.swap(callbacks_);
lock.unlock();

finished_cv_.notify_all();
for (auto& callback : cbs) {
invokeCallback(std::move(callback));
invokeCallback(std::move(callback.callback), callback.uses_future);
}
}

Expand Down Expand Up @@ -1023,19 +1032,20 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
* this function will execute the callback immediately.
*/
template <typename T>
void addCallback(T callback) {
void addCallback(T callback, bool uses_future = true) {
#if __cpp_lib_is_invocable >= 201703
static_assert(
std::is_invocable_r<void, T, Future&>::value,
"The callback must have signature void(Future&)");
#endif

std::unique_lock<std::mutex> lock(mutex_);
if (completed()) {
lock.unlock();
invokeCallback(std::move(callback));
invokeCallback(std::move(callback), uses_future);
return;
}
callbacks_.emplace_back(std::move(callback));
callbacks_.emplace_back(std::move(callback), uses_future);
}

/**
Expand Down Expand Up @@ -1153,24 +1163,29 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
// set up before running the callback, as in, it will set up the CUDA streams,
// synchronize them with the value, and so on (if needed).
template<typename T>
void invokeCallback(T callback) {
void invokeCallback(T callback, bool uses_future) {
#if __cpp_lib_is_invocable >= 201703
static_assert(
std::is_invocable_r<void, T, Future&>::value,
"The callback must have signature void(Future&)");
#endif

c10::OptionalDeviceGuard deviceGuard(currentDevice_);
// The synchronization performed below shouldn't be needed when the future
// is not used by the callback.
if (uses_future) {
c10::OptionalDeviceGuard deviceGuard(currentDevice_);

std::vector<c10::Stream> streams;
streams.reserve(devices_.size());
for (const c10::Device& device : devices_) {
streams.push_back(impl_.getStreamFromGlobalPool(device));
std::vector<c10::Stream> streams;
streams.reserve(devices_.size());
for (const c10::Device& device : devices_) {
streams.push_back(impl_.getStreamFromGlobalPool(device));
}
c10::MultiStreamGuard streamGuard(streams);
synchronizeWithCurrentStreams();
callback(*this);
} else {
callback(*this);
}
c10::MultiStreamGuard streamGuard(streams);
synchronizeWithCurrentStreams();

callback(*this);
}

// This method should be called before this future's value is used, as it
Expand Down Expand Up @@ -1206,13 +1221,13 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
completed_ = true;
eptr_ = std::move(eptr);

std::vector<std::function<void(Future&)>> cbs;
std::vector<FutureCallback> cbs;
cbs.swap(callbacks_);
lock.unlock();

finished_cv_.notify_all();
for (auto& callback : cbs) {
invokeCallback(std::move(callback));
invokeCallback(std::move(callback.callback), callback.uses_future);
}
}

Expand Down Expand Up @@ -1353,7 +1368,7 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {

IValue value_; // when finished the value
TypePtr type_;
std::vector<std::function<void(Future&)>> callbacks_;
std::vector<FutureCallback> callbacks_;
std::exception_ptr eptr_;

// An upcast pointer to a virtual class which allows us to manipulate events,
Expand Down
22 changes: 16 additions & 6 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1819,9 +1819,14 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
// future blocks the stream this callback runs on the corresponding
// ncclEndEvents_ ensuring appropriate synchronization.
if (work->recordFunctionEndCallback_) {
work->future_->addCallback([work](at::ivalue::Future& /* unused */) {
work->recordFunctionEndCallback_();
});
work->future_->addCallback(
[work](at::ivalue::Future& /* unused */) {
work->recordFunctionEndCallback_();
},
// uses_future = false allows us to skip synchronization in
// ivalue::Future, but is only valid as long as the lambda doesn't use
// the "Future" argument.
/*uses_future=*/false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to directly set this to False, or we want to make it configurable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that, since the callback doesn't use the future, there's no point in synchronizing when the callback is invoked.

But if there's a reason to, LMK and I can make it configurable

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that mostly is for gating and rolliout purpose. This PR in general looks good to me, but have we battle tested in all FSDP workload? If not, we might just want to gate it for now so that instead of reverting the PR, we can iterate on top of this change? WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think since we control the callback, this change seems reasonable to me without gating (so we know that we do not use the future). I feel that in my experience with FSDP, gating ends up costing more in terms of maintenance than helping.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we want to add some comments and say something like if the future is non empty, the uses_future must be set to true. Or is it possible to add a check if the lambda func is taking a unnamed argument (maybe not lol)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some comments

will do

add a check if the lambda func is taking a unnamed argument

I tried to do this https://github.com/pytorch/pytorch/pull/109933/files/b3ce2ec4fab56f5173fdbcb681759d7887832da2
but couldn't get the windows builds to pass

}
work->future_->markCompleted(at::IValue(*work->outputs_));
}
Expand Down Expand Up @@ -1999,9 +2004,14 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
// future blocks the stream this callback runs on the corresponding
// ncclEndEvents_ ensuring appropriate synchronization.
if (work->recordFunctionEndCallback_) {
work->future_->addCallback([work](at::ivalue::Future& /* unused */) {
work->recordFunctionEndCallback_();
});
work->future_->addCallback(
[work](at::ivalue::Future& /* unused */) {
work->recordFunctionEndCallback_();
},
// uses_future = false allows us to skip synchronization in
// ivalue::Future, but is only valid as long as the lambda doesn't use
// the "Future" argument.
/*uses_future=*/false);
}

return work;
Expand Down