Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch 195443326 #19090

Merged
merged 34 commits into from
May 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
9b43bd6
Documentation for tf.contrib.eager.py_func
akshayka May 3, 2018
86d3435
Fix a typo.
May 3, 2018
518dfea
[XLA:CPU] Remove dead function + DCHECK, NFC
May 3, 2018
a4a9e37
Optimize idempotent ops, e.g., Snapshot(Snapshot(x)) => Snapshot(x)
tensorflower-gardener May 3, 2018
05425f2
[TF:XLA] clean up interface to xla::VerifyHloModule
nickdesaulniers May 3, 2018
316e0ba
Add separate get_read and get_updated helpers that work on code excep…
tensorflower-gardener May 3, 2018
28d43e5
Add tflite listed models with accuracy and performance numbers.
zhixianyan May 3, 2018
4f4b15c
Fix bug that disabled loop invariant node motion optimizer. Disable i…
tensorflower-gardener May 3, 2018
f25dd60
Use tuple instead of list to reduce the chance of it being picked by …
May 3, 2018
549d63a
Do not hoist nodes that modify frame info.
tensorflower-gardener May 3, 2018
200f4a2
Fix oom_test so that it doesn't try to allocate a giant host buffer when
May 3, 2018
04d5adb
Fix bugs in LogicalBuffer::ToString and BufferValue::ToProto: these f…
fdxmw May 3, 2018
c9a9280
Adjust worker shutdown hooks for TPUs
rjpower May 3, 2018
4a74a50
Fix flaky test time-outs for dnn_test and rnn_test.
tensorflower-gardener May 3, 2018
213a98d
[XLA] Redesign: deprecate ComputationBuilder.
tensorflower-gardener May 4, 2018
fc7b593
Clear the stat cache of the target when renaming the file.
rxsang May 4, 2018
fa7b5a9
[XLA] Make LocalShapedBuffer::FromLiteral fallible by passing StatusO…
cdleary May 4, 2018
0abbff6
[XLA] Redesign: cleanup client_library_test_base.
tensorflower-gardener May 4, 2018
8ec11ae
Add the MultiWorkerMirroredStrategy
May 4, 2018
da0dcb2
Internal change.
tensorflower-gardener May 4, 2018
0bb55f0
Automated g4 rollback of changelist 194829761
hyouklee May 4, 2018
1284047
* Don't copy on-host and on-device shapes locally.
tensorflower-gardener May 4, 2018
73a1908
Prefer non-nested GradientTape.gradient call when only one source is …
tomhennigan May 4, 2018
c183c56
Fixing some linter errors in TF documentation (Github > GitHub, the t…
tensorflower-gardener May 4, 2018
7a7bbc3
Do not crash on ROOT outfeed operations.
tensorflower-gardener May 4, 2018
34bb664
Fix HloSharding::GetSubSharding to return correct array shardings
tensorflower-gardener May 4, 2018
2d6170f
[XLA] Remove template keyword on non-template methods.
d0k May 4, 2018
3db0e54
Change RecvTensor RPC implementation to use DeviceContext::CopyDevice…
hawkinsp May 4, 2018
a5f44b3
Implement neg op
alanchiao May 4, 2018
47f1bd9
TFTS: Make it easier to swap in different autoregressive models.
allenlavoie May 4, 2018
e32c42a
Improve broadcast add implementation.
tensorflower-gardener May 4, 2018
09d0e30
Internal clean up: change scanf to use int64_t instead of int64
tensorflower-gardener May 4, 2018
01a70dc
Add operations before Identity operations should be quantized.
May 4, 2018
9a48796
Merge commit for internal changes
caisq May 4, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 25 additions & 0 deletions tensorflow/compiler/jit/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,31 @@ tf_cc_test(
],
)

tf_cc_test(
name = "xla_launch_util_test",
size = "small",
srcs = ["xla_launch_util_test.cc"],
deps = [
":common",
":xla_compilation_cache",
":xla_launch_util",
":xla_tensor",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_runtime",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core/kernels:variable_ops",
],
)

# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
cc_header_only_library(
name = "xla_jit_headers_lib",
Expand Down
22 changes: 13 additions & 9 deletions tensorflow/compiler/jit/xla_launch_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,16 @@ Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase* mem) {
return Status::OK();
}

