Skip to content

Commit

Permalink
Merge pull request #43034 from bas-aarts:bas-devel-async-compilation
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 363120731
Change-Id: I4860b80e3e7a851c59539f7c3fadf9443ce1a715
  • Loading branch information
tensorflower-gardener committed Mar 16, 2021
2 parents d688099 + 597499a commit 7005f42
Show file tree
Hide file tree
Showing 14 changed files with 436 additions and 145 deletions.
5 changes: 5 additions & 0 deletions tensorflow/compiler/jit/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ void AllocateAndParseFlags() {

ops_flags = new XlaOpsCommonFlags;
ops_flags->tf_xla_always_defer_compilation = false;
ops_flags->tf_xla_async_compilation = false;

jitter_flags = new IntroduceFloatingPointJitterPassFlags;
jitter_flags->jitter_amount = 1e-5;
Expand Down Expand Up @@ -216,6 +217,10 @@ void AllocateAndParseFlags() {

Flag("tf_xla_always_defer_compilation",
&ops_flags->tf_xla_always_defer_compilation, ""),
Flag("tf_xla_async_compilation", &ops_flags->tf_xla_async_compilation,
"When lazy compilation is enabled, asynchronous compilation starts "
"the cluster compilation in the background, and the fallback path "
"is executed until the compilation has finished."),

Flag("tf_introduce_floating_point_jitter_to_tensors",
setter_for_jitter_tensor_names, "",
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/compiler/jit/flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ struct XlaOpsCommonFlags {
// If true, _XlaCompile always refuses to compile the cluster, which means the
// XLA clusters always run in the TF executor. Defaults to false.
bool tf_xla_always_defer_compilation;
// If true, _XlaCompile compiles the cluster asynchronously with respect to
// the main execution. The fallback path is taken while compilation happens.
bool tf_xla_async_compilation;
};

// Flags for the build_xla_ops pass.
Expand Down
7 changes: 3 additions & 4 deletions tensorflow/compiler/jit/get_compiler_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ static xla::StatusOr<xla::LocalExecutable*> GetLocalExecutable(
const XlaCompiler::Options& options,
const XlaCompiler::CompileOptions& compile_options,
const NameAttrList& function, XlaCompilationCache* cache,
absl::Span<XlaCompiler::Argument const> args, const XlaCompiler& compiler) {
const std::vector<XlaCompiler::Argument>& args,
const XlaCompiler& compiler) {
const XlaCompiler::CompilationResult* compilation_result = nullptr;
xla::LocalExecutable* executable = nullptr;
TF_RETURN_IF_ERROR(cache->Compile(options, function, args, compile_options,
Expand Down Expand Up @@ -100,12 +101,10 @@ xla::StatusOr<std::string> GetCompilerIr(
}));
core::ScopedUnref cache_ref(cache);

absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;

XlaCompiler::Options options =
GenerateCompilerOptions(*cache, *flr, dev,
/*stream=*/nullptr, platform_info,
/*has_ref_vars=*/false, &tf_allocator_adapter);
/*has_ref_vars=*/false);

XlaCompiler::CompileOptions compile_options;
compile_options.always_return_tuple = false;
Expand Down
48 changes: 25 additions & 23 deletions tensorflow/compiler/jit/kernels/xla_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,9 @@ static Status CompileToLocalExecutable(
const XlaPlatformInfo& platform_info,
absl::Span<const Tensor* const> inputs,
absl::Span<VariableInfo const> variable_infos,
absl::Span<const int> constants, bool lazy, bool may_alias_resource_update,
xla::LocalClient** client,
absl::Span<const int> constants,
XlaCompilationCache::CompileMode compile_mode,
bool may_alias_resource_update, xla::LocalClient** client,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable) {
// We store information about the JIT-compiled XLA computation
Expand All @@ -190,11 +191,10 @@ static Status CompileToLocalExecutable(

*client = static_cast<xla::LocalClient*>(cache->client());

absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
XlaCompiler::Options options = GenerateCompilerOptions(
*cache, *ctx->function_library(), ctx->device(),
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
platform_info, has_ref_vars, &tf_allocator_adapter);
platform_info, has_ref_vars);

XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true;
Expand All @@ -209,9 +209,7 @@ static Status CompileToLocalExecutable(
constants, inputs, variable_infos,
static_cast<Device*>(ctx->device()));
TF_RETURN_IF_ERROR(args.status());
return cache->Compile(options, function, *args, compile_options,
lazy ? XlaCompilationCache::CompileMode::kLazy
: XlaCompilationCache::CompileMode::kStrict,
return cache->Compile(options, function, *args, compile_options, compile_mode,
compilation_result, executable);
}

Expand All @@ -232,7 +230,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
Status s = CompileToLocalExecutable(
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, inputs,
variable_infos, constants_, /*lazy=*/false,
variable_infos, constants_, XlaCompilationCache::CompileMode::kStrict,
/*may_alias_resource_update=*/true, &client, &compilation_result,
&executable);
OP_REQUIRES_OK(ctx, s);
Expand All @@ -245,12 +243,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {

se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;

absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator = GetAllocator(
&tf_allocator_adapter, ctx->device(),
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
platform_info_);
std::shared_ptr<se::DeviceMemoryAllocator> allocator_ptr =
GetAllocator(ctx->device(), stream, platform_info_);
se::DeviceMemoryAllocator* allocator = allocator_ptr.get();
int device_ordinal = stream ? stream->parent()->device_ordinal()
: client->default_device_ordinal();
XlaComputationLaunchContext launch_context(
Expand Down Expand Up @@ -380,6 +375,14 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
mutex_lock guard(cannot_compile_cluster_mu_);
cannot_compile_cluster = cannot_compile_cluster_;
}
XlaCompilationCache::CompileMode compile_mode = [&] {
if (must_compile_) {
return XlaCompilationCache::CompileMode::kStrict;
}
return GetXlaOpsCommonFlags().tf_xla_async_compilation
? XlaCompilationCache::CompileMode::kAsync
: XlaCompilationCache::CompileMode::kLazy;
}();

if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation ||
cannot_compile_cluster) {
Expand All @@ -395,12 +398,12 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
// unlocking them in XlaRun may lead to deadlocks.
Status status = CompileToLocalExecutable(
ctx, function_, has_ref_vars_, platform_info_, inputs, variable_infos,
constants_,
/*lazy=*/!must_compile_,
/*may_alias_resource_update=*/false, &client, &kernel, &executable);
constants_, compile_mode, /*may_alias_resource_update=*/false, &client,
&kernel, &executable);
OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
variable_infos, &variables));
if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
if (compile_mode != XlaCompilationCache::CompileMode::kLazy ||
status.code() != error::UNIMPLEMENTED) {
OP_REQUIRES_OK(ctx, status);
}

Expand All @@ -422,6 +425,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
host_alloc_attrs.set_on_host(true);
Allocator* cpu_allocator = ctx->device()->GetAllocator(host_alloc_attrs);

// Async compilation returns nullptr executable without an error.
if (!executable) {
DCHECK(!must_compile_);
Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
Expand Down Expand Up @@ -462,13 +466,11 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
XlaExecutableClosure closure =
XlaExecutableClosureStore::Global()->Consume(key);

absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator = GetAllocator(
&tf_allocator_adapter, ctx->device(),
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
platform_info_);
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
std::shared_ptr<se::DeviceMemoryAllocator> allocator_ptr =
GetAllocator(ctx->device(), stream, platform_info_);
se::DeviceMemoryAllocator* allocator = allocator_ptr.get();
int device_ordinal = stream ? stream->parent()->device_ordinal()
: closure.client()->default_device_ordinal();
XlaComputationLaunchContext launch_context(
Expand Down

0 comments on commit 7005f42

Please sign in to comment.