Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch 177545934 #15024

Merged
merged 128 commits into from
Dec 2, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
128 commits
Select commit Hold shift + click to select a range
bc8718b
Removed unused variables from curl_http_request_test.
tensorflower-gardener Nov 28, 2017
cbd31dd
Eager function definition no longer adds control dependencies after c…
alextp Nov 28, 2017
570d277
Bump LLVM snapshot to r319150.
tensorflower-gardener Nov 28, 2017
ba87a80
Eager: Better errors for invalid options to enable_eager_execution.
asimshankar Nov 28, 2017
b911049
Continue to allow old argument names specified in tf.flags.DEFINE fun…
yilei Nov 28, 2017
a86f2d2
Added an option to assume that the shape of fed nodes in unknown sinc…
benoitsteiner Nov 28, 2017
d96e936
Add custom_estimators
MarkDaoust Nov 28, 2017
c68d35b
Change HLO verifier semantics for bitcasts to: Bitcasts that are not …
tensorflower-gardener Nov 28, 2017
9c7fd28
Allow unsorted_segment_sum and unsorted_segment_max to take int64 num…
jtkeeling Nov 28, 2017
f93c8a7
[XLA:CPU] Avoid using the untiled lowering for dot when possible
Nov 28, 2017
c294fcf
Dataset support within Estimator. With this cl Input_fn can return a …
ispirmustafa Nov 28, 2017
49bb801
Changed to allow removing side-effect instructions from an HLO comput…
hyouklee Nov 28, 2017
b5683d2
Change description of sparse_column_with_integerized_feature to make …
tensorflower-gardener Nov 28, 2017
bb33903
add parentheses because this test is failing in my current CL
itsmeolivia Nov 28, 2017
e917bf7
[TF] Improve LogSoftMax performance
tensorflower-gardener Nov 28, 2017
f252ea2
Deprecating `tf.data.Dataset.from_sparse_tensor_slices`.
jsimsa Nov 28, 2017
efe5658
Make 'input_map' argument to import_graph_def work with C API.
skye Nov 28, 2017
9306dd9
Add bool value type support for gauge metrics.
tensorflower-gardener Nov 28, 2017
8966a79
Add a log test for bfloat16.
yunxing Nov 28, 2017
d72e2a3
scatter_nd_update for resource variables
alextp Nov 28, 2017
a99e9a2
Support tfe.Network.losses
allenlavoie Nov 28, 2017
a6ee905
Add non_trainable_variables to templates.
sguada Nov 28, 2017
d8de0d9
Fixing the windows nightly build.
Nov 28, 2017
5f1b61b
Check per HLO instruction only at vlog=1 in non-opt build.
jpienaar Nov 28, 2017
b8969d1
Mark Supervisor deprecated. Please use MonitoredTrainingSession instead.
martinwicke Nov 28, 2017
5a1e22b
Remove temp_workaround_http_archive.
Nov 28, 2017
9049b44
Fix tensorflow-android jcenter link
andrewharp Nov 29, 2017
a80fd2a
C API: fix bug in ValidateNoCycles().
skye Nov 29, 2017
f22261e
Add depthwise ops for NAS cell in nn_ops_test to improve the inferenc…
tensorflower-gardener Nov 29, 2017
b89251c
[TF:XLA] Implement Cumsum and Cumprod using the XLA ReduceWindow oper…
hawkinsp Nov 29, 2017
f2f6356
Automated g4 rollback of changelist 177191521
Nov 29, 2017
625ae88
Round-robin variables across local devices with `replicate_model_fn`.
isaprykin Nov 29, 2017
57839fb
Re-using (the more general) DeserializeSparse kernel to implement Des…
jsimsa Nov 29, 2017
73a803f
Fix flakiness in map_dataset_op_test.
saxenasaurabh Nov 29, 2017
e2f9107
Adds minor-dim pooling tests for cases in which windows exist entir…
tayo Nov 29, 2017
782ec4e
Silenced noisy log
benoitsteiner Nov 29, 2017
bf05a2d
Support shape inference (i.e., shapes containing -1) in the Reshape b…
tensorflower-gardener Nov 29, 2017
60d2e51
Introduce tf.contrib.summary.flush
jart Nov 29, 2017
d2e7a2e
Add VLOG-ging to gcs_file_system
saeta Nov 29, 2017
bdde4d0
[XLA] Support transposing the spatial dimensions of a convolution's a…
majnemer Nov 29, 2017
a55ee58
Include the filename when we encounter EOF
saeta Nov 29, 2017
05f5785
Bugfixes: Gather's gradient for 2+ dimensional indices with eager exe…
asimshankar Nov 29, 2017
2ab7e9d
Go: Add some more detail to an error message.
asimshankar Nov 29, 2017
a7c13b3
Update TF Android build instructions to warn about NDK 16
angerson Nov 29, 2017
9aeb0ee
Add logging to help differentiate multiple stacktraces
saeta Nov 29, 2017
4c2ca8b
Fix typo in GradientTape.persistent_ comment.
iganichev Nov 29, 2017
bc87c28
Register fp16 Reduce min on GPU.
protoget Nov 29, 2017
6196d30
Add alpha support for dealing with shardings to batchnorm rewriter
tensorflower-gardener Nov 29, 2017
4b60f92
Change "safe pointers" to make the deleters stateless (i.e. a type, n…
tensorflower-gardener Nov 29, 2017
e6d823d
Make NCCL code ready for NVIDIA's NCCL 2.
tensorflower-gardener Nov 29, 2017
18a36a8
[XLA:CPU] Factor IR function building logic out of IrEmitter into its…
tensorflower-gardener Nov 29, 2017
e7e1cab
[XLA:CPU] Factor out parallel loop emission into its own file so it c…
tensorflower-gardener Nov 29, 2017
667282e
[TFXLA] Return nullopt if no merge node found.
jpienaar Nov 29, 2017
537ecc5
[tf.data] Remove GraphDefBuilder and NodeBuilder dependencies from "d…
mrry Nov 29, 2017
fa8bfa8
Add feature_columns doc.
MarkDaoust Nov 29, 2017
c27a90d
[TF:XLA] VariableShape op support.
asimshankar Nov 29, 2017
2229a6c
Internal Change
tensorflower-gardener Nov 29, 2017
7921d01
Raise an exception when converting lists with invalid lengths to Tens…
allenlavoie Nov 29, 2017
c572bc4
Outline generated LLVM IR matrix-vector dot kernels
Nov 29, 2017
78a4873
Go: Bugfix: Make list-of-shape attributes in an operation work.
asimshankar Nov 29, 2017
71f22bb
(Temporarily) call Graph._add_op outside of Operation.__init__ again.
skye Nov 29, 2017
ad1310a
Add RandomDataset which generates pseudo random number of type int64.
saxenasaurabh Nov 29, 2017
c0bd9df
[tf.data] Fix compiler warnings about unused captures in lambda expre…
mrry Nov 29, 2017
dcf9b03
Made sure the unknown shapes of placeholders always propagate to thei…
benoitsteiner Nov 29, 2017
97da160
Allow the toolchain defaults to be used instead of hard-coding -Os.
tensorflower-gardener Nov 29, 2017
037acad
Deleted unused method arguments
benoitsteiner Nov 29, 2017
d1aea3b
Clarify the role of replicate_model_fn.Mode better.
isaprykin Nov 29, 2017
197850f
enabling Tensor._set_shape() to work with the C API
itsmeolivia Nov 29, 2017
e00156b
Proper deallocation in the thread-local tape stack.
alextp Nov 29, 2017
d3a8bf0
Added comment/TODO concerning memory use of extract_images_patches.
tensorflower-gardener Nov 29, 2017
19f62f6
Re-enable Mul hoisting for aggregations other than Add when input sha…
tensorflower-gardener Nov 29, 2017
48347ee
Simplify const node creation.
Nov 29, 2017
1d0b073
Add a way to query a batch scheduler to determine the max task size.
Nov 29, 2017
cb5a63d
Check when session cannot run because its graph was modified
iganichev Nov 29, 2017
d0d8596
Add R1 slice tests.
tensorflower-gardener Nov 29, 2017
aeba523
Updating references to the `tf.data` API to `tf.data` from `Datasets`.
jsimsa Nov 29, 2017
32b861d
Automated g4 rollback of changelist 177353959
gunan Nov 29, 2017
963a521
Using the C API in eager mode for graph functions.
alextp Nov 29, 2017
bc7180f
Fix more clang-tidy warnings:
eliben Nov 30, 2017
4ada275
Change `tf.contrib.distributions` docstring examples to use `tfd` ali…
jvdillon Nov 30, 2017
cb4ef36
Add native dilated support for conv2d and its gradients in cudnn v>=6.
Nov 30, 2017
b97585f
Always leverage shapes inference now that it can handle fed nodes
benoitsteiner Nov 30, 2017
4422aaa
Automated g4 rollback of changelist 177375237
gunan Nov 30, 2017
bec3d96
Automated g4 rollback of changelist 177362829
tensorflower-gardener Nov 30, 2017
ad3213b
Disable baseline_test in asan.
gunan Nov 30, 2017
8a98563
Add support for int32 output types to the Multinomial op.
hawkinsp Nov 30, 2017
976049b
Implement Python-specific device and colocation logic in import_graph…
skye Nov 30, 2017
1c48101
Hoist function input placeholders out of any control flow context.
skye Nov 30, 2017
9308470
Rename tests.
tensorflower-gardener Nov 30, 2017
f283173
[TF:XLA] Add support for the V2 variants of the FusedBatchNorm operat…
hawkinsp Nov 30, 2017
1297674
Uses C API for eager functions.
alextp Nov 30, 2017
b8c9b75
Change "Datasets" to "`tf.data`" in the "Reading Data" API guide.
mrry Nov 30, 2017
3b9a26d
Turned a verbose log into a vlog
benoitsteiner Nov 30, 2017
0369392
Internal testing change.
rjpower Nov 30, 2017
4e8301b
Disable dnn_linear_combined_test
gunan Nov 30, 2017
ea1c295
Change depthwise convolution filter expansion and contraction with
blakehechtman Nov 30, 2017
4146ff1
[XLA] Adds Dot with DotDimensionNumbers proto for specifying arbitrar…
tensorflower-gardener Nov 30, 2017
eafa8ef
[XLA:CPU] Add Hlo profiling support to XlaJitCompiledCpuFunction
Nov 30, 2017
af36437
Add rules to replace nodes corresponding to operations with the neutr…
tensorflower-gardener Nov 30, 2017
5e54d87
Add an option to override maximum number of elements in the quantile
tensorflower-gardener Nov 30, 2017
62b70f5
Support binary operations with a scalar and a 4d tensor as input; ref…
Nov 30, 2017
15b06e0
Add R1 slice tests.
tensorflower-gardener Nov 30, 2017
39cac05
[TF:XLA] Allow bfloat16 types in more places.
hawkinsp Nov 30, 2017
bc2b4b0
Enable tests that pass now with the new copy insertion.
tensorflower-gardener Nov 30, 2017
99525c7
Automated g4 rollback of changelist 177499365
gunan Nov 30, 2017
3a011f9
Disable state_saving_rnn_estimator_test in asan mode.
gunan Nov 30, 2017
6ebb6d6
TF server should not crash when -v=1 is enabled.
tensorflower-gardener Nov 30, 2017
6bfc73a
Extract out a MathUtil::GCD helper
Nov 30, 2017
ce4200e
Fix profiler to track some missed persistent bytes.
tensorflower-gardener Nov 30, 2017
f69915d
[XLA] Sanitize hlo names to match regexp "[a-zA-Z_][a-zA-Z0-9_.-]*".
tensorflower-gardener Nov 30, 2017
0472116
[TF:XLA] Make tf_cnn_benchmarks run on CPU with XLA.
tensorflower-gardener Nov 30, 2017
186caed
Add int64 support to XLA Shape op.
tensorflower-gardener Nov 30, 2017
0438ac7
[TF:XLA] Use output spatial dimensions instead of a transpose for conv
majnemer Dec 1, 2017
b2db981
Merge changes from github.
Dec 1, 2017
7ab54c4
Support compressed TensorProto format in constant folding for types i…
tensorflower-gardener Dec 1, 2017
1a89cf5
Output unknown dimension root nodes with --vmodule=graph_properties=2
yacoder Dec 1, 2017
6e16af8
Register more ops with bfloat16 types.
hawkinsp Dec 1, 2017
87e2f20
Automated g4 rollback of changelist 177505909
Dec 1, 2017
79ad4a4
enabling Tensor._set_shape() to work with the C API
itsmeolivia Dec 1, 2017
6968ff0
Disable tuning for now. Re-enable when measurement-based estimator is…
Dec 1, 2017
1ec61fa
Use latest nsync in tensorflow. Latest nsync builds with bazel on Fr…
tensorflower-gardener Dec 1, 2017
2c4e8fc
Fix merge conflicts
Dec 1, 2017
f70c56e
Merge branch 'master' into branch_177545934
sb2nov Dec 1, 2017
232db28
Fix long stride op issue
Dec 1, 2017
e6c6b6d
Merge branch 'master' into branch_177545934
sb2nov Dec 2, 2017
728896d
Fix broken test
Dec 2, 2017
c8a86e5
Merge branch 'branch_177545934' of github.com:sb2nov/tensorflow into …
Dec 2, 2017
9c6b708
Disable the lite test
Dec 2, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
47 changes: 35 additions & 12 deletions tensorflow/c/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,12 +383,11 @@ void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers,
// be less than the total node count.
Status ValidateNoCycles(const Graph& g) {
// TODO(nolivia): check this on a subset of the graph instead of all of it.
int total_num_nodes = g.num_node_ids();
// A node is ready when all of its inputs have been visited.
std::vector<const Node*> ready;
std::vector<int> pending_count(total_num_nodes, 0);
std::vector<int> pending_count(g.num_node_ids(), 0);

for (int i = 0; i < total_num_nodes; ++i) {
for (int i = 0; i < g.num_node_ids(); ++i) {
const Node* n = g.FindNodeId(i);
if (n == nullptr) continue;
pending_count[i] = n->in_edges().size();
Expand Down Expand Up @@ -421,7 +420,7 @@ Status ValidateNoCycles(const Graph& g) {
}
}

if (processed < total_num_nodes) {
if (processed < g.num_nodes()) {
std::vector<string> nodes_in_cycle;
for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3;
++i) {
Expand All @@ -430,7 +429,7 @@ Status ValidateNoCycles(const Graph& g) {
}
}
return errors::InvalidArgument(
"Graph is invalid, contains a cycle with ", total_num_nodes - processed,
"Graph is invalid, contains a cycle with ", g.num_nodes() - processed,
" nodes, including: ", str_util::Join(nodes_in_cycle, ", "));
}
return Status::OK();
Expand Down Expand Up @@ -625,6 +624,23 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in,
return Status::OK();
}

void RecordMutation(TF_Graph* graph, const TF_Operation& op,
const char* mutation_type)
EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
// If any session has already run this node_id, mark this session as
// unrunnable.
for (auto it : graph->sessions) {
if (it.first->last_num_graph_nodes > op.node.id()) {
it.second = FailedPrecondition(
"Operation '", op.node.DebugString(), "' was changed by ",
mutation_type,
" after it was run by a session. Nodes can be mutated "
"only before they are executed by a session. Either don't modify "
"nodes after running them or create a new session.");
}
}
}

// Helpers for loading a TensorFlow plugin (a .so file).
Status LoadLibrary(const char* library_filename, void** result,
const void** buf, size_t* len);
Expand Down Expand Up @@ -1745,7 +1761,6 @@ void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
TF_Graph::TF_Graph()
: graph(tensorflow::OpRegistry::Global()),
refiner(graph.versions().producer(), graph.op_registry()),
num_sessions(0),
delete_requested(false),
parent(nullptr),
parent_inputs(nullptr) {}
Expand All @@ -1755,7 +1770,7 @@ TF_Graph* TF_NewGraph() { return new TF_Graph; }
void TF_DeleteGraph(TF_Graph* g) {
g->mu.lock();
g->delete_requested = true;
const bool del = g->num_sessions == 0;
const bool del = g->sessions.empty();
g->mu.unlock();
if (del) delete g;
}
Expand Down Expand Up @@ -2325,11 +2340,12 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
Session* session;
status->status = NewSession(opt->options, &session);
if (status->status.ok()) {
TF_Session* new_session = new TF_Session(session, graph);
if (graph != nullptr) {
mutex_lock l(graph->mu);
graph->num_sessions += 1;
graph->sessions[new_session] = Status::OK();
}
return new TF_Session(session, graph);
return new_session;
} else {
DCHECK_EQ(nullptr, session);
return nullptr;
Expand Down Expand Up @@ -2393,7 +2409,7 @@ TF_Session* TF_LoadSessionFromSavedModel(

TF_Session* session = new TF_Session(bundle.session.release(), graph);

graph->num_sessions += 1;
graph->sessions[session] = Status::OK();
session->last_num_graph_nodes = graph->graph.num_node_ids();
return session;
#endif // __ANDROID__
Expand All @@ -2408,8 +2424,8 @@ void TF_DeleteSession(TF_Session* s, TF_Status* status) {
TF_Graph* const graph = s->graph;
if (graph != nullptr) {
graph->mu.lock();
graph->num_sessions -= 1;
const bool del = graph->delete_requested && graph->num_sessions == 0;
graph->sessions.erase(s);
const bool del = graph->delete_requested && graph->sessions.empty();
graph->mu.unlock();
if (del) delete graph;
}
Expand All @@ -2425,6 +2441,13 @@ static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
mutex_lock session_lock(session->mu);
session->graph->mu.lock();
const Graph& graph = session->graph->graph;

status->status = session->graph->sessions[session];
if (!status->status.ok()) {
session->graph->mu.unlock();
return false;
}

const auto num_nodes = graph.num_node_ids();
if (session->last_num_graph_nodes < num_nodes) {
status->status = tensorflow::ValidateNoCycles(session->graph->graph);
Expand Down
21 changes: 16 additions & 5 deletions tensorflow/c/c_api_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,20 @@ struct TF_Graph {
std::unordered_map<tensorflow::string, tensorflow::Node*> name_map
GUARDED_BY(mu);

// TF_Graph may only / must be deleted when
// num_sessions == 0 && delete_requested == true

// num_sessions incremented by TF_NewSession, and decremented by
// The keys of this map are all the active sessions using this graph.
// Each value is the current "runnability" status of the corresponding
// session. Under normal conditions all statuses are Status::OK(), but
// if some operation is mutated after it was run by a session (this
// is detected in RecordMutation function), that session is no longer
// safe to run. Its status will contain the error that will be returned
// to the user, should she try running this session.
//
// Sessions are added to this map in TF_NewSession, and removed in
// TF_DeleteSession.
int num_sessions GUARDED_BY(mu);
// TF_Graph may only / must be deleted when
// sessions.size() == 0 && delete_requested == true
tensorflow::gtl::FlatMap<TF_Session*, tensorflow::Status> sessions
GUARDED_BY(mu);
bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph

// Used to link graphs contained in TF_WhileParams to the parent graph that
Expand Down Expand Up @@ -167,6 +175,9 @@ TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);

Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out);

void RecordMutation(TF_Graph* graph, const TF_Operation& op,
const char* mutation_type);

} // end namespace tensorflow

#endif // TENSORFLOW_C_C_API_INTERNAL_H_
2 changes: 1 addition & 1 deletion tensorflow/c/eager/tape.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class GradientTape {
// the tape refer to it); to aid in tape garbage collection.
std::unordered_map<int64, int64> tensor_usage_;

// If true, all activations are deleted in the first call to ComputeGradient.
// If false, all activations are deleted in the first call to ComputeGradient.
// Else, only when this is destructed.
bool persistent_;
};
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/c/python_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ namespace tensorflow {
void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
mutex_lock l(graph->mu);
graph->graph.AddControlEdge(&input->node, &op->node);
RecordMutation(graph, *op, "adding control input");
}

void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
Expand All @@ -36,11 +37,13 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,

mutex_lock l(graph->mu);
op->node.AddAttr(attr_name, attr_val);
RecordMutation(graph, *op, "setting attribute");
}

void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
mutex_lock l(graph->mu);
op->node.set_requested_device(device);
RecordMutation(graph, *op, "setting device");
}

void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
Expand Down Expand Up @@ -75,6 +78,13 @@ void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
}
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
&dst.oper->node, dst.index);