namespace {
namespace internal {
// Return the 'index''th subtree of the given ShapedBuffer as a
// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the
// subtree, and sets the input's buffer pointers to nullptr for the subtree.
ScopedShapedBuffer ExtractSubShapedBuffer(
ShapedBuffer* shaped_buffer, int index,
xla::DeviceMemoryAllocator* allocator) {
xla::Shape on_host_shape = xla::ShapeUtil::GetTupleElementShape(
const xla::Shape& on_host_shape = xla::ShapeUtil::GetTupleElementShape(
shaped_buffer->on_host_shape(), index);
xla::Shape on_device_shape = xla::ShapeUtil::GetTupleElementShape(
const xla::Shape& on_device_shape = xla::ShapeUtil::GetTupleElementShape(
shaped_buffer->on_device_shape(), index);

ShapedBuffer sub_shaped_buffer(on_host_shape, on_device_shape,
Expand All @@ -98,14 +98,18 @@ ScopedShapedBuffer ExtractSubShapedBuffer(
sub_shape_tree.CopySubtreeFrom(shape_tree,
/*source_base_index=*/{index},
/*target_base_index=*/{});
for (auto& index_to_buffer : shape_tree) {
if (!index_to_buffer.first.empty() && index_to_buffer.first[0] == index) {
index_to_buffer.second = se::DeviceMemoryBase(nullptr, 0);
}
}
shape_tree.ForEachMutableElement(
[index](const xla::ShapeIndex& shape_index,
tensorflow::se::DeviceMemoryBase* data) {
// shape_index is empty for the root node. Ignore that.
if (!shape_index.empty() && shape_index[0] == index) {
*data = tensorflow::se::DeviceMemoryBase(nullptr, 0);
}
});
return ScopedShapedBuffer(std::move(sub_shaped_buffer), allocator);
}
} // namespace
} // namespace internal
using internal::ExtractSubShapedBuffer;

XlaComputationLaunchContext::XlaComputationLaunchContext(
int64 num_resource_args, xla::LocalClient* client,
Expand Down
11 changes: 11 additions & 0 deletions tensorflow/compiler/jit/xla_launch_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,17 @@ class XlaTensorBuffer : public TensorBuffer {
Allocator* allocator_;
};

// Exposed in this header file for microbenchmarking purposes, but this is an
// internal implementation detail.
namespace internal {
// Return the 'index''th subtree of the given ShapedBuffer as a
// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the
// subtree, and sets the input's buffer pointers to nullptr for the subtree.
xla::ScopedShapedBuffer ExtractSubShapedBuffer(
xla::ShapedBuffer* shaped_buffer, int index,
xla::DeviceMemoryAllocator* allocator);
} // namespace internal

} // namespace tensorflow

#endif
64 changes: 64 additions & 0 deletions tensorflow/compiler/jit/xla_launch_util_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// Contains microbenchmarks for performance critical functions in
// xla_launch_util.cc.

#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"

// Test ExtractSubBuffer with different depths (depth of ShapeTree) and fan-outs
// (cardinality of each non-leaf node's children).
void BM_ExtractSubBuffer(int iters, int depth, int fan_out) {
tensorflow::testing::StopTiming();
xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {32, 64, 128});
for (int i = 0; i < depth; ++i) {
std::vector<xla::Shape> shapes(fan_out, shape);
shape = xla::ShapeUtil::MakeTupleShape(shapes);
}
xla::ShapedBuffer shaped_buffer(shape, shape, /*platform=*/nullptr,
/*device_ordinal=*/0);
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
// Extract a buffer from approximately the middle of the first level of the
// tree.
tensorflow::internal::ExtractSubShapedBuffer(&shaped_buffer,
/*index=*/fan_out / 2,
/*allocator=*/nullptr)
.release();
}
}

BENCHMARK(BM_ExtractSubBuffer)
->ArgPair(1, 4)
->ArgPair(1, 8)
->ArgPair(1, 32)
->ArgPair(1, 64)
->ArgPair(1, 128)
->ArgPair(1, 256)
->ArgPair(1, 512)
->ArgPair(2, 4)
->ArgPair(2, 8)
->ArgPair(2, 32)
->ArgPair(2, 64)
->ArgPair(2, 128);

int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
tensorflow::testing::RunBenchmarks();
return RUN_ALL_TESTS();
}
29 changes: 22 additions & 7 deletions tensorflow/compiler/tests/oom_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import googletest


Expand All @@ -42,20 +44,33 @@ def testOutputOutOfMemory(self):
"""

def test_loop():
size = 2e8
size = int(2e8)
while True:
with self.test_session():
# Force the compiled code to not be constant by feeding in an addend.
p = array_ops.placeholder(dtypes.float32, shape=[])
# Force the compiled code to not be constant by feeding in a
# parameter.
p = array_ops.placeholder(dtypes.float32, shape=[2, 1, 1])
with self.test_scope():
# Create a large R1 tensor.
c = array_ops.zeros([size, 1]) + p
# Create a computation that produces a large R1 tensor as an
# intermediate result. Reduce it down so that if this file was
# compiled without --config=cuda, we don't force a D2H copy of a
# large tensor and potentially OOM the host.
#
# This is a bit tricky because XLA:GPU doesn't currently support RNG
# ops. Here we rely on the fact that XLA doesn't do algebraic
# simplifications on conv(<ones>, <filter>).
c = math_ops.reduce_sum(
nn_ops.convolution(
array_ops.ones([1, size, 1]),
p,
padding='SAME',
data_format='NWC'))

c.eval(feed_dict={p: 1.0})
c.eval(feed_dict={p: [[[1.0]], [[2.0]]]})
size *= 2

self.assertRaises(errors.ResourceExhaustedError, test_loop)


if __name__ == "__main__":
if __name__ == '__main__':
googletest.main()
2 changes: 2 additions & 0 deletions tensorflow/compiler/xla/client/computation.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ namespace xla {
// Wraps a ComputationHandle protobuf with a lifetime. Computation is
// movable and not copyable to capture the same kind of unique
// ownership that std::unique_ptr represents.
//
// TODO(b/74197823): Deprecated. Use XlaComputation instead.
class Computation {
public:
// Creates a null Computation.
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/xla/client/computation_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ namespace xla {
// deferred from being handled until Build() is called.
//
// Thread-compatible.
//
// TODO(b/74197823): Deprecated. Use XlaBuilder instead.
class ComputationBuilder {
public:
// client: client in which to build the computation.
Expand Down
5 changes: 1 addition & 4 deletions tensorflow/compiler/xla/client/lib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
Expand All @@ -43,9 +41,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client",
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
Expand Down
90 changes: 2 additions & 88 deletions tensorflow/compiler/xla/client/lib/arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ limitations under the License.

#include <string>

#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
Expand All @@ -27,28 +28,6 @@ limitations under the License.
namespace xla {
namespace {

using InstructionGenerator =
ComputationDataHandle (*)(ComputationBuilder*, const ComputationDataHandle&,
const ComputationDataHandle&);

Computation CreateScalarComputation(const string& name, PrimitiveType type,
ComputationBuilder* builder,
InstructionGenerator generator) {
std::unique_ptr<ComputationBuilder> b;
if (type == PRED) {
b = builder->CreateSubBuilder(name);
} else {
b = builder->CreateSubBuilder(
tensorflow::strings::StrCat(name, "_", PrimitiveType_Name(type)));
}

const Shape scalar = ShapeUtil::MakeShape(type, {});
auto lhs = b->Parameter(0, scalar, "lhs");
auto rhs = b->Parameter(1, scalar, "rhs");
generator(b.get(), lhs, rhs);
return b->BuildAndNoteError();
}

using XlaOpGenerator = XlaOp (*)(XlaBuilder*, const XlaOp&, const XlaOp&);

XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
Expand All @@ -71,71 +50,6 @@ XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,

} // namespace

Computation CreateScalarAddComputation(PrimitiveType type,
ComputationBuilder* builder) {
return CreateScalarComputation(
"add", type, builder,
[](ComputationBuilder* b, const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs) { return b->Add(lhs, rhs); });
}

Computation CreateScalarMultiplyComputation(PrimitiveType type,
ComputationBuilder* builder) {
return CreateScalarComputation(
"mul", type, builder,
[](ComputationBuilder* b, const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs) { return b->Mul(lhs, rhs); });
}

Computation CreateScalarGeComputation(PrimitiveType type,
ComputationBuilder* builder) {
return CreateScalarComputation(
"ge", type, builder,
[](ComputationBuilder* b, const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs) { return b->Ge(lhs, rhs); });
}

Computation CreateScalarMaxComputation(PrimitiveType type,
ComputationBuilder* builder) {
return CreateScalarComputation(
"max", type, builder,
[](ComputationBuilder* b, const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs) { return b->Max(lhs, rhs); });
}

Computation CreateScalarMinComputation(PrimitiveType type,
ComputationBuilder* builder) {
return CreateScalarComputation(
"min", type, builder,
[](ComputationBuilder* b, const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs) { return b->Min(lhs, rhs); });
}

Computation CreateScalarAndComputation(ComputationBuilder* builder) {
return CreateScalarComputation(
"and", PRED, builder,
[](ComputationBuilder* b, const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs) { return b->And(lhs, rhs); });
}

Computation CreateScalarOrComputation(ComputationBuilder* builder) {
return CreateScalarComputation(
"or", PRED, builder,
[](ComputationBuilder* b, const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs) { return b->Or(lhs, rhs); });
}

StatusOr<ComputationDataHandle> Any(const ComputationDataHandle& predicates,
ComputationBuilder* builder) {
auto f = builder->ConstantR0<bool>(false);
Computation logical_or = CreateScalarOrComputation(builder);
TF_ASSIGN_OR_RETURN(std::unique_ptr<Shape> predicates_shape,
builder->GetShape(predicates));
std::vector<int64> all_dimensions(ShapeUtil::Rank(*predicates_shape));
std::iota(all_dimensions.begin(), all_dimensions.end(), 0);
return builder->Reduce(predicates, f, logical_or, all_dimensions);
}

XlaComputation CreateScalarAddComputation(PrimitiveType type,
XlaBuilder* builder) {
return CreateScalarComputation(
Expand Down