Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate intra_op_parallelism_threads from SessionOptions to xla::LocalClientOptions #30996

Merged
merged 2 commits into from Jul 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 6 additions & 1 deletion tensorflow/compiler/jit/xla_device.cc
Expand Up @@ -203,6 +203,7 @@ XlaDevice::XlaDevice(const SessionOptions& session_options,
device_ordinal_(options.device_ordinal),
jit_device_name_(options.compilation_device_name),
platform_(options.platform),
intra_op_parallelism_threads_(session_options.config.intra_op_parallelism_threads()),
use_multiple_streams_(options.use_multiple_streams),
shape_representation_fn_(options.shape_representation_fn),
allowed_devices_(options.allowed_devices) {
Expand Down Expand Up @@ -233,9 +234,13 @@ xla::LocalClient* XlaDevice::client() const {
// don't want to do it until we get a chance to hook the platform up
// to a simulator.

xla::LocalClientOptions options;
options.set_platform(platform_)
.set_allowed_devices(allowed_devices_)
.set_intra_op_parallelism_threads(intra_op_parallelism_threads_);
// TODO(b/78468222): This can fail, at least when the backend is GPU and
// there is no GPU on the host.
return xla::ClientLibrary::GetOrCreateLocalClient(platform_, allowed_devices_)
return xla::ClientLibrary::GetOrCreateLocalClient(options)
.ValueOrDie();
}

Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/jit/xla_device.h
Expand Up @@ -202,6 +202,8 @@ class XlaDevice : public LocalDevice {
const DeviceType jit_device_name_;
// The platform for this device.
se::Platform* const platform_; // Not owned.
// Intra-op threads to spawn (from SessionOptions).
const int intra_op_parallelism_threads_;
// Memory allocator associated with this device.
Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr; // Not owned.

Expand Down