if (status->status.ok()) {
// This modification only updates the destination node for
// the purposes of running this graph in a session. Thus, we don't
// record the source node as being modified.
RecordMutation(graph, *dst.oper, "updating input tensor");
}
}

} // namespace tensorflow
8 changes: 4 additions & 4 deletions tensorflow/compiler/aot/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ namespace xla { class ExecutableRunOptions; }
// (Implementation detail) Entry point to the function in the object file.
extern "C" void {{ENTRY}}(
void* result, const xla::ExecutableRunOptions* run_options,
const void** args, void** temps);
const void** args, void** temps, tensorflow::int64* profile_counters);

{{NS_START}}
// {{CLASS}} represents a computation previously specified in a
Expand Down Expand Up @@ -483,7 +483,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
return *kStaticData;
}

{{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS)
{{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS)
: XlaCompiledCpuFunction(StaticData(), alloc_mode) {}

{{CLASS}}(const {{CLASS}}&) = delete;
Expand All @@ -496,8 +496,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
// void set_argN_data(void* data)
// Sets the buffer of type T for positional argument N. May be called in
// any AllocMode. Must be called before Run to have an affect. Must be
// called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument,
// to set the argument buffers.
// called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional
// argument, to set the argument buffers.
//
// T* argN_data()
// Returns the buffer of type T for positional argument N.
Expand Down
8 changes: 4 additions & 4 deletions tensorflow/compiler/aot/codegen_test_h.golden
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace xla { class ExecutableRunOptions; }
// (Implementation detail) Entry point to the function in the object file.
extern "C" void entry_point(
void* result, const xla::ExecutableRunOptions* run_options,
const void** args, void** temps);
const void** args, void** temps, tensorflow::int64* profile_counters);

namespace foo {
namespace bar {
Expand Down Expand Up @@ -86,7 +86,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
return *kStaticData;
}

MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS)
MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS)
: XlaCompiledCpuFunction(StaticData(), alloc_mode) {}

MyClass(const MyClass&) = delete;
Expand All @@ -99,8 +99,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
// void set_argN_data(void* data)
// Sets the buffer of type T for positional argument N. May be called in
// any AllocMode. Must be called before Run to have an affect. Must be
// called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument,
// to set the argument buffers.
// called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional
// argument, to set the argument buffers.
//
// T* argN_data()
// Returns the buffer of type T for positional argument N.
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/aot/tests/tfcompile_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ TEST(TFCompileTest, Add) {
// Run tests that use set_argN_data separately, to avoid accidentally re-using
// non-existent buffers.
TEST(TFCompileTest, Add_SetArg) {
AddComp add(AddComp::AllocMode::RESULTS_AND_TEMPS_ONLY);
AddComp add(AddComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);

int32 arg_x = 10;
int32 arg_y = 32;
Expand Down Expand Up @@ -258,7 +258,7 @@ TEST(TFCompileTest, MatMul2_SetArg) {
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());

foo::bar::MatMulComp matmul(
foo::bar::MatMulComp::AllocMode::RESULTS_AND_TEMPS_ONLY);
foo::bar::MatMulComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
matmul.set_thread_pool(&device);

// Test using the set_argN_data() methods.
Expand Down
7 changes: 7 additions & 0 deletions tensorflow/compiler/jit/mark_for_compilation_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,13 @@ Status FindCompilationCandidates(
!IsCompilableWhile(*node, jit_device_type, 0, lib_runtime)) {
continue;
}
// _Retval nodes in a top-level function represent fetches.
// Do not compile them.
if (node->type_string() == "_Retval") {
VLOG(2) << "Compilation rejected node: return value " << node->name()
<< ": " << node->type_string();
continue;
}
candidates->insert(node);
}
return Status::OK();
Expand Down
27 changes: 27 additions & 0 deletions tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -525,5 +525,32 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
"+-- c\n"));
}

TEST(XlaCompilationTest, Retval) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
.WithName("A")
.WithAttr("dtype", DT_FLOAT)
.WithAttr("value", Tensor()));
Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
ops::UnaryOp("_Retval", b,
builder.opts()
.WithName("R")
.WithAttr("T", DT_FLOAT)
.WithAttr("index", 0));

TF_EXPECT_OK(builder.ToGraph(graph.get()));
}

TF_ASSERT_OK(MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);

EXPECT_EQ(2, clusters.size());
EXPECT_TRUE(clusters.find("R") == clusters.cend());
EXPECT_EQ(clusters["A"], clusters["B"]);
}

} // namespace
} // namespace tensorflow
14 changes: 14 additions & 0 deletions tensorflow/compiler/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,20 @@ tf_xla_py_test(
],
)

tf_xla_py_test(
name = "scan_ops_test",
size = "small",
srcs = ["scan_ops_test.py"],
tags = ["optonly"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)

tf_xla_py_test(
name = "segment_reduction_ops_test",
size = "medium",
Expand Down