Skip to content

Commit

Permalink
address comments on commit cf346df
Browse files Browse the repository at this point in the history
  • Loading branch information
Bas Aarts committed Oct 8, 2020
1 parent cf346df commit a312c37
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 18 deletions.
25 changes: 13 additions & 12 deletions tensorflow/compiler/jit/xla_compilation_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ constexpr int64 XlaCompilationCache::kDefaultCompilationThreshold;

XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client,
DeviceType device_type)
: client_(client), device_type_(std::move(device_type)), async_compilation_() {}
: client_(client), device_type_(std::move(device_type)) {}

XlaCompilationCache::~XlaCompilationCache() {
// Ensure any use of our programs have completed by waiting for all stream
Expand Down Expand Up @@ -387,32 +387,33 @@ Status XlaCompilationCache::CompileAsynchronous(
entry->compile_state = CompileState::kCompiling; // Still under caller's lock.
{
mutex_lock lock(async_compilation_.async_compilation_mu_);
async_compilation_.nrof_ongoing_compilations++;
async_compilation_.num_ongoing_compilations++;
}
// Don't move the above code into the thread function!!!

// Passing options by value into the lamba increases the refcount on
// options.device_allocator, keeping it alive for the duration of the
// compilation.
// Passing args by value as well. Doing this here only when an asynchronous
// compilation is performed, as copying many argS incurs an overhead.
// compilation is performed, as copying many args incurs an overhead.
async_compilation_.compiler_threads.Schedule([=] {
Entry tmp;
Entry local_entry;
VLOG(2) << "Starting asynchronous compilation of cluster "
<< function_name << '.';
(void)CompileStrict(&tmp, options, args, function_name, compile_fn);
(void)CompileStrict(&local_entry, options, args, function_name,
compile_fn);
VLOG(2) << "Finished asynchronous compililation of cluster "
<< function_name << '.';
{
mutex_lock lock(async_compilation_.async_compilation_mu_);
async_compilation_.nrof_ongoing_compilations--;
async_compilation_.num_ongoing_compilations--;
}
{ // Populate original entry with compilation result.
mutex_lock entry_lock(entry->mu);
entry->compilation_result = tmp.compilation_result;
entry->compile_state = tmp.compile_state;
entry->compilation_status = tmp.compilation_status;
entry->executable = std::move(tmp.executable);
entry->compilation_result = local_entry.compilation_result;
entry->compile_state = local_entry.compile_state;
entry->compilation_status = local_entry.compilation_status;
entry->executable = std::move(local_entry.executable);
}
}
);
Expand Down Expand Up @@ -526,8 +527,8 @@ Status XlaCompilationCache::CompileImpl(
// asynchronous compilation is enabled.
{
mutex_lock lock(async_compilation_.async_compilation_mu_);
if (async_compilation_.nrof_ongoing_compilations >=
async_compilation_.kMaxNrofOngoingCompilations) {
if (async_compilation_.num_ongoing_compilations >=
async_compilation_.kMaxNumOngoingCompilations) {
VLOG(2) << "Not asynchronously compiling cluster " << function_name
<< " because of too many ongoing compilations.";
return false;
Expand Down
9 changes: 4 additions & 5 deletions tensorflow/compiler/jit/xla_compilation_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,21 +222,20 @@ class XlaCompilationCache : public ResourceBase {
mutex async_compilation_mu_;

// Number of threads for asynchronous compilations.
static constexpr int64 kNrofCompilerThreads = 10;
static constexpr int64 kNumCompilerThreads = 10;

// Maximum number of ongoing compilations.
static constexpr int64 kMaxNrofOngoingCompilations = kNrofCompilerThreads;
static constexpr int64 kMaxNumOngoingCompilations = kNumCompilerThreads;

// Pool of threads for asynchronous compilations.
thread::ThreadPool compiler_threads;

// Number of ongoing compilations.
int64 nrof_ongoing_compilations GUARDED_BY(async_compilation_mu_) = 0;
int64 num_ongoing_compilations GUARDED_BY(async_compilation_mu_) = 0;

AsyncCompilation()
: compiler_threads(tensorflow::Env::Default(), "aync_compiler_threads",
kNrofCompilerThreads) {}
~AsyncCompilation() {}
kNumCompilerThreads) {}

} async_compilation_;

Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/jit/xla_platform_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class XlaPlatformInfo {
// If the op associated with this XlaPlatformInfo is placed on an XLA device
// then device_allocator_ is the xla::Backend's memory allocator. If the op
// is placed on a regular CPU or GPU device then device_allocator_ is null.
// The allocator is of unknowm provenance; keep it in a shared pointer to
// The allocator is of unknown provenance; keep it in a shared pointer to
// set an artificial refcount of one.
std::shared_ptr<se::DeviceMemoryAllocator> device_allocator_;

Expand Down

0 comments on commit a312c37

Please sign in to comment.