Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch 203518000 #20601

Merged
merged 40 commits into from
Jul 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
a1810b4
Go: Update generated wrapper functions for TensorFlow ops.
tensorflower-gardener Jul 3, 2018
1737b82
Add distributed model GetStepSequenceAsync implementation to
tensorflower-gardener Jul 3, 2018
bd94b10
Move InterpreterBuilder's destructor to the .cc file.
tensorflower-gardener Jul 3, 2018
fb58f1d
[tf.data / Bigtable] Log mutation_status errors.
saeta Jul 3, 2018
fc61f06
Make sure that that nodes connected by reference/resource edges are c…
akshayka Jul 3, 2018
774a01f
Internal change
mkuperst Jul 4, 2018
ae5cdb7
Automated g4 rollback of changelist 203171335
tensorflower-gardener Jul 4, 2018
58a61c0
tf.keras sync 2.2.0
Jul 4, 2018
1e7dde8
I think this should fix the following failure with see in MacOS build:
annarev Jul 4, 2018
8779f76
[TF:XLA] Split literal_util into {literal, literal_util}.
kayzhu Jul 4, 2018
6578bad
[Java]: Release 1.9.0-rc2
asimshankar Jul 4, 2018
3661a14
Exclude importing tensorflow.contrib.cloud on Windows (internal)
tensorflower-gardener Jul 4, 2018
cead016
Minor performance improvement for ComputeReachability
tensorflower-gardener Jul 4, 2018
51061e1
Profile SequentialThunks if they represent one HloInstruction.
akuegel Jul 4, 2018
014422b
Begin introducing NUMA support for CPU threads and memory
tensorflower-gardener Jul 4, 2018
8737cd1
[XLA] Make ShapeUtil::PermuteDimensions() properly permute layouts.
Jul 4, 2018
548601a
Automated g4 rollback of changelist 203283392
tensorflower-gardener Jul 4, 2018
eb8a110
[XLA:GPU] Enhance the tiled 0-2-1 transpose algorithm to handle fusion.
bixia1 Jul 4, 2018
9103612
[tf.data / Bigtable] Refactor BUILD file
saeta Jul 5, 2018
2c130a0
Improve the latency of small `tf.concat()` operations.
mrry Jul 5, 2018
6042e76
Add a utility function so broadcasting can be done over arbitrary Sha…
tensorflower-gardener Jul 5, 2018
f3e8ac9
Remove GatherExpander pass from the CPU backend.
akuegel Jul 5, 2018
c0e54af
Fix zero size hlo elimination to not replace empty constants
tensorflower-gardener Jul 5, 2018
baf41a3
[tf.data / Bigtable]: Augment test client's SampleKeys() implementation.
saeta Jul 5, 2018
14ef3b6
[tf.data / Bigtable] Support sampling row keys.
saeta Jul 5, 2018
4d080fb
[tf.data / Bigtable] Clean up BUILD files
saeta Jul 5, 2018
9879867
Format .bzl files with buildifier
tensorflower-gardener Jul 6, 2018
37eed82
Removing unused options from batch_ops.batch_function().
vinuraja Jul 6, 2018
e9de897
Add no_pip tag to revnet_test
tensorflower-gardener Jul 6, 2018
626760c
fix link
MarkDaoust Jul 6, 2018
8f5e2a7
Fix HLO domains mis-handling from HLO passes.
tensorflower-gardener Jul 6, 2018
ed494f1
Rolling back tensorflow .bzl file changes
rohan100jain Jul 6, 2018
8a9f023
AutoGraph notebook for workshop
tensorflower-gardener Jul 6, 2018
3acc7e3
Creating the new FeatureColumn's that would be made public. This CL c…
rohan100jain Jul 6, 2018
7e9a018
Fix the update_version.py bug that is breaking nightlies.
Jul 6, 2018
0e616c1
Update pin to bazel toolchains repo
tensorflower-gardener Jul 6, 2018
ff06d67
Support extra identity operation added if batch_norms updates are for…
Jul 6, 2018
cbeaf29
Faster Shuffle
tensorflower-gardener Jul 6, 2018
602f8fa
Merge changes from github.
yifeif Jul 6, 2018
50e72c4
Merge commit for internal changes
yifeif Jul 6, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
2 changes: 1 addition & 1 deletion tensorflow/compiler/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ tf_xla_py_test(

tf_xla_py_test(
name = "sort_ops_test",
size = "small",
size = "medium",
srcs = ["sort_ops_test.py"],
# Times out in fastbuild mode.
tags = ["optonly"],
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/tests/random_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ def probit(x, sess=sess):
def testShuffle1d(self):
with self.test_session() as sess:
with self.test_scope():
x = math_ops.range(20)
x = math_ops.range(1 << 16)
shuffle = random_ops.random_shuffle(x)
result = sess.run(shuffle)
expected = range(20)
expected = range(1 << 16)
# Compare sets to avoid randomness behavior changes but make sure still
# have all the values.
self.assertAllEqual(set(result), set(expected))
Expand Down
8 changes: 5 additions & 3 deletions tensorflow/compiler/tf2xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ cc_library(
":sharding_util",
":tf2xla_util",
"//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
Expand Down Expand Up @@ -202,7 +202,7 @@ cc_library(
],
visibility = [":friends"],
deps = [
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:core_cpu_internal",
Expand Down Expand Up @@ -285,6 +285,7 @@ tf_cc_test(
deps = [
":tf2xla",
":tf2xla_proto",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
Expand Down Expand Up @@ -327,7 +328,7 @@ tf_cc_test(
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/client:client_library",
Expand Down Expand Up @@ -364,6 +365,7 @@ tf_cc_test(
],
deps = [
":common",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/core:framework",
"//tensorflow/core:test",
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/compiler/tf2xla/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla/lib:while_loop",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
Expand Down Expand Up @@ -159,7 +160,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
Expand All @@ -175,7 +176,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
Expand Down Expand Up @@ -210,6 +211,7 @@ tf_kernel_library(
":index_ops_kernel_argmax_float_2d",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client/lib:arithmetic",
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/bcast.h"
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/tf2xla/kernels/elu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/no_op.h"
Expand Down
8 changes: 4 additions & 4 deletions tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
Expand Down Expand Up @@ -78,14 +78,14 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
std::vector<xla::XlaOp> args;
args.push_back(ctx->Input(0));
args.push_back(xla::ConstantLiteral(
&b, *xla::Literal::CreateR1<int64>(input_shape.dim_sizes())));
&b, *xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
if (input_shape.dims() > 1) {
// Don't bother passing the output shape and dim for the 1d case, since
// the shape is always a scalar and the dim is always 0.
args.push_back(xla::ConstantLiteral(
&b, *xla::Literal::CreateR1<int64>(output_shape.dim_sizes())));
&b, *xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
args.push_back(
xla::ConstantLiteral(&b, *xla::Literal::CreateR0<int32>(dim)));
xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0<int32>(dim)));
}

xla::Shape xla_shape =
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
Expand Down
159 changes: 112 additions & 47 deletions tensorflow/compiler/tf2xla/kernels/random_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,56 +74,121 @@ class RandomShuffleOp : public XlaOpKernel {
for (tensorflow::TensorShapeDim dimension : input_shape) {
num_elements *= dimension.size;
}

if (num_elements <= 1 || n <= 1) {
// No shuffling is required, so copy input directly to output
ctx->SetOutput(0, input);
} else {
// Generate the random swaps for the indices.
auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n});
auto swaps =
xla::RngUniform(xla::ConstantR0<int32>(builder, 0),
xla::ConstantR0<int32>(builder, n), swaps_shape);

// Generate range(n) as the initial value for the indices to be swapped.
xla::XlaOp indices = xla::Iota(builder, xla::S32, n);

// Swap the indices at i and swaps[i].
auto swap_body_fn = [&](xla::XlaOp i,
gtl::ArraySlice<xla::XlaOp> loop_vars,
xla::XlaBuilder* builder)
-> xla::StatusOr<std::vector<xla::XlaOp>> {
auto swaps = loop_vars[0];
auto indices = loop_vars[1];
i = xla::Reshape(i, {1});
// temp = indices[i]
auto temp = xla::DynamicSlice(indices, i, {1});
// swap_index = swaps[i]
auto swap_index = xla::DynamicSlice(swaps, i, {1});
// swap_value = indices[swaps[i]]
auto swap_value = xla::DynamicSlice(indices, swap_index, {1});
// indices[i] = indices[swaps[i]]
indices = xla::DynamicUpdateSlice(indices, swap_value, i);
// indices[swaps[i]] = temp
indices = xla::DynamicUpdateSlice(indices, temp, swap_index);
return std::vector<xla::XlaOp>{swaps, indices};
};
// for i in range(n):
auto swap_loop_result =
XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices},
"indices_swap_loop", builder)
.ValueOrDie();
auto swapped_indices = swap_loop_result[1];

// Gather the data using the swapped indices as the shuffled order.
auto indices_tensor_shape = TensorShape({n});
DataType type = ctx->expected_output_dtype(0);
xla::XlaOp gather;
OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices,
indices_tensor_shape,
/*axis=*/0, /*indices_are_nd=*/false, type,
DT_INT32, builder, &gather));
ctx->SetOutput(0, gather);
return;
}

if (input_shape.dims() == 1) {
// For R1s, shuffle values by sorting instead of the obvious Fisher-Yates
// algorithm. Fisher-Yates is simple to implement and correct, but not
// easily parallelizable. For a sufficiently parallel architecture, it is
// faster to sort many times, than Fisher-Yates shuffle once.

// Shuffle values by assigning each value a random key and sorting the
// keys. Keys can collide causing detectable patterns in the shuffled
// output. Collisions translates into more ascending sub-sequences in the
// shuffled output than would be expected by chance. To avoid collisions,
// the number of possible key values must be sufficiently large.

// How are more than 2^32 keys created? In each loop iteration, the
// algorithm sorts by random keys. Conceptually, the earlier iterations
// are sorting on the lower-order bits of larger keys that are never
// actually assembled.

// The expected number of collisions is n - d + d(1 - 1/d)^n, where d is
// the number of possible keys and n is the number of values. If d = n^2,
// then the limit as n goes to infinity is 1/2. If d = n^3, then the limit
// as n goes to infinity is zero.

// This implementation ensures that the key-space is greater than or equal
// to the cube of the number of values. The risk of collisions can be
// further reduced by increasing Exponent at the expense of
// performance.

// For Exponent = 2, the expected number of collisions per shuffle is
// maximized at n = floor((2^32-1)^(1/2)) = 65535 where the expectation is
// about 1/2.

// For Exponent = 3, the expected number of collisions per shuffle is
// maximized at n = floor((2^32-1)^(1/3)) = 1625 where the expectation is
// about 1/3255.

// For Exponent = 4, the expected number of collisions per shuffle is
// maximized at n = floor((2^32-1)^(1/4)) = 255 where the expectation is
// about 1/132622.
constexpr int Exponent = 3;
const int rounds = static_cast<int>(
std::ceil(Exponent * std::log(num_elements) / std::log(kuint32max)));

const xla::Shape key_shape =
xla::ShapeUtil::MakeShape(xla::U32, {num_elements});
xla::XlaOp zero = xla::ConstantR0(builder, 0U);

// Unfortunately, xla::RngUniform gives values in the half open interval
// rather than the closed interval, so instead of 2^32 possible keys there
// are only 2^32 - 1 (kuint32max).
xla::XlaOp max_value = xla::ConstantR0(builder, kuint32max);

xla::XlaOp curr = input;
for (int i = 0; i < rounds; ++i) {
xla::XlaOp keys = xla::RngUniform(zero, max_value, key_shape);
xla::XlaOp sorted = xla::Sort(keys, curr);
curr = xla::GetTupleElement(sorted, 1);
}

ctx->SetOutput(0, curr);
return;
}

// The Fisher-Yates algorithm.

// Generate the random swaps for the indices.
auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n});
auto swaps =
xla::RngUniform(xla::ConstantR0<int32>(builder, 0),
xla::ConstantR0<int32>(builder, n), swaps_shape);

