New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
XLA Asynchronous compilation #43034
XLA Asynchronous compilation #43034
Conversation
Entry tmp; | ||
VLOG(2) << "Starting asynchronous compilation of cluster " | ||
<< function_name << '.'; | ||
(void)CompileStrict(&tmp, options, args, function_name, compile_fn); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we can ignore the error, we need to report it back to the user. IMO the right solution is to store the Status
in entry
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need a little help understanding a particular error:
TF_RETURN_IF_ERROR(
BroadcastXlaActivity(std::move(jit_compilation_activity)));
(what does this error mean)
In the original code, this error can trigger after compilation. At that point the Entry has been updated, and the compilation result are already stored in the cache. When this error triggers, compilation is ignored for this call only. Subsequent compilations will retrieve the prior compilation result.
So even though this error is triggered, the compilation passed, and "in the future" the result can be used.
I mimicked that behaviour with this change. When the compilation fails, the Entry is populated with that information s done before. If the above error triggers, the compilation results have already been stored. Since this is an asynchronous compilation, the fall back path has already been chosen, which would match the original behavior
// The number of times a lazy compilation must be requested for a specific | ||
// signature before we attempt to compile it. | ||
static constexpr int64 kDefaultCompilationThreshold = 2; | ||
static constexpr int64 kDefaultCompilationThreshold = 3; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a separate change right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. Should I leave this out? This was part of the change that changes the compilation heuristic (see commit comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, let's leave this out. IMO we should do this only if async compilation is enabled, assuming that it makes sense only when async compilation is enabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dropped the change that changes the compilation protocol
@bas-aarts Can you please check @sanjoy's comments and keep us posted ? Thanks! |
@@ -166,8 +166,8 @@ static Status CompileToLocalExecutable( | |||
const XlaPlatformInfo& platform_info, | |||
absl::Span<const Tensor* const> inputs, | |||
absl::Span<VariableInfo const> variable_infos, | |||
absl::Span<const int> constants, bool lazy, bool may_alias_resource_update, | |||
xla::LocalClient** client, | |||
absl::Span<const int> constants, bool async, bool lazy, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of passing around two bool
s, I'd prefer passing around a XlaCompilationCache::CompileMode
(i.e. convert the pair of bool
s into XlaCompilationCache::CompileMode
much earlier in the process).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
// compile_fn can be called asynchronously. Make sure all required arguments | ||
// are passed by value. | ||
auto compile_fn = [&, compile_options, function]( | ||
XlaCompiler* compiler, | ||
const std::vector<XlaCompiler::Argument>& args, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a bit too subtle, can you please create a struct
with an operator()
that explicitly captures all state?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sanjoy, I'm not seeing how you want me to use the operator().
please add some pseudo code to show the intent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant writing this as:
struct CompileFunctor {
CompileOptions compile_options;
// All the other state that is needed
Status operator()(<args>) { ... }
};
And create an instance of CompileFunctor
and pass that to CompileImpl
. The difference is that the state captured is now explicit. In the current version it will be easy for a later change to introduce the use of a local variable in the body of the lambda.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Passing additional argument to CompileImpl requires many additional functions to changes as well. Capture by value is cleaner.
Added a POD-type instead of a class with a operator() to explicitly capture the required variables, added additional comments for clarity, and most importantly, removed capture-default to prevent additional variables to be introduced in the body of the lambda
// compilation. | ||
// Passing args by value as well. Doing this here only when an asynchronous | ||
// compilation is performed, as copying many args incurs an overhead. | ||
async_compilation_.compiler_threads.Schedule([=] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, I think this is a bit too implicit. Can you instead create a struct that explicitly captures state and exposes an operator()
, or some other way to capture the state implicitly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed the comment here. Since everything is passed by value, no need for a struct to encapsulate anything (which would make it more error prone if new arguments would have to be passed by value)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Passing things by value does not guarantee stability though, since you could be passing a soon-to-be-stale pointer by value.
Creating a struct that has the state explicitly threaded through (like in CompileFunctor
above) will make it easier to spot use-after-free bugs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a POD-type instead of a class with a operator() to explicitly capture the required variables, added additional comments for clarity, and most importantly, removed capture-default to prevent additional variables to be introduced in the body of the lambda
string function_name = function.name(); | ||
string human_signature = VLOG_IS_ON(3) ? signature.HumanString() : function_name; | ||
VLOG(2) << "Signature: " << human_signature; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's put all of this under VLOG_IS_ON(2)
to avoid unnecessary copies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
// The number of times a lazy compilation must be requested for a specific | ||
// signature before we attempt to compile it. | ||
static constexpr int64 kDefaultCompilationThreshold = 2; | ||
static constexpr int64 kDefaultCompilationThreshold = 3; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, let's leave this out. IMO we should do this only if async compilation is enabled, assuming that it makes sense only when async compilation is enabled.
@@ -519,6 +521,7 @@ def simpleTest(self, arg0, arg1, global_jit_level): | |||
|
|||
class LazyCompilationTest(test.TestCase): | |||
|
|||
@unittest.skip("test too dependant on XLA compilation protocol") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is by design -- the test is testing the XLA compilation protocol. :)
I think the test needs to be adjusted to adapt to whatever the new scheme is,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the change that changes the compilation protocol
@bas-aarts Can you please check @sanjoy's comments and keep us posted ? Thanks! |
@bas-aarts Any update on this PR? Please. Thanks! |
This PR is still being worked on |
a312c37
to
5111699
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rebased and addressed many comments
@@ -166,8 +166,8 @@ static Status CompileToLocalExecutable( | |||
const XlaPlatformInfo& platform_info, | |||
absl::Span<const Tensor* const> inputs, | |||
absl::Span<VariableInfo const> variable_infos, | |||
absl::Span<const int> constants, bool lazy, bool may_alias_resource_update, | |||
xla::LocalClient** client, | |||
absl::Span<const int> constants, XlaCompilationCache::CompileMode cmode, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's call this compile_mode
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
// If must_compile_ is true, there is no fallback path and therefore | ||
// async and lazy must be false. If must_compile_ is false, and async | ||
// compilation is enabled, async is true, and lazy is false. Otherwise | ||
// lazy compilation is true. | ||
bool async = !must_compile_ && | ||
GetXlaOpsCommonFlags().tf_xla_async_compilation; | ||
// Possible future work: | ||
// disable async for small clusters. | ||
// disable async for cluster that have short compile time. | ||
bool lazy = async ? false : !must_compile_; | ||
XlaCompilationCache::CompileMode cmode = | ||
lazy ? XlaCompilationCache::CompileMode::kLazy : | ||
async ? XlaCompilationCache::CompileMode::kAsync : | ||
XlaCompilationCache::CompileMode::kStrict; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO this will be easier to read and self documenting if we write it as:
XlaCompilationCache::CompileMode compile_mode = [&] {
if (must_compile_) { return kStrict; }
return GetXlaOpsCommonFlags().tf_xla_async_compilation ? kAsync : kLazy;
}();
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
// compile_fn can be called asynchronously. Make sure all required arguments | ||
// are passed by value. | ||
auto compile_fn = [&, compile_options, function]( | ||
XlaCompiler* compiler, | ||
const std::vector<XlaCompiler::Argument>& args, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant writing this as:
struct CompileFunctor {
CompileOptions compile_options;
// All the other state that is needed
Status operator()(<args>) { ... }
};
And create an instance of CompileFunctor
and pass that to CompileImpl
. The difference is that the state captured is now explicit. In the current version it will be easy for a later change to introduce the use of a local variable in the body of the lambda.
// compilation. | ||
// Passing args by value as well. Doing this here only when an asynchronous | ||
// compilation is performed, as copying many args incurs an overhead. | ||
async_compilation_.compiler_threads.Schedule([=] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Passing things by value does not guarantee stability though, since you could be passing a soon-to-be-stale pointer by value.
Creating a struct that has the state explicitly threaded through (like in CompileFunctor
above) will make it easier to spot use-after-free bugs.
@@ -142,7 +142,9 @@ Backend::Backend(se::Platform* platform, Compiler* compiler, | |||
} | |||
} | |||
|
|||
Backend::~Backend() {} | |||
Backend::~Backend() { | |||
CHECK_EQ(memory_allocator_.use_count(), 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this true? What prevents a compilation from running concurrently with backend destruction? (Please add a comment.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check is not needed. Memory allocator is stored in shared pointer, so even if Backend is destroyed, compilation will succeed.
Just noticed this change has not yet been merged. Any reason for this? |
Not clear, @gbaned do you know what's going on here? |
XLA Asynchronous compilation
1) add option to opt into asynchronous compilation
2) asynchronous compilation uses a dedicated number of threads
to start a cluster instance compilation while the fallback path
is executed
3) limit number of ongoing compilations to a fixed threshold
change some VLOG levels to make level 2 less verbose