-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merge common parts of FutureNCCL into at::ivalue::Future #48505
Changes from 1 commit
226f08f
036b219
6ec38ac
ec882f4
c71f627
ae725e2
00eb48d
5a98c31
25b81ff
d667d5e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -290,23 +290,27 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { | |
/** | ||
* Wait on the future until it completes. | ||
*/ | ||
virtual void wait() { | ||
void wait() { | ||
std::unique_lock<std::mutex> lock(mutex_); | ||
while (!completed_) { | ||
finished_cv_.wait(lock); | ||
} | ||
|
||
postWaitHook(); | ||
} | ||
|
||
/** | ||
* Wait on the future until it completes and throw an | ||
* exception if an error exists. | ||
*/ | ||
virtual void waitAndThrow() { | ||
void waitAndThrow() { | ||
std::unique_lock<std::mutex> lock(mutex_); | ||
while (!completed_) { | ||
finished_cv_.wait(lock); | ||
} | ||
|
||
postWaitHook(); | ||
|
||
if (eptr_) { | ||
std::rethrow_exception(eptr_); | ||
} | ||
|
@@ -315,7 +319,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { | |
/** | ||
* Explicitly mark the future as completed with the output value. | ||
*/ | ||
virtual void markCompleted(IValue value) { | ||
void markCompleted(IValue value) { | ||
std::unique_lock<std::mutex> lock(mutex_); | ||
TORCH_CHECK( | ||
!completed(), | ||
|
@@ -324,6 +328,8 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { | |
completed_ = true; | ||
value_ = std::move(value); | ||
|
||
postMarkCompletedHook(value_); | ||
|
||
std::vector<std::function<void(void)>> cbs; | ||
cbs.swap(callbacks_); | ||
lock.unlock(); | ||
|
@@ -359,7 +365,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { | |
} | ||
|
||
// Get the result of the current future. | ||
virtual IValue value() { | ||
IValue value() { | ||
std::unique_lock<std::mutex> lock(mutex_); | ||
AT_ASSERT(completed()); | ||
if (eptr_) { | ||
|
@@ -370,7 +376,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { | |
|
||
// This accessor should only be used if we know that the future is | ||
// completed() with no error. | ||
virtual const IValue& constValue() { | ||
const IValue& constValue() { | ||
std::unique_lock<std::mutex> lock(mutex_); | ||
AT_ASSERT(completed()); | ||
AT_ASSERT(!eptr_); | ||
|
@@ -383,8 +389,9 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { | |
* If the future has already completed, | ||
* this function will execute the callback immediately. | ||
*/ | ||
virtual void addCallback(std::function<void(void)> callback) { | ||
void addCallback(std::function<void(void)> callback) { | ||
std::unique_lock<std::mutex> lock(mutex_); | ||
callback = wrapCallback(std::move(callback)); | ||
if (completed()) { | ||
lock.unlock(); | ||
callback(); | ||
|
@@ -398,22 +405,18 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { | |
* value of the callback. This is necessary when the callback provider needs | ||
* to know for sure when the callback has finished. | ||
*/ | ||
virtual c10::intrusive_ptr<Future> then( | ||
c10::intrusive_ptr<Future> then( | ||
std::function<IValue(void)> callback, | ||
TypePtr type) { | ||
auto fut = c10::make_intrusive<Future>(type); | ||
// Cannot move capture std::function in lambda, because it cannot deduce | ||
// the template type for std::function. Hence use std::bind to explicitly | ||
// specify types. | ||
addCallback(std::bind( | ||
[fut](std::function<IValue(void)> cb) { | ||
auto fut = createInstance(std::move(type)); | ||
addCallback( | ||
[fut, cb = std::move(callback)]() { | ||
try { | ||
fut->markCompleted(cb()); | ||
} catch (std::exception& e) { | ||
fut->setError(std::current_exception()); | ||
} | ||
}, | ||
std::move(callback))); | ||
}); | ||
return fut; | ||
} | ||
|
||
|
@@ -452,11 +455,11 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { | |
} | ||
|
||
// Check if the current future has completed | ||
virtual bool completed() const { | ||
bool completed() const { | ||
return completed_; | ||
} | ||
|
||
virtual bool hasValue() const { | ||
bool hasValue() const { | ||
std::unique_lock<std::mutex> lock(mutex_); | ||
return completed_ && !eptr_; | ||
} | ||
|
@@ -479,6 +482,17 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { | |
return type_; | ||
} | ||
|
||
protected: | ||
virtual c10::intrusive_ptr<Future> createInstance(at::TypePtr type) { | ||
return c10::make_intrusive<Future>(type); | ||
} | ||
|
||
virtual void postMarkCompletedHook(const at::IValue& value) {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you briefly document how these there apis are used, probably only FutureNCCL is using them now, but if there're other future derived type there, might be a good reference. |
||
|
||
virtual std::function<void(void)> wrapCallback(std::function<void(void)> callback) { return callback; } | ||
|
||
virtual void postWaitHook() {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we add some comments to these functions and explain what derived classes need to do when implementing them? |
||
|
||
private: | ||
void setErrorInternal( | ||
std::exception_ptr eptr, | ||
|
@@ -487,6 +501,8 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { | |
completed_ = true; | ||
eptr_ = std::move(eptr); | ||
|
||
// Do not call postMarkCompletedHook() here as there isn't any value. | ||
|
||
std::vector<std::function<void(void)>> cbs; | ||
cbs.swap(callbacks_); | ||
lock.unlock(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was curious to see what the reason for this problem was so I tried to undo this fix... and it seems to work? Maybe the comment was outdated?