// Generate range(n) as the initial value for the indices to be swapped.
xla::XlaOp indices = xla::Iota(builder, xla::S32, n);

// Swap the indices at i and swaps[i].
auto swap_body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
xla::XlaBuilder* builder)
-> xla::StatusOr<std::vector<xla::XlaOp>> {
auto swaps = loop_vars[0];
auto indices = loop_vars[1];
i = xla::Reshape(i, {1});
// temp = indices[i]
auto temp = xla::DynamicSlice(indices, i, {1});
// swap_index = swaps[i]
auto swap_index = xla::DynamicSlice(swaps, i, {1});
// swap_value = indices[swaps[i]]
auto swap_value = xla::DynamicSlice(indices, swap_index, {1});
// indices[i] = indices[swaps[i]]
indices = xla::DynamicUpdateSlice(indices, swap_value, i);
// indices[swaps[i]] = temp
indices = xla::DynamicUpdateSlice(indices, temp, swap_index);
return std::vector<xla::XlaOp>{swaps, indices};
};
// for i in range(n):
auto swap_loop_result =
XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices},
"indices_swap_loop", builder)
.ValueOrDie();
auto swapped_indices = swap_loop_result[1];

// Gather the data using the swapped indices as the shuffled order.
auto indices_tensor_shape = TensorShape({n});
DataType type = ctx->expected_output_dtype(0);
xla::XlaOp gather;
OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices,
indices_tensor_shape,
/*axis=*/0, /*indices_are_nd=*/false, type,
DT_INT32, builder, &gather));
ctx->SetOutput(0, gather);
}

private:
Expand Down Expand Up @@ -220,5 +285,5 @@ REGISTER_XLA_OP(Name("TruncatedNormal")
.TypeConstraint("dtype", DT_FLOAT),
TruncatedNormalOp);

} // anonymous namespace
} // namespace
} // namespace tensorflow
2 changes: 1 addition & 1 deletion tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"

namespace tensorflow {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"

namespace tensorflow {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/tf2xla/kernels/relu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/no_op.h"
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/tf2xla/kernels/reshape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/tf2xla/kernels/reverse_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/tf2xla/kernels/scan_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/register_types.h"
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/tf2xla/kernels/split_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
Expand Down