Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch 195897321 #19164

Merged
merged 117 commits into from
May 9, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
117 commits
Select commit Hold shift + click to select a range
f368558
[XLA] Cleanup client_library_test_base: move definition of CreatePara…
tensorflower-gardener May 4, 2018
be9b873
[XLA] Redesign: migrate the SWIG wrapped xla client.
tensorflower-gardener May 4, 2018
5ca373b
Some fixes to support another TF graph:
tensorflower-gardener May 4, 2018
67b5e72
[XLA:GPU] Mark floating-point division as an inexpensive op.
May 4, 2018
4d0388d
Fix build failure for macos py3
protoget May 4, 2018
cb1775e
Identify and prune nodes that can never be executed
benoitsteiner May 4, 2018
ab48fb5
[XLA] Print allowed attributes when the user specifies an invalid attr.
May 4, 2018
008a3b6
Add the ability to export separate SavedModels for train and eval mod…
May 4, 2018
037e52e
Expose read-only versions of tensors in tflite.
tensorflower-gardener May 4, 2018
fa1d92f
Add infrastructure for a backend-specific configuration for each op. …
broune May 4, 2018
bf228e1
[tf.data] Adding `num_parallel_calls` to `map_and_batch`.
jsimsa May 5, 2018
dedcf1d
Update ops-related pbtxt files.
tensorflower-gardener May 5, 2018
b81f2dc
Go: Update generated wrapper functions for TensorFlow ops.
tensorflower-gardener May 5, 2018
ad2e7f0
Internal Change
tensorflower-gardener May 5, 2018
59f0618
GuaranteeConst is a NoOp for the op_level_cost_estiamtor.
yacoder May 5, 2018
dd5ef1b
Checkpointable: A small utility for exempting objects from __setattr_…
allenlavoie May 5, 2018
5fb53fe
add support for PadV2
tensorflower-gardener May 5, 2018
f21bead
OVIC Benchmarker App (current without the functionality to bind to a …
tensorflower-gardener May 5, 2018
939fc53
Fix landscape layout.
shashishekhar May 5, 2018
150089e
Remove uses of the kTransposeDot fusion
May 5, 2018
067be1a
Part 2 of Swift<->TF sends/recvs: support receiving tensors in TF from
May 5, 2018
f3c2191
[XLA] Always be willing to duplicate widening kConvert instructions d…
May 5, 2018
62ed0aa
Allow benchmark model graph to be specified in text proto format.
shashishekhar May 5, 2018
a5e1809
[TPU] Add option to only compile a replicated graph.
jpienaar May 5, 2018
094daf9
[XLA:GPU] Zero out input buffers before running cudnn conv autotune.
May 5, 2018
14f618e
Gracefully handle workers without heartbeat support enabled.
rjpower May 6, 2018
e0be7b7
Remove unused threadpool from stream executor.
tensorflower-gardener May 7, 2018
6304da2
Improve fusion logic of (a dot b) * alpha
tensorflower-gardener May 7, 2018
9dc76cd
Automated g4 rollback of changelist 195638795
tensorflower-gardener May 7, 2018
adcfdfc
Control flow graph with forward and backward analysis
tensorflower-gardener May 7, 2018
a0bdd4b
Internal Change
netfs May 7, 2018
ac630df
Removing quotations to fix the broken link found on https://www.tenso…
tensorflower-gardener May 7, 2018
b2888c6
Add EvaluateNodes to HoistFactorDiv test.
tensorflower-gardener May 7, 2018
9ba26ca
Extend block sparsity support for TPUs
tensorflower-gardener May 7, 2018
170634d
Replaced calls to tensorflow::StringPiece::ToString with std::string …
tensorflower-gardener May 7, 2018
f6a55cc
Add tests for broadcasting KL divergence calculations.
tensorflower-gardener May 7, 2018
93846ec
Extracts PartialConcatConstFolding into a method.
tensorflower-gardener May 7, 2018
eb5ee79
Release notes for TensorFlow Lite.
miaout17 May 7, 2018
f14123d
Generalize the input to TPU distribution strategy. Add cross-shard-r…
isaprykin May 7, 2018
aa57960
Register bool scatter_update for resource variables
alextp May 7, 2018
abe83fe
Disable autograph cfg_test in windows.
gunan May 7, 2018
0297d9c
[tf.data] Patch to unref iterator_resource in DeserializeIteratorOp.
shivaniag May 7, 2018
fb6d927
[XLA] Add FusionKind matcher to pattern_matcher.h.
May 7, 2018
a75f3f5
[TF:XLA:GPU] Allow the use of linear address when there are size one …
bixia1 May 7, 2018
c3fef21
Add 'optonly' directive to linear_operator_circulant tests.
langmore May 7, 2018
6f3a890
Adding Greater/GreaterEqual/LessEqual ops to complement Less.
tensorflower-gardener May 7, 2018
a0496a1
Make test tensorflow/python/keras:resnet50_test be size "medium"
tensorflower-gardener May 7, 2018
91fb950
Rename HloDotWithContractDimsMatcher to HloDotWithContractingDimsMatcher
May 7, 2018
914c971
Specialize functions only once per unique context.
tensorflower-gardener May 7, 2018
bd8d744
Fixes for accessing variables with a MirroredStrategy in a
tensorflower-gardener May 7, 2018
9a0c845
[TF:XLA] Bump open source llvm revision to r331624
May 7, 2018
92acc30
Change deprecation_version field in api_def.proto to a string.
annarev May 7, 2018
fa2132a
Use 64bit aggregation for gradients and hessians since the 32 bit ver…
tensorflower-gardener May 7, 2018
544dcc5
Move PadV2Options to the end, in order to maintain schema compatibility.
tensorflower-gardener May 7, 2018
cd065ca
[XLA] Shard compilation of HloEvaluator.
May 7, 2018
dfae6ff
Fix resource variable in cond gradient.
alextp May 7, 2018
8498672
Allow output has a different shape from input in the image.transform …
tensorflower-gardener May 7, 2018
80c7ebc
Disable automated testing of tensorflow/compiler/tests:extract_image_…
tensorflower-gardener May 7, 2018
860452f
Replace references to TensorInfo with XlaTensor
iganichev May 7, 2018
de177c4
Add support for tf.data.Dataset iterators in model training/eval meth…
pavithrasv May 7, 2018
ba0584a
Fix TypeError in update_version.py
May 7, 2018
3a2f1cf
Internal Change.
May 7, 2018
fc7f0b2
Add support for select (via tf.where) support to tflite.
tensorflower-gardener May 7, 2018
37b8860
[XLA] Fix a "we're we're" in the operation semantics.
mkuperst May 7, 2018
4a9beef
Add EvaluateNodes to tests: RemoveIdentityTransposesMultipleOutputs, …
tensorflower-gardener May 7, 2018
27e6ab7
[Remote functions] Only set the default runner *after* resolving the …
mrry May 7, 2018
94b0b2f
[XLA] Make post order a possible schedule as it sometimes uses less m…
blakehechtman May 7, 2018
5802096
Refactor TensorArray to avoid copies and memory allocations when exec…
akshayka May 7, 2018
6e1784b
ShapeRefiner fix: some variant-type tensors have handle data.
skye May 7, 2018
97fea64
Reorder executor NodeItem variable length data section so
tensorflower-gardener May 7, 2018
f64b16a
Add TFX section. Add Ecosystem page and dropdown menu.
lamberta May 7, 2018
a72ee2f
Fast-path to VarHandleOp
alextp May 7, 2018
3964bde
Delete kTransposeDot (it is no longer in use)
May 7, 2018
db63348
Add test with tf.cond.
jpienaar May 7, 2018
b67d8b2
Internal change
blakehechtman May 8, 2018
482ed8e
Raise an error if we try to take the gradient wrt to the initial valu…
skye May 8, 2018
b6906d1
Make eager functions runable on TPU
iganichev May 8, 2018
2585a81
Make conv2d_tranpose_test.py work with C API shapes enabled.
skye May 8, 2018
1af09b5
Add logic for StridedSlice ops in ShapeRefiner::ConstantPartialShape().
skye May 8, 2018
7a9e695
[tf.data] Move tensorflow::dataset::MakeIteratorContext to core/frame…
saeta May 8, 2018
069f312
Temporarily disable concat rewrite.
tensorflower-gardener May 8, 2018
a799cdb
Automated g4 rollback of changelist 195748721
tensorflower-gardener May 8, 2018
392ce20
Fix a test expectation.
tensorflower-gardener May 8, 2018
42115bd
ProfileHandler: Remove unnecessary interface method.
asimshankar May 8, 2018
0bd1408
Add missing #include for OpResponse. This class currently happens to …
tensorflower-gardener May 8, 2018
07fdb69
Automated g4 rollback of changelist 195723288
tensorflower-gardener May 8, 2018
a6a862e
[TF:XLA] Fix NaN in StatelessRandomNormal if the underlying uniform d…
hawkinsp May 8, 2018
4a6e663
Update comment clarifying continuous eval behavior.
tensorflower-gardener May 8, 2018
bd60650
Minor formatting tweaks to distribute.py and simple_tfkeras_example.py
caisq May 8, 2018
77bb984
Free ANeuralNetworksCompilation object in NNAPIDelegate destructor
tensorflower-gardener May 8, 2018
074d290
Add cost model of depthwiseConv2dNative. Tensorflow computes depthwis…
tensorflower-gardener May 8, 2018
83aa323
When building functions, capture tensors in `internal_convert_to_tens…
akshayka May 8, 2018
f0a506f
Fix Raspberry Pi build by making PNG not try to use Neon (by autodete…
aselle May 8, 2018
4ca46dd
Increase size of test //third_party/tensorflow/python:saver_large_var…
tensorflower-gardener May 8, 2018
b62573f
Add affinity binding functionality and documentation to OVIC benchmar…
tensorflower-gardener May 8, 2018
2674930
Change visibility of hlo_proto.
tensorflower-gardener May 8, 2018
211e3a2
Update version of downloadable clang toolchain
ilya-biryukov May 8, 2018
4f7a0bc
Fix docstring for flush() method
tensorflower-gardener May 8, 2018
59bffb7
Re-land: Optimize dot(DynamicSlice(ConstA), ConstantB) by memoizing d…
alinas May 8, 2018
b15500b
Remove outdated CUDA SDK string (the text is now consistent with othe…
tensorflower-gardener May 8, 2018
8fcf647
Better wrapping of stream executor's cuDNN API calls. Replacing mutex…
tensorflower-gardener May 8, 2018
62c50e1
Avoid string formatting in assert_same_float_dtype unless there's an …
allenlavoie May 8, 2018
1d94ed7
Increase size of test tensorflow/contrib/layers:rev_block_lib_test to…
tensorflower-gardener May 8, 2018
d3f3fb5
Increase shard count of tensorflow/contrib/distributions:mixture_test…
tensorflower-gardener May 8, 2018
f58effe
Do not differentiage integers in the eager API.
alextp May 8, 2018
96fa17d
Increase shard_count of tensorflow/python/estimator:estimator_test to…
tensorflower-gardener May 8, 2018
34f6241
Add missing ":haswell" match to list of platform selectors.
tensorflower-gardener May 8, 2018
7138715
Increase shard count of tensorflow/python/keras:lstm_test to avoid fl…
tensorflower-gardener May 8, 2018
8039c94
Increase size of tensorflow/contrib/distributions:batch_reshape_test …
tensorflower-gardener May 8, 2018
d4d9759
Increase shard count of tensorflow/contrib/learn:state_saving_rnn_est…
tensorflower-gardener May 8, 2018
241e828
Add test to test suite.
tensorflower-gardener May 8, 2018
0028bf8
add test for pruning useless function lib in graph.
tensorflower-gardener May 8, 2018
bbebae0
Only use integer values for event_ndims.
tensorflower-gardener May 8, 2018
79b773a
Set size of tensorflow/python/keras:normalization_test to medium to a…
tensorflower-gardener May 8, 2018
14d5f21
Make eager functions runable on TPU
iganichev May 8, 2018
24c9174
Merge commit for internal changes
yifeif May 9, 2018
d14a530
Hardcode EndpointSpec deprecated input to False for now after cl/1957…
yifeif May 9, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
4 changes: 2 additions & 2 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,8 +845,8 @@ def reformat_version_sequence(version_str, sequence_count):
def set_tf_cuda_version(environ_cp):
"""Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION."""
ask_cuda_version = (
'Please specify the CUDA SDK version you want to use, '
'e.g. 7.0. [Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION
'Please specify the CUDA SDK version you want to use. '
'[Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION

for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
# Configure the Cuda SDK version to use.
Expand Down
48 changes: 48 additions & 0 deletions tensorflow/c/c_api_experimental.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8407,3 +8407,51 @@ TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id,
}
return ret;
}

void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
TF_Tensor* tensor, TF_Status* status) {
assert(session);
{
tensorflow::mutex_lock c(session->graph->mu);
if (VLOG_IS_ON(1)) {
VLOG(1) << "Enqueuing named tensor with id " << tensor_id
<< ", with input graph: "
<< session->graph->graph.ToGraphDefDebug().DebugString();
tensorflow::Tensor internal_tensor;
if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) {
VLOG(1) << "Enqueu'ing tensor content: "
<< internal_tensor.DebugString();
}
}
}

TF_Operation* enqueue_op = TF_GraphOperationByName(
session->graph,
tensorflow::strings::StrCat("fifo_queue_enqueue_", tensor_id).c_str());
if (enqueue_op == nullptr) {
status->status = tensorflow::errors::Internal(
"Unable to find the enqueue node in the TF graph.");
return;
}

TF_Operation* placeholder_op = TF_GraphOperationByName(
session->graph,
tensorflow::strings::StrCat("arg_tensor_enqueue_", tensor_id).c_str());
if (placeholder_op == nullptr) {
status->status = tensorflow::errors::Internal(
"Unable to find the placeholder node as input to enqueue in the TF "
"graph.");
return;
}

VLOG(1) << "Running the enqueue op";
TF_Output input{placeholder_op, 0};
TF_SessionRun(session, /*run_options*/ nullptr,
// input related parameters
/*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1,
// output related parameters
/*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0,
/*targets*/ &enqueue_op, /*ntargets*/ 1,
/*run_metadata*/ nullptr, status);
VLOG(1) << "Enqueuing is done.";
}
23 changes: 21 additions & 2 deletions tensorflow/c/c_api_experimental.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,34 @@ TF_CAPI_EXPORT extern TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets(
unsigned char is_mnist, TF_Status* status);

// On success, dequeues a tensor from a TF-managed FifoQueue given by
// `tensor_id`, associated with `session`. Caller must call TF_DeleteTensor()
// over the returned tensor. If the queue is empty, this call is blocked.
// `tensor_id`, associated with `session`. There must be a graph node named
// "fifo_queue_dequeue_<tensor_id>", to be executed by this API call.

// Caller must call TF_DeleteTensor() over the returned tensor. If the queue is
// empty, this call is blocked.
//
// Tensors are enqueued via the corresponding TF enqueue op.
// TODO(hongm): Add support for `timeout_ms`.
TF_CAPI_EXPORT extern TF_Tensor* TF_DequeueNamedTensor(TF_Session* session,
int tensor_id,
TF_Status* status);

// On success, enqueues `tensor` into a TF-managed FifoQueue given by
// `tensor_id`, associated with `session`. There must be a graph node named
// "fifo_queue_enqueue_<tensor_id>", to be executed by this API call. It reads
// from a placeholder node "arg_tensor_enqueue_<tensor_id>".
//
// `tensor` is still owned by the caller. This call will be blocked if the queue
// has reached its capacity, and will be unblocked when the queued tensors again
// drop below the capacity due to dequeuing.
//
// Tensors are dequeued via the corresponding TF dequeue op.
// TODO(hongm): Add support for `timeout_ms`.
TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
int tensor_id,
TF_Tensor* tensor,
TF_Status* status);

#ifdef __cplusplus
} /* end extern "C" */
#endif
Expand Down
36 changes: 29 additions & 7 deletions tensorflow/c/eager/tape.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,15 @@ class GradientTape {
}
}

bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids);
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes);

void Watch(int64 tensor_id);

void RecordOperation(const string& op_type,
gtl::ArraySlice<TapeTensor> output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
BackwardFunction* backward_function,
const std::function<void()>& backward_function_deleter);

Expand Down Expand Up @@ -170,12 +172,30 @@ class GradientTape {

// Template instantiations here

inline bool IsDtypeTrainable(DataType dtype) {
switch (dtype) {
case DT_HALF:
case DT_BFLOAT16:
case DT_FLOAT:
case DT_DOUBLE:
case DT_COMPLEX64:
case DT_COMPLEX128:
case DT_RESOURCE:
case DT_VARIANT:
return true;
default:
return false;
}
}

template <typename Gradient, typename BackwardFunction>
bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
gtl::ArraySlice<int64> tensor_ids) {
for (int64 i : tensor_ids) {
if (tensor_tape_.find(i) != tensor_tape_.end()) {
return true;
gtl::ArraySlice<int64> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes) {
CHECK_EQ(tensor_ids.size(), dtypes.size());
for (int i = 0; i < tensor_ids.size(); ++i) {
if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) {
return IsDtypeTrainable(dtypes[i]);
}
}
return false;
Expand All @@ -189,9 +209,11 @@ void GradientTape<Gradient, BackwardFunction>::Watch(int64 tensor_id) {
template <typename Gradient, typename BackwardFunction>
void GradientTape<Gradient, BackwardFunction>::RecordOperation(
const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
gtl::ArraySlice<int64> input_tensor_id, BackwardFunction* backward_function,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
BackwardFunction* backward_function,
const std::function<void()>& backward_function_deleter) {
if (!ShouldRecord(input_tensor_id)) {
if (!ShouldRecord(input_tensor_id, input_dtypes)) {
backward_function_deleter();
return;
}
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/compiler/aot/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ test_suite(
":test_graph_tfadd_with_ckpt_saver_test",
":test_graph_tfadd_with_ckpt_test",
":test_graph_tfassert_eq_test",
":test_graph_tfcond_test",
":test_graph_tffunction_test",
":test_graph_tfgather_test",
":test_graph_tfmatmul_test",
Expand Down Expand Up @@ -55,6 +56,7 @@ genrule(
"test_graph_tfadd_with_ckpt_saver.pb",
"test_graph_tfadd_with_ckpt_saver.saver",
"test_graph_tfassert_eq.pb",
"test_graph_tfcond.pb",
"test_graph_tffunction.pb",
"test_graph_tfgather.pb",
"test_graph_tfmatmul.pb",
Expand Down Expand Up @@ -118,6 +120,17 @@ tf_library(
],
)

tf_library(
name = "test_graph_tfcond",
testonly = 1,
config = "test_graph_tfcond.config.pbtxt",
cpp_class = "CondComp",
graph = "test_graph_tfcond.pb",
tags = [
"manual",
],
)

tf_library(
name = "test_graph_tffunction",
testonly = 1,
Expand Down Expand Up @@ -194,6 +207,7 @@ tf_cc_test(
":test_graph_tfadd_with_ckpt",
":test_graph_tfadd_with_ckpt_saver",
":test_graph_tfassert_eq",
":test_graph_tfcond",
":test_graph_tffunction",
":test_graph_tfgather",
":test_graph_tfmatmul",
Expand Down
29 changes: 19 additions & 10 deletions tensorflow/compiler/aot/tests/make_test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,22 @@ def tfadd_with_ckpt_saver(out_dir):
f.write(saver.as_saver_def().SerializeToString())


def tfassert_eq(_):
x = array_ops.placeholder(dtypes.int32, name='x_hold')
y = array_ops.placeholder(dtypes.int32, name='y_hold')
control_flow_ops.Assert(
math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq')
math_ops.add(x, math_ops.negative(y), name='x_y_diff')


def tfcond(_):
p = array_ops.placeholder(dtypes.bool, name='p_hold')
x = array_ops.placeholder(dtypes.int32, name='x_hold')
y = array_ops.placeholder(dtypes.int32, name='y_hold')
z = control_flow_ops.cond(p, lambda: x, lambda: y)
array_ops.identity(z, name='result')


def tfgather(_):
params = array_ops.placeholder(dtypes.float32, name='params')
indices = array_ops.placeholder(dtypes.int32, name='indices')
Expand Down Expand Up @@ -126,14 +142,6 @@ def tfsplits(_):
array_ops.identity(y, name='result')


def tfassert_eq(_):
x = array_ops.placeholder(dtypes.int32, name='x_hold')
y = array_ops.placeholder(dtypes.int32, name='y_hold')
control_flow_ops.Assert(
math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq')
math_ops.add(x, math_ops.negative(y), name='x_y_diff')


def write_graph(build_graph, out_dir):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
Expand All @@ -148,12 +156,13 @@ def main(_):
write_graph(tfadd, FLAGS.out_dir)
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
write_graph(tfassert_eq, FLAGS.out_dir)
write_graph(tfcond, FLAGS.out_dir)
write_graph(tffunction, FLAGS.out_dir)
write_graph(tfgather, FLAGS.out_dir)
write_graph(tfmatmul, FLAGS.out_dir)
write_graph(tfmatmulandadd, FLAGS.out_dir)
write_graph(tffunction, FLAGS.out_dir)
write_graph(tfsplits, FLAGS.out_dir)
write_graph(tfassert_eq, FLAGS.out_dir)


if __name__ == '__main__':
Expand Down
20 changes: 20 additions & 0 deletions tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "p_hold" }
shape {}
}
feed {
id { node_name: "x_hold" }
shape {
dim { size: 1 }
}
}
feed {
id { node_name: "y_hold" }
shape {
dim { size: 1 }
}
}
fetch {
id { node_name: "result" }
}
26 changes: 26 additions & 0 deletions tensorflow/compiler/aot/tests/tfcompile_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfcond.h"
#include "tensorflow/compiler/aot/tests/test_graph_tffunction.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
Expand Down Expand Up @@ -150,6 +151,31 @@ TEST(TFCompileTest, AddWithCkptSaver) {
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
}

TEST(TFCompileTest, Cond) {
CondComp cond;
EXPECT_EQ(cond.arg0_data(), cond.args()[0]);
EXPECT_EQ(cond.arg1_data(), cond.args()[1]);
EXPECT_EQ(cond.arg2_data(), cond.args()[2]);
cond.arg1() = 10;
cond.arg2() = 20;
{
cond.arg0() = true;
const int32 expected_result = cond.arg1();
EXPECT_TRUE(cond.Run());
EXPECT_EQ(cond.result0(), expected_result);
EXPECT_EQ(cond.result0_data()[0], expected_result);
EXPECT_EQ(cond.result0_data(), cond.results()[0]);
}
{
cond.arg0() = false;
const int32 expected_result = cond.arg2();
EXPECT_TRUE(cond.Run());
EXPECT_EQ(cond.result0(), expected_result);
EXPECT_EQ(cond.result0_data()[0], expected_result);
EXPECT_EQ(cond.result0_data(), cond.results()[0]);
}
}

TEST(TFCompileTest, Gather) {
GatherComp gather;
EXPECT_EQ(gather.arg0_data(), gather.args()[0]);
Expand Down
24 changes: 24 additions & 0 deletions tensorflow/compiler/jit/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ cc_library(
name = "create_xla_launch_op",
srcs = [
"create_xla_launch_op.cc",
"create_xla_launch_op.h",
],
deps = [
":common",
Expand All @@ -270,6 +271,29 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
],
alwayslink = 1,
)

tf_cc_test(
name = "create_xla_launch_op_test",
srcs = [
"create_xla_launch_op.h",
"create_xla_launch_op_test.cc",
],
deps = [
":create_xla_launch_op",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:session_options",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"@com_google_absl//absl/memory",
],
)

Expand Down