Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch 194551042 #18935

Merged
merged 57 commits into from
Apr 27, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
12129fc
Disable gather_test under ASAN since it times out.
tensorflower-gardener Apr 26, 2018
98d0386
Added metadata to the TFLite model.
tensorflower-gardener Apr 26, 2018
0ecfb35
[XLA] Redesign: migrate other xla/tests to use the new buidler.
tensorflower-gardener Apr 26, 2018
521606d
Support CuDNN RNN layers in tf.keras.
pavithrasv Apr 26, 2018
59a4b48
Clarify limitation of `deps` in tf_gen_op_wrapper_py
tensorflower-gardener Apr 26, 2018
8148895
Support matching against shape string in HLO testing matchers
tensorflower-gardener Apr 26, 2018
481f229
- Adding support for Cholesky (inverse) factor multiplications.
james-martens Apr 26, 2018
4eac28a
Format header guards under tensorflow/core/grappler.
tensorflower-gardener Apr 26, 2018
acb5563
tfdbg: disable grpc_large_data_test on ASAN
caisq Apr 26, 2018
e563b56
Disable vector_diffeomixture_test under ASAN to avoid timeouts.
tensorflower-gardener Apr 26, 2018
85e4dc4
Fixing issue #13258. y is the square of Mahalanobis distance actually.
bignamehyp Apr 26, 2018
509ffc3
Simplify tfe.defun capture by not using convert_to_tensor
alextp Apr 26, 2018
18f1349
Disable wrappers_test under ASAN since it sometimes times out.
tensorflower-gardener Apr 26, 2018
f495e32
Limit the number of single allocation memory warnings.
rjpower Apr 26, 2018
f63a8d6
Remove "everything matched" assertions from CuDNN object-based checkp…
allenlavoie Apr 26, 2018
efa789e
Add a skeleton dispatch context object, that can be used to control t…
Apr 26, 2018
c7dce75
Updates on https://www.tensorflow.org/community/swift as part of the …
Apr 26, 2018
a0af355
Automated g4 rollback of changelist 192536085
tensorflower-gardener Apr 26, 2018
6b6976e
Deprecate tfe.Network and associated utilities in favor of tf.keras.M…
allenlavoie Apr 26, 2018
a848183
Removing @@ comments from core TensorFlow. They are no longer needed …
annarev Apr 26, 2018
d66adb4
Simplify, test and document logic in instruction fusion that decides …
tensorflower-gardener Apr 26, 2018
b6adaab
Move */logging.cc into :platform_base since it already exposes the he…
tensorflower-gardener Apr 26, 2018
667077c
Optimize functions in the function library.
tensorflower-gardener Apr 26, 2018
f637506
For tf.gradients(), do not backpropagate through integer tensors.
tensorflower-gardener Apr 26, 2018
bb28101
gRPC worker cache owns a shared_ptr to the channel cache
Apr 26, 2018
38244c3
Automated g4 rollback of changelist 194269675
tensorflower-gardener Apr 26, 2018
5f06514
Fix build by adding op_lib dependencies to trt_engine_op_loader, and …
aaroey Apr 26, 2018
d3c18b5
Delay deleting RingReducer until group_size_tensor_ready_ has
tensorflower-gardener Apr 26, 2018
2c105ac
Run 2 passes of rewrites by default
benoitsteiner Apr 26, 2018
b6189a2
[TF:XLA] Add INTEL MKL_DNN Conv2d method to XLA/CPU backend
Apr 26, 2018
7b0e865
Adding some slightly more exhaustive strided_slice test parameters.
tensorflower-gardener Apr 26, 2018
35bf3bf
Remove unnecessary TF_NEED_GCP from build scripts.
tensorflower-gardener Apr 26, 2018
f67a78c
Disable densenet_test on MSAN due to flaky time outs.
tensorflower-gardener Apr 26, 2018
4386296
Adds optimization to convert division of sqrt to multiplication of rsqrt
tensorflower-gardener Apr 26, 2018
eceb3a2
Edit tensorflow.org/community/swift page.
dan-zheng Apr 26, 2018
5dd3d19
Disable triangular_solve_test on ASAN due to flaky time outs.
tensorflower-gardener Apr 26, 2018
2ce60cd
Add support for variables in tf.custom_gradient
tensorflower-gardener Apr 26, 2018
c9be1f2
- Default values of cov and inv variables are now 0. Zero-debiasing …
james-martens Apr 26, 2018
7ec93b4
[tf.data] Changes description for `bytes_produced_stats` and `latency…
shivaniag Apr 26, 2018
2808c3f
[tf.data] Adds support for adding scalar value to `StatsAggregator`.
shivaniag Apr 26, 2018
ab5de48
Remove the inter-op thread pool
Apr 26, 2018
3ab696e
Handle variations in scoping of batch norms for correct unfused batch…
raghuraman-k Apr 26, 2018
04a5547
Internal change.
tensorflower-gardener Apr 26, 2018
bcefec3
Fix some flakiness in test.
shashishekhar Apr 26, 2018
7d3e3fd
More informative error message when loading a graph_def which uses un…
malcolmreynolds Apr 26, 2018
236120d
Split out SaveableObjects into their own file
allenlavoie Apr 26, 2018
0b02fd4
Implements linear no-offset (aka symmetric) quantizer.
tensorflower-gardener Apr 27, 2018
0a1d311
Free scratch memory in ~BaseGPUDevice.
tensorflower-gardener Apr 27, 2018
84b3322
Automated g4 rollback of changelist 194442428
tensorflower-gardener Apr 27, 2018
e41e70e
Implement floor operator
tensorflower-gardener Apr 27, 2018
7c845cb
Reenable factorization_ops_test on ASAN after adding shard_count = 4.…
tensorflower-gardener Apr 27, 2018
f88add4
Automated g4 rollback of changelist 194306629
miaout17 Apr 27, 2018
4f69331
[TF:XLA] Bump open source llvm revision to r330926
Apr 27, 2018
ec56b53
Fix bug in @custom_gradient in Eager mode with numpy inputs
tensorflower-gardener Apr 27, 2018
f7f0248
Added string conversion operator to tensorflow::StringPiece.
tensorflower-gardener Apr 27, 2018
2c9a67f
Merge commit for internal changes
Apr 27, 2018
5e0f151
Fix merge conflict manual merge error.
Apr 27, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorflow/compiler/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,7 @@ tf_xla_py_test(
name = "gather_test",
size = "medium",
srcs = ["gather_test.py"],
tags = ["noasan"], # times out, http://b/78599043
deps = [
":xla_test",
"//tensorflow/python:array_ops",
Expand Down
4 changes: 3 additions & 1 deletion tensorflow/compiler/tests/tensor_array_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,9 @@ def _testTensorArrayGradientWriteReadType(self, dtype):
self.assertAllEqual(c([[-2.0, -10.0]]), grad_vals[1])

def testTensorArrayGradientWriteRead(self):
for dtype in self.numeric_types:
for dtype in self.float_types:
self._testTensorArrayGradientWriteReadType(dtype)
for dtype in self.complex_types:
self._testTensorArrayGradientWriteReadType(dtype)

def _testTensorArrayGradientWritePackConcatAndRead(self):
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/tf2xla/lib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ cc_library(
xla_test(
name = "triangular_solve_test",
srcs = ["triangular_solve_test.cc"],
tags = ["noasan"], # sometimes times out, http://b/78650012
deps = [
":triangular_solve",
"//tensorflow/compiler/xla:array2d",
Expand Down
11 changes: 0 additions & 11 deletions tensorflow/compiler/xla/executable_run_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,6 @@ stream_executor::Stream* ExecutableRunOptions::stream() const {
return stream_;
}

ExecutableRunOptions& ExecutableRunOptions::set_inter_op_thread_pool(
tensorflow::thread::ThreadPool* inter_op_thread_pool) {
inter_op_thread_pool_ = inter_op_thread_pool;
return *this;
}

tensorflow::thread::ThreadPool* ExecutableRunOptions::inter_op_thread_pool()
const {
return inter_op_thread_pool_;
}

ExecutableRunOptions& ExecutableRunOptions::set_intra_op_thread_pool(
const Eigen::ThreadPoolDevice* intra_op_thread_pool) {
intra_op_thread_pool_ = intra_op_thread_pool;
Expand Down
7 changes: 0 additions & 7 deletions tensorflow/compiler/xla/executable_run_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,6 @@ class ExecutableRunOptions {
ExecutableRunOptions& set_stream(stream_executor::Stream* stream);
stream_executor::Stream* stream() const;

// Sets the thread pool on which to run parallel CPU backend
// computations. Does not take ownership.
ExecutableRunOptions& set_inter_op_thread_pool(
tensorflow::thread::ThreadPool* inter_op_thread_pool);
tensorflow::thread::ThreadPool* inter_op_thread_pool() const;

// Sets the thread pool device on which to run Eigen subcomputations.
// Does not take ownership.
ExecutableRunOptions& set_intra_op_thread_pool(
Expand All @@ -93,7 +87,6 @@ class ExecutableRunOptions {
int device_ordinal_ = -1;
DeviceAssignment* device_assignment_ = nullptr;
stream_executor::Stream* stream_ = nullptr;
tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr;
const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;
ExecutionProfile* execution_profile_ = nullptr;
int rng_seed_ = 0;
Expand Down
3 changes: 0 additions & 3 deletions tensorflow/compiler/xla/python/local_computation_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,6 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
ExecutableRunOptions options;
options.set_device_ordinal(device_ordinal);
options.set_allocator(client->backend().memory_allocator());
options.set_inter_op_thread_pool(
client->backend().inter_op_thread_pool());
options.set_intra_op_thread_pool(
client->backend().eigen_intra_op_thread_pool_device());
options.set_device_assignment(&device_assignment);
Expand Down Expand Up @@ -242,7 +240,6 @@ LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers(
// Execute
ExecutableRunOptions options;
options.set_allocator(client->backend().memory_allocator());
options.set_inter_op_thread_pool(client->backend().inter_op_thread_pool());
options.set_intra_op_thread_pool(
client->backend().eigen_intra_op_thread_pool_device());
ScopedShapedBuffer result_buffer =
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,7 @@ tf_cc_test(
":instruction_fusion",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)

Expand Down
7 changes: 0 additions & 7 deletions tensorflow/compiler/xla/service/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,6 @@ Backend::Backend(
<< "Service found no devices for backend " << platform_->Name() << '.';

if (platform->id() == se::host::kHostPlatformId) {
inter_op_thread_pool_.reset(new tensorflow::thread::ThreadPool(
tensorflow::Env::Default(), "xla_inter_op",
tensorflow::port::NumSchedulableCPUs()));
const int num_threads = intra_op_parallelism_threads > 0
? intra_op_parallelism_threads
: tensorflow::port::NumSchedulableCPUs();
Expand All @@ -155,10 +152,6 @@ int Backend::default_device_ordinal() const {
return default_stream_executor()->device_ordinal();
}

tensorflow::thread::ThreadPool* Backend::inter_op_thread_pool() const {
return inter_op_thread_pool_.get();
}

const Eigen::ThreadPoolDevice* Backend::eigen_intra_op_thread_pool_device()
const {
if (intra_op_thread_pool_wrapper_ == nullptr) {
Expand Down
7 changes: 0 additions & 7 deletions tensorflow/compiler/xla/service/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,6 @@ class Backend {
// be equivalent to an executable compiled for the other.
StatusOr<bool> devices_equivalent(int device_ordinal_a, int device_ordinal_b);

// For the host platform, returns the threadpool to use when scheduling
// parallel operators. For other platforms, returns NULL.
tensorflow::thread::ThreadPool* inter_op_thread_pool() const;

// For the host platform, returns the configured eigen threadpool device to be
// used for scheduling work. For other platforms, returns NULL.
const Eigen::ThreadPoolDevice* eigen_intra_op_thread_pool_device() const;
Expand Down Expand Up @@ -178,9 +174,6 @@ class Backend {
// The default memory allocator to use.
std::unique_ptr<StreamExecutorMemoryAllocator> memory_allocator_;

// For the CPU backend, a threadpool for scheduling parallel operators.
std::unique_ptr<tensorflow::thread::ThreadPool> inter_op_thread_pool_;

// For the CPU backend, an Eigen threadpool device for use by Eigen code.
std::unique_ptr<EigenThreadPoolWrapper> intra_op_thread_pool_wrapper_;
};
Expand Down
22 changes: 22 additions & 0 deletions tensorflow/compiler/xla/service/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ cc_library(
":orc_jit_memory_mapper",
":runtime_fp16",
":runtime_conv2d",
":runtime_conv2d_mkl",
":runtime_fft",
":runtime_fork_join",
":runtime_matmul",
Expand Down Expand Up @@ -470,6 +471,27 @@ cc_library(
],
)

cc_library(
name = "runtime_conv2d_mkl",
srcs = [
"runtime_conv2d_mkl.cc",
],
hdrs = ["runtime_conv2d_mkl.h"],
copts = runtime_copts(),
visibility = ["//visibility:public"],
deps = [
":runtime_conv2d",
":runtime_single_threaded_conv2d",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/core:framework_lite",
"//tensorflow/core/kernels:eigen_helpers",
"//third_party/eigen3",
] + if_mkl([
"@mkl_dnn",
"//third_party/mkl:intel_binary_blob",
]),
)

cc_library(
name = "runtime_fft",
srcs = [
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ extern const char* const kEigenMatMulF32SymbolName =
"__xla_cpu_runtime_EigenMatMulF32";
extern const char* const kEigenMatMulF64SymbolName =
"__xla_cpu_runtime_EigenMatMulF64";
extern const char* const kMKLConvF32SymbolName = "__xla_cpu_runtime_MKLConvF32";
extern const char* const kMKLMatMulF32SymbolName =
"__xla_cpu_runtime_MKLMatMulF32";
extern const char* const kMKLMatMulF64SymbolName =
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/service/cpu/cpu_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ namespace runtime {
extern const char* const kEigenMatMulF16SymbolName;
extern const char* const kEigenMatMulF32SymbolName;
extern const char* const kEigenMatMulF64SymbolName;
extern const char* const kMKLConvF32SymbolName;
extern const char* const kMKLMatMulF32SymbolName;
extern const char* const kMKLMatMulF64SymbolName;
extern const char* const kMKLSingleThreadedMatMulF32SymbolName;
Expand Down
20 changes: 16 additions & 4 deletions tensorflow/compiler/xla/service/cpu/ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,8 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
const ConvolutionDimensionNumbers& dnums =
convolution->convolution_dimension_numbers();

// TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support
// different data layouts.
if (PotentiallyImplementedAsEigenConvolution(*convolution)) {
const Shape& lhs_shape = lhs->shape();
const Shape& rhs_shape = rhs->shape();
Expand Down Expand Up @@ -942,16 +944,26 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
int64_type, int64_type, int64_type, int64_type, int64_type,
int64_type, int64_type, int64_type, int64_type},
/*isVarArg=*/false);
bool multi_threaded_eigen =
bool multi_threaded =
hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
bool use_mkl_dnn =
hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();

// TODO(b/78639006) Singlethread MKL conv2d is not implemented due to the
// potential race condition by setting the omp_num_threads.
const char* fn_name =
primitive_type == F16
? (multi_threaded_eigen
? (multi_threaded
? runtime::kEigenConvF16SymbolName
: runtime::kEigenSingleThreadedConvF16SymbolName)
: (multi_threaded_eigen
? runtime::kEigenConvF32SymbolName
: (multi_threaded
? (use_mkl_dnn ? runtime::kMKLConvF32SymbolName
: runtime::kEigenConvF32SymbolName)
: runtime::kEigenSingleThreadedConvF32SymbolName);
if (!multi_threaded && use_mkl_dnn) {
LOG(WARNING) << "Using Eigen instead of MKL-DNN for single-threaded "
"conv2d function.";
}
llvm::Function* conv_func = llvm::cast<llvm::Function>(
module_->getOrInsertFunction(fn_name, conv_type));
conv_func->setCallingConv(llvm::CallingConv::C);
Expand Down