Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch 195061425 #19025

Merged
merged 85 commits into from
May 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
a5a51ad
Adding a depthwise convolution kernel op (with label 'cudnn_grouped_c…
tensorflower-gardener Apr 30, 2018
aa2405e
Fixes to tape gradient for providing outputs and having multiple targ…
alextp Apr 30, 2018
1872f29
Clarify return type for defun as zero or more `tf.Tensor`s.
tomhennigan Apr 30, 2018
a3ae05d
Remove manifest_merger that is being removed from Bazel 0.13.0.
tensorflower-gardener Apr 30, 2018
09e529f
Prepare nodes that will be allocated using ScopedAllocator.
dubey Apr 30, 2018
9f2728b
Switch install get_started link
MarkDaoust Apr 30, 2018
d6da4aa
Add snippet illustrating discretized logistic mixture for WaveNet.
dustinvtran Apr 30, 2018
113cc4d
Add --keep_going flag to bazel query in pip_smoke_test to bypass baze…
yifeif Apr 30, 2018
bdaa70c
-Miscellaneous code clean-up
tensorflower-gardener Apr 30, 2018
c3e9ca7
Fix bugs in AssignOp:
tensorflower-gardener Apr 30, 2018
8d5e87b
Use the default rewriter config instead of a custom one
benoitsteiner Apr 30, 2018
9d79acc
[TF:XLA] Bump open source llvm revision to r331173
Apr 30, 2018
8609ef4
When a mirrored variable is fetched in cross-tower mode, fetch its pr…
isaprykin Apr 30, 2018
6e9d8ab
Fix typos in tf.GradientTape documentation.
tomhennigan Apr 30, 2018
5cdcb47
Fix device assignment in xla/service/service.cc to build the assignme…
hyouklee Apr 30, 2018
1986f00
Cleanup handling of non-Tensor valued event_ndims in Bijector.
jvdillon Apr 30, 2018
57b7c7b
Fix a bug in profiler.
shashishekhar Apr 30, 2018
b7cb5fa
Extend SDCAOptimizer functionality to prune negative indices (the def…
tensorflower-gardener Apr 30, 2018
0ed712e
[tf.data] Adding support for `tf.SparseTensor` into `tf.contrib.data.…
jsimsa Apr 30, 2018
4041ae0
Push down const inputs into the function of specialized functions.
tensorflower-gardener Apr 30, 2018
fac11a7
Removing an obsolete TODO
petrosmol Apr 30, 2018
c0f0720
[TF:XLA] Fix some unexpected memory leak in hlo_graph_dumper_test.
dimvar Apr 30, 2018
ab02bce
Do not cast int64 to int32 in keras embedding lookups.
alextp Apr 30, 2018
d2a4227
Add XLA logo to its documentation page
eliben Apr 30, 2018
b8197b2
Implement unary chain hoisting optimization for Concat, Split, and Sp…
tensorflower-gardener Apr 30, 2018
9c961e8
Enhancements to GRAPHVIZ_DOT output:
tensorflower-gardener Apr 30, 2018
286d61b
Do not allocate memory for literal as it will be allocated later.
tensorflower-gardener Apr 30, 2018
30fcdec
Improve error message for pip_smoke_test.
yifeif Apr 30, 2018
1834361
[XLA] Change the TF2XLA bridge to perform F16 reduction using F32 dat…
bixia1 Apr 30, 2018
7141ed5
Add MultiNodeDataset and MultiNodeIterator which are intended to work…
Apr 30, 2018
1ff23a3
Small fix to prevent a crash if the delegate has not implemented Free…
tensorflower-gardener Apr 30, 2018
64bb1de
Faster reduce_logsoftmax (specially in eager) and bugfixes in broadca…
alextp May 1, 2018
b7978d4
Internal cleanup.
tensorflower-gardener May 1, 2018
c89a1d9
[tf.data] Adding an experimental `group_by_reducer` transformation wh…
jsimsa May 1, 2018
45bafe9
[XLA] Redesign: migrate tensorflow/compiler/tf2xla, tensorflow/compil…
tensorflower-gardener May 1, 2018
c8a0f6e
Go: Update generated wrapper functions for TensorFlow ops.
tensorflower-gardener May 1, 2018
4072142
Remove proto header import from core/framework/tracking_allocator.h
yifeif May 1, 2018
85d30bf
Internal change.
tensorflower-gardener May 1, 2018
88821b0
[XLA] Redesign: dump HloSnapshot at the point where it used to dump t…
tensorflower-gardener May 1, 2018
79ccb99
Move LinearOperatorKronecker and LinearOperatorBlockDiag to core.
tensorflower-gardener May 1, 2018
37191e9
Update ops-related pbtxt files.
tensorflower-gardener May 1, 2018
95b3643
[XLA:CPU] Open source some tests.
May 1, 2018
a4343eb
Protocol buffer classes now list their fields in dir(cls)
tensorflower-gardener May 1, 2018
10337c9
Preventing RemoveTrivialBinary from removing broadcasts.
tensorflower-gardener May 1, 2018
a82e0e7
Fix crash in HloGraphDumper where it crashes on tuple shaped constants
tensorflower-gardener May 1, 2018
bb82203
Automated g4 rollback of changelist 194917415
gunan May 1, 2018
9477a96
eager: Update sample notebooks with API changes in the last few relea…
asimshankar May 1, 2018
07c5885
Boosted trees: support indicator column.
tensorflower-gardener May 1, 2018
8e918c3
Improve shape inference for tf.contrib.signal.frame.
rryan May 1, 2018
5c18dc6
Simplified shape inference.
benoitsteiner May 1, 2018
59677dc
Add device_util.resolve method which merges with current device as well.
May 1, 2018
87ebe11
Implements matrix multiply-accumulate for linear no-offset (aka symme…
tensorflower-gardener May 1, 2018
ee236bd
Add a pointer from Device to its owning DeviceMgr.
hawkinsp May 1, 2018
57207f2
Add utility to auto shard a dataset pipeline in the appropriate place…
guptapriya May 1, 2018
1a50cd4
Open source infeed test
May 1, 2018
9149558
Collective Ops Part 5
tensorflower-gardener May 1, 2018
3b7f22f
Relax the stringent memory allocator constraints in AssignOp if a Gra…
tensorflower-gardener May 1, 2018
594f970
Update schema.
shashishekhar May 1, 2018
75c1896
Update community/swift
MarkDaoust May 1, 2018
7cbbd35
Enable checkpointless eval and predict for tf.estimator.
tensorflower-gardener May 1, 2018
46bf1e8
Make tower-local variables non-trainable even with the default
tensorflower-gardener May 1, 2018
325d0ef
Merge changes from github.
May 1, 2018
fdcdf75
Fix bug in peak buffer accounting in buffer assignment.
meheffernan May 1, 2018
415ea73
Making ids unique in nn.embedding_lookup_sparse. This helps to reduce…
tensorflower-gardener May 1, 2018
707b0c9
Minor JNI performance improvement.
tensorflower-gardener May 1, 2018
6f10fb5
Fixed some outdated comments
benoitsteiner May 1, 2018
33978e8
Go: Update generated wrapper functions for TensorFlow ops.
tensorflower-gardener May 1, 2018
b25e6fe
Implementation of the fully-connected TFLite Op using the symmetric q…
tensorflower-gardener May 1, 2018
210abeb
[TF:XLA] Separate on-host and on-device shape and layout in HloModule.
tensorflower-gardener May 1, 2018
fb8f040
Allow `warm_start_from` argument to be a SavedModel path.
tensorflower-gardener May 1, 2018
f5dbc1e
Check for overflow in shape calculation.
tensorflower-gardener May 1, 2018
1aa7aaa
Adds logistic_regression_head.
tensorflower-gardener May 1, 2018
2b54774
Avoid making a copy of the graph needlessly
benoitsteiner May 1, 2018
833803d
Fix wrongly ordered lines
May 1, 2018
448b2ca
Sharding for tensorflow/contrib/timeseries/python/timeseries/state_sp…
allenlavoie May 1, 2018
f453b62
test fix
ispirmustafa May 2, 2018
62356ad
Re-apply CL 194140820, which reverts #18251 (convolution change).
May 2, 2018
92939c5
Internal change.
tensorflower-gardener May 2, 2018
c8ae9e8
Internal change
tensorflower-gardener May 2, 2018
69b2c63
[XLA:CPU] Re-use the same llvm::GlobalVariable for identical literals
May 2, 2018
c0f1080
Make the CRF work when sequence_lengths are int32.
May 2, 2018
b50f632
Minor refactor: establish some operator naming conventions and apply …
May 2, 2018
7715b7b
Add missing colocated element in test in buffer_assignment_test.
meheffernan May 2, 2018
5e1448f
BUGFIX: Convert inputs and list of gradients into tuple if they are n…
tensorflower-gardener May 2, 2018
f6ce3e2
Merge commit for internal changes
caisq May 2, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
65 changes: 27 additions & 38 deletions tensorflow/c/eager/tape.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,49 +380,39 @@ Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
gtl::ArraySlice<Gradient*> output_gradients,
const TensorTape& tensor_tape,
const OpTape<BackwardFunction>& op_tape,
const gtl::FlatMap<int64, int64>& tensor_usage_counts,
gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
for (int i = 0; i < target_tensor_ids.size(); ++i) {
const int64 id = target_tensor_ids[i];
if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
if (!output_gradients.empty() && output_gradients[i] != nullptr) {
// TODO(apassos) figure out how to print debugging information here.
return errors::InvalidArgument(
"A gradient was provided for a tensor which is used as part of the "
"computation.");
}
} else {
if (output_gradients.empty() || output_gradients[i] == nullptr) {
auto tensor_it = tensor_tape.find(id);
if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
auto op_it = op_tape.find(tensor_it->second);
if (op_it == op_tape.end()) {
return errors::Internal(
"Internal state of the gradient tape is invalid: "
"failed to find operation producing a tensor");
}
bool found = false;
for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
if (op_it->second.output_tensor_info[j].id == id) {
found = true;
(*result)[id].push_back(
vspace.Ones(op_it->second.output_tensor_info[j].shape,
op_it->second.output_tensor_info[j].dtype));
break;
}
}
if (!found) {
return errors::Internal(
"Internal state of the gradient tape is invalid: "
"none of operations outputs match expected tensor");
if (output_gradients.empty() || output_gradients[i] == nullptr) {
auto tensor_it = tensor_tape.find(id);
if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
auto op_it = op_tape.find(tensor_it->second);
if (op_it == op_tape.end()) {
return errors::Internal(
"Internal state of the gradient tape is invalid: "
"failed to find operation producing a tensor");
}
bool found = false;
for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
if (op_it->second.output_tensor_info[j].id == id) {
found = true;
(*result)[id].push_back(
vspace.Ones(op_it->second.output_tensor_info[j].shape,
op_it->second.output_tensor_info[j].dtype));
break;
}
} else {
// No record of the target tensor found on the tape, so no gradient
// needs to be computed from it. Do nothing.
}
if (!found) {
return errors::Internal(
"Internal state of the gradient tape is invalid: "
"none of operations outputs match expected tensor");
}
} else {
(*result)[id].push_back(output_gradients[i]);
// No record of the target tensor found on the tape, so no gradient
// needs to be computed from it. Do nothing.
}
} else {
(*result)[id].push_back(output_gradients[i]);
}
}
return Status::OK();
Expand Down Expand Up @@ -451,8 +441,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
InitialStack(state.op_tape, state.op_missing_tensor);
gtl::FlatMap<int64, std::vector<Gradient*>> gradients;
Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
tensor_tape_, state.op_tape,
state.tensor_usage_counts, &gradients);
tensor_tape_, state.op_tape, &gradients);
auto cleanup = [this, &state]() {
if (!persistent_) {
// Release all backprop functions
Expand Down
12 changes: 6 additions & 6 deletions tensorflow/compiler/aot/compile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ namespace {

// Compiles the XLA computation into executable code.
Status CompileXla(xla::CompileOnlyClient* client,
const xla::Computation& computation,
const xla::XlaComputation& computation,
const xla::cpu::CpuAotCompilationOptions& aot_opts,
CompileResult* compile_result) {
// Retrieves arg and result layouts from the computation.
Expand All @@ -62,7 +62,7 @@ Status CompileXla(xla::CompileOnlyClient* client,
for (int i = 0; i < pshape->parameters_size(); ++i) {
arg_layouts.push_back(pshape->mutable_parameters(i));
}
xla::CompileOnlyClient::AotComputationInstance instance;
xla::CompileOnlyClient::AotXlaComputationInstance instance;
instance.computation = &computation;
instance.argument_layouts = std::move(arg_layouts);
instance.result_layout = &pshape->result();
Expand Down Expand Up @@ -93,14 +93,14 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
xla::CompileOnlyClient* client =
xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform)
.ValueOrDie();
xla::Computation computation;
xla::XlaComputation computation;
TF_RETURN_IF_ERROR(
ConvertGraphDefToXla(graph_def, config, client, &computation));
if (!flags.out_session_module.empty()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::SessionModule> module,
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
computation.Snapshot());
// Serialize the SessionModule deterministically so that all the outputs of
// a tf_library genrule are deterministic.
// Serialize the HloSnapshot deterministically so that all the outputs of a
// tf_library genrule are deterministic.
string proto;
TF_RET_CHECK(SerializeToStringDeterministic(*module, &proto));
TF_RETURN_IF_ERROR(
Expand Down
14 changes: 8 additions & 6 deletions tensorflow/compiler/aot/tests/tfcompile_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -525,14 +525,16 @@ TEST(TFCompileTest, HloProfiling) {
auto header = HasSubstr("Execution profile for");
auto total_cycles_profile_line = HasSubstr("[total]");
auto dot_profile_line = HasSubstr(
"%dot = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
"%dot.0.2 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
"%arg1.0.1)");
auto add_profile_line = HasSubstr(
"%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
"%add.0.5 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
"%arg1.0.1)");
auto tuple_profile_line = HasSubstr(
"%tuple.2 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %dot, "
"f32[2,2]{1,0} %add)");
auto arg0_profile_line = HasSubstr("%arg0 = f32[2,2]{1,0} parameter(0)");
auto arg1_profile_line = HasSubstr("%arg1 = f32[2,2]{1,0} parameter(1)");
"%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} "
"%dot.0.2, f32[2,2]{1,0} %add.0.5)");
auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)");
auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)");

hlo_profile_lines.erase(hlo_profile_lines.begin() + 7,
hlo_profile_lines.end());
Expand Down
64 changes: 64 additions & 0 deletions tensorflow/compiler/tests/reduce_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import print_function

import functools
import itertools
import numpy as np

from tensorflow.compiler.tests.xla_test import XLATestCase
Expand Down Expand Up @@ -155,5 +156,68 @@ def testReduceAny(self):
self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA)


class ReduceOpPrecisionTest(XLATestCase):

def _testReduceSum(self,
expected_result,
dtype,
test_inputs,
rtol=1e-3,
atol=1e-4):
"""Tests reduce sum on a list of input arrays.

For each array in test_inputs, check that performing reduce sum on the array
produces a value that is close to the expected result.

Args:
expected_result: the expected result.
dtype: the data type of the reduce sum operation.
test_inputs: a list of input arrays for the reduce sum operation.
rtol: the relative error.
atol: the absolute error.
"""

for test_input in test_inputs:
with self.test_session() as sess:
with self.test_scope():
a = array_ops.placeholder(dtype)
index = array_ops.placeholder(dtypes.int32)
out = math_ops.reduce_sum(a, index)
result = sess.run(out, {
a: np.array(test_input, dtype=dtype),
index: [0]
})
# Compare the results using float32 type.
self.assertAllClose(
np.float32(result),
np.float32(expected_result),
rtol=rtol,
atol=atol)

def testReduceSumF16(self):
"""Tests the reduce sum of float16 doesn't lose too much precision."""

if np.float16 not in self.all_types:
return

f16_max = np.finfo(np.float16).max
self._testReduceSum(
f16_max, np.float16,
itertools.permutations([f16_max, f16_max, f16_max * (-1.0)], 3))

def testReduceSumBF16(self):
"""Tests the reduce sum of bfloat16 doesn't lose too much precision."""

if dtypes.bfloat16.as_numpy_dtype not in self.all_types:
return

bf16_max = np.float32(dtypes.bfloat16.max)
f32_max = dtypes.float32.max
value = min(bf16_max, f32_max - bf16_max)
self._testReduceSum(
dtypes.bfloat16.as_numpy_dtype(value), dtypes.bfloat16.as_numpy_dtype,
itertools.permutations([bf16_max, value, bf16_max * (-1.0)], 3))


if __name__ == '__main__':
googletest.main()
7 changes: 3 additions & 4 deletions tensorflow/compiler/tf2xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ cc_library(
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/client",
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
Expand Down Expand Up @@ -168,9 +168,9 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
Expand Down Expand Up @@ -215,7 +215,6 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:sharding_builder",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
Expand Down
7 changes: 5 additions & 2 deletions tensorflow/compiler/tf2xla/graph_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
const std::vector<const XlaExpression*>& expressions,
std::vector<XlaCompiler::Argument>* args) {
auto builder = ctx->builder();
auto client = ctx->compiler()->client();
std::vector<bool> compile_time_constant_flags(expressions.size());

TF_RETURN_IF_ERROR(
Expand All @@ -72,8 +73,10 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
arg.kind = XlaCompiler::Argument::kConstant;
TF_RET_CHECK(expressions[i]->resource() == nullptr)
<< "Input with resource is not yet implemented.";
TF_ASSIGN_OR_RETURN(auto constant_graph, builder->BuildConstantSubGraph(
expressions[i]->handle()));
TF_ASSIGN_OR_RETURN(auto literal,
builder->ComputeConstant(expressions[i]->handle()));
client->ComputeConstant(constant_graph));
TF_RETURN_IF_ERROR(
LiteralToHostTensor(*literal, arg.type, &arg.constant_value));
} else {
Expand Down Expand Up @@ -212,7 +215,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,

TF_RET_CHECK(arguments.size() == expressions.size());

std::vector<xla::ComputationDataHandle> handles;
std::vector<xla::XlaOp> handles;
for (int64 i = 0; i < expressions.size(); ++i) {
if (arguments[i].kind == XlaCompiler::Argument::kConstant) {
continue;
Expand Down
8 changes: 4 additions & 4 deletions tensorflow/compiler/tf2xla/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ tf_kernel_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:image_ops_op_lib",
"//tensorflow/core:lib",
Expand Down Expand Up @@ -151,7 +151,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
Expand All @@ -167,7 +167,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
Expand Down Expand Up @@ -203,8 +203,8 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels:argmax_op",
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class AddNOp : public XlaOpKernel {
OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
errors::InvalidArgument("AddN requires at least one argument"));

xla::ComputationDataHandle sum = ctx->Input(0);
xla::XlaOp sum = ctx->Input(0);
for (int i = 1; i < ctx->num_inputs(); ++i) {
sum = ctx->builder()->Add(sum, ctx->Input(i));
}
Expand Down
18 changes: 9 additions & 9 deletions tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ class FusedBatchNormOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx,
DataTypeToPrimitiveType(ctx->input_type(1), &scale_type));

xla::ComputationBuilder* builder = ctx->builder();
xla::XlaBuilder* builder = ctx->builder();

xla::ComputationDataHandle input = ctx->Input(0);
xla::XlaOp input = ctx->Input(0);
TensorShape input_shape = ctx->InputShape(0);

int feature_index =
Expand All @@ -62,7 +62,7 @@ class FusedBatchNormOp : public XlaOpKernel {
input = builder->ConvertElementType(input, scale_type);

if (is_training_) {
xla::ComputationDataHandle output = builder->BatchNormTraining(
xla::XlaOp output = builder->BatchNormTraining(
input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index);

// In training mode, outputs the normalized value as well as the
Expand All @@ -79,7 +79,7 @@ class FusedBatchNormOp : public XlaOpKernel {
ctx->SetOutput(3, builder->GetTupleElement(output, 1));
ctx->SetOutput(4, builder->GetTupleElement(output, 2));
} else {
xla::ComputationDataHandle output = builder->BatchNormInference(
xla::XlaOp output = builder->BatchNormInference(
input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4),
epsilon_, feature_index);
ctx->SetOutput(0, builder->ConvertElementType(output, input_type));
Expand Down Expand Up @@ -118,7 +118,7 @@ class FusedBatchNormGradOp : public XlaOpKernel {
}

void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationBuilder* const b = ctx->builder();
xla::XlaBuilder* const b = ctx->builder();
DataType input_dtype = ctx->input_type(0);
DataType scale_dtype = ctx->input_type(2);

Expand All @@ -137,11 +137,11 @@ class FusedBatchNormGradOp : public XlaOpKernel {
const int feature_index =
GetTensorFeatureDimIndex(input_dims, data_format_);

xla::ComputationDataHandle x_backprop;
xla::ComputationDataHandle scale_backprop;
xla::ComputationDataHandle offset_backprop;
xla::XlaOp x_backprop;
xla::XlaOp scale_backprop;
xla::XlaOp offset_backprop;
if (is_training_) {
xla::ComputationDataHandle output =
xla::XlaOp output =
b->BatchNormGrad(activations, scale, mean, var, grad_backprop,
epsilon_, feature_index);

Expand Down