Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch 192771889 #18497

Merged
merged 110 commits into from Apr 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
110 commits
Select commit Hold shift + click to select a range
a9a3b98
Import FunctionDef as GrapplerFunctionItem
tensorflower-gardener Apr 11, 2018
08a12ca
Add a clear error message for when a doc does not have a title.
MarkDaoust Apr 11, 2018
8f75385
Add gradient in cond test to match CallGradInLoop.
jpienaar Apr 11, 2018
5757d09
Use tf.train.get_or_create_global_step() instead of deprecated variab…
tensorflower-gardener Apr 11, 2018
48b2bdc
Fix uninitialized value.
tensorflower-gardener Apr 11, 2018
2ea5c1e
Disable prelu tests for real now.
tensorflower-gardener Apr 11, 2018
8b17a17
Script to create custom_ops inside a TensorFlow graphdef.
tensorflower-gardener Apr 11, 2018
abc26c1
Update docs for softmax_cross_entropy_with_logits.
adria-p Apr 11, 2018
5eccb5a
Increase size of tensorflow/contrib/data/python/kernel_tests:batch_da…
mrry Apr 11, 2018
cc15251
Internal TF Lite test changes
angerson Apr 11, 2018
d983832
Adding hp5y back.
rohan100jain Apr 11, 2018
c5d59c6
Internal change.
annarev Apr 11, 2018
371d513
DepthwiseConv Optimizations
tensorflower-gardener Apr 11, 2018
744a5cc
When not necessary, avoid the creation of a `placeholder_with_default…
fchollet Apr 11, 2018
3fa224a
Factor out the syntactic function scope tracking into the transformer…
tensorflower-gardener Apr 11, 2018
1a36eb1
Replace examples/image_retraining by a pointer to TensorFlow Hub.
tensorflower-gardener Apr 11, 2018
73aef57
Support for removing unfused quantized activation functions and min/max.
tensorflower-gardener Apr 11, 2018
1e283d6
Porting tests for the `decode_proto` and `encode_proto` to OS.
jsimsa Apr 11, 2018
eed6828
BREAKING_CHANGE: Remove event_ndims in Bijector, and require `log_det…
tensorflower-gardener Apr 11, 2018
e520167
Adds a nodedef_fn parameter to copy_op_handler, allowing customizatio…
tensorflower-gardener Apr 11, 2018
21fb4ee
Adding support for batch_to_space_nd op with crops.
tensorflower-gardener Apr 11, 2018
d2690cf
Extend support to remove transpose/reverse on dimensions of size 1.
tensorflower-gardener Apr 11, 2018
079d63d
GCS Filesystem should not cache checkpoint file as we need to read th…
Apr 11, 2018
4b08b66
Fixes issue where name scope collisions could lead to an invalid vari…
alextp Apr 11, 2018
f029631
Increase size of //tensorflow/python/kernel_tests:sets_test to "medium".
mrry Apr 11, 2018
9ce7791
Revealing the range of node ids in the latest layer via resource' state
tensorflower-gardener Apr 11, 2018
acad702
Adding support of core feature columns and losses to gradient boosted…
tensorflower-gardener Apr 11, 2018
d6e2513
support profiling multiple tpu through one grpc and one session.
tensorflower-gardener Apr 11, 2018
e7cfede
Speed up computing mean confidence intervals by avoiding tf.while_loop.
tensorflower-gardener Apr 11, 2018
2b94b44
Move callback into bound function to avoid copying.
mrry Apr 11, 2018
3734bb6
boosted_trees: make sure ensemble deserialization happens for the non…
yk5 Apr 11, 2018
81a9cea
Update ops-related pbtxt files.
tensorflower-gardener Apr 12, 2018
d62a5a1
Automated g4 rollback of changelist 192516190
tensorflower-gardener Apr 12, 2018
1a721ec
Internal testing changes
angerson Apr 12, 2018
70d9935
Add `tf.contrib.stateless.stateless_multinomial()`.
mrry Apr 12, 2018
58029d1
In model_to_estimator, only run get_weights when there are initialize…
yifeif Apr 12, 2018
ff6c110
Increase size of //tensorflow/python/kernel_tests:linalg_ops_test to …
mrry Apr 12, 2018
4e29ebd
[XLA] Redesign: test sharding.
tensorflower-gardener Apr 12, 2018
b52b5a4
Switch to WaitForNotification to fix the flaky test.
Apr 12, 2018
e7e01ac
[XLA] Redesign: fix GetComputationGraphStats.
tensorflower-gardener Apr 12, 2018
6f67893
Update ops-related pbtxt files.
tensorflower-gardener Apr 12, 2018
6c9f882
Go: Update generated wrapper functions for TensorFlow ops.
tensorflower-gardener Apr 12, 2018
9026254
Enable a reduce window test case.
tensorflower-gardener Apr 12, 2018
5b0cb6c
Add closure_js_proto_library build for tf.example protos.
jameswex Apr 12, 2018
ac9be81
Fix description of DynamicUpdateSlice.
tensorflower-gardener Apr 12, 2018
96aba78
Enable an r2 reduce window test case.
tensorflower-gardener Apr 12, 2018
e688642
Make DType, TensorShape, and Dimension "reducable" for pickling purpo…
tensorflower-gardener Apr 12, 2018
cf542ae
Special-case the name scoping for operator methods. TensorFlow disall…
tensorflower-gardener Apr 12, 2018
b0978aa
Updating tests containing graphs with Variables so that they Evaluate…
tensorflower-gardener Apr 12, 2018
cbea753
Fixing dependencies.
tensorflower-gardener Apr 12, 2018
8a24797
Collective Ops Part 3
tensorflower-gardener Apr 12, 2018
ffbf77d
Introduced tool to run an HLO module in replicated fashion, by infeed…
tensorflower-gardener Apr 12, 2018
844b8ca
[TF] Add TensorListPushBackBatch.
ebrevdo Apr 12, 2018
151c31c
Make default weights initializer in `base_layers.Layer` suitable for …
tensorflower-gardener Apr 12, 2018
dc2d1c2
Fix shape inference for outside_compilation clusters that include cyc…
tensorflower-gardener Apr 12, 2018
4a405bc
Update ops-related pbtxt files.
tensorflower-gardener Apr 12, 2018
3ebe39c
Fix lost dependency
tensorflower-gardener Apr 12, 2018
024b037
Fixed error where no background audio samples were being used when te…
petewarden Apr 12, 2018
10e6021
Go: Update generated wrapper functions for TensorFlow ops.
tensorflower-gardener Apr 12, 2018
454a22a
Construct Orthogonal kernels for 2d convolutions.
tensorflower-gardener Apr 12, 2018
583ee0e
Add testCompileTimeConstantsInDefun in xla
iganichev Apr 12, 2018
d1ee67c
Start moving Checkpointable utilities toward core
allenlavoie Apr 12, 2018
6d2316d
Add FunctionTest.testLayerInDefun
iganichev Apr 12, 2018
6308e58
Add softsign bijector.
tensorflower-gardener Apr 12, 2018
ecacd20
[XLA] Redesign: add XlaComputation::IsNull.
tensorflower-gardener Apr 12, 2018
1a014c6
Restore dependency on estimator utils from model.
tensorflower-gardener Apr 12, 2018
f959065
Fix comment of bucket_by_sequence_length about return type of
tensorflower-gardener Apr 12, 2018
e40fec4
Upgrade libjpeg-turbo
jart Apr 12, 2018
64eb9b4
Separate out distribute dependency out of training, as it needs to be…
guptapriya Apr 12, 2018
322580a
Fix build breakage on metagraph exporting when caching_device is set
alextp Apr 12, 2018
111ee9b
Make new build target public.
jameswex Apr 12, 2018
7c0172e
ResolveConstantReshape transformation and fix for ResolveConstantTran…
tensorflower-gardener Apr 12, 2018
0161bb7
K-FAC: Deprecate tf.contrib.kfac.
tensorflower-gardener Apr 12, 2018
fffbe5a
Check if the session has been deleted before releasing a callable.
mrry Apr 12, 2018
d49cbc2
[tf.data] Clean up //tensorflow/contrib/data/python/ops/BUILD.
mrry Apr 12, 2018
0195d6b
Added a utility to compute a topo ordering of a graph
benoitsteiner Apr 12, 2018
cc108a7
Add support for RNN state array of type tf.identity.
tensorflower-gardener Apr 12, 2018
dde6aaf
Exposing tensorflow.contrib.proto in the pip package.
jsimsa Apr 12, 2018
9908cb1
Change assertions to use the tensor 'x' rather than 'x.op.name'. This…
tensorflower-gardener Apr 12, 2018
5d442be
Propagate sharding of the source instruction to the copies added by l…
tensorflower-gardener Apr 12, 2018
5f7929b
[XLA:GPU] Pass all four args to custom-call convs when they're created.
Apr 12, 2018
59c828c
Document support for boolean values in tf.contrib.training.HParams.
shoyer Apr 12, 2018
4d56813
Misc. small optimizations in Grappler and shape inference code.
tensorflower-gardener Apr 12, 2018
3755128
Fix a typo in cross_tower_ops.
Apr 13, 2018
fffd3ca
Move dummy AssertOp and CheckNumericsOp to //third_party/tensorflow/c…
tensorflower-gardener Apr 13, 2018
d42e4bd
Porting tests for `rpc_op` to OS.
jsimsa Apr 13, 2018
457e8b3
Print error msg in CUDATimer.Init() when CreateEvent() is not ok().
protoget Apr 13, 2018
5a53c9b
Reintroducing support for constants as outputs of tf.data.map(). This…
jsimsa Apr 13, 2018
7d89bfc
Adding autograph built-in function checker.
tensorflower-gardener Apr 13, 2018
93afca5
Convert GrapplerFunctionItem to (Specialized)FunctionDef.
tensorflower-gardener Apr 13, 2018
c4526e5
Avoid calling K.learning_phase() when not necessary in Dropout layer …
fchollet Apr 13, 2018
5a6d5a1
Enable efficient feeding of symbolic tensors to placeholders in the K…
fchollet Apr 13, 2018
4f615ad
Automated g4 rollback of changelist 192691078
jsimsa Apr 13, 2018
3c98705
Add boolean type to tflite in favor of comparison implementations.
tensorflower-gardener Apr 13, 2018
3438c3f
Automated g4 rollback of changelist 192504411
jsimsa Apr 13, 2018
1c88fac
Automated g4 rollback of changelist 192698931
jsimsa Apr 13, 2018
68f0f1a
[XLA] Rename Interpreter{Executor,Platform} -> XlaInterpreter{Executo…
Apr 13, 2018
73cc1d5
-- Add a new histogram/cdf computation method compatible with the TPU.
tensorflower-gardener Apr 13, 2018
1b0c277
Implementation of Less
tensorflower-gardener Apr 13, 2018
bb80410
Fix bug in converted_call, and add tests for it.
tensorflower-gardener Apr 13, 2018
b520022
Update for upstream LLVM *.def -> *.inc rename
tensorflower-gardener Apr 13, 2018
345414c
- Fixed small bug in example script
tensorflower-gardener Apr 13, 2018
bb8fcd5
Keep function doc string at the top of the function.
tensorflower-gardener Apr 13, 2018
8c47ec3
Demo: RNN colorbot with Estimators.
tensorflower-gardener Apr 13, 2018
3eb4e4f
Split byte_order.h off cpu_info.h
tensorflower-gardener Apr 13, 2018
f9de043
Automated g4 rollback of changelist 192768744
tensorflower-gardener Apr 13, 2018
91c3199
Add support to TFLite for dilated convolution.
tensorflower-gardener Apr 13, 2018
17aa70e
Refactor to remove the duplicate calls to obtain a function's namespa…
tensorflower-gardener Apr 13, 2018
554c587
Experiment with pre-shuffled fully-connected weights
tensorflower-gardener Apr 13, 2018
a194572
Merge commit for internal changes
qlzh727 Apr 13, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
1 change: 1 addition & 0 deletions tensorflow/compiler/aot/BUILD
Expand Up @@ -60,6 +60,7 @@ cc_library(
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
Expand Down
15 changes: 15 additions & 0 deletions tensorflow/compiler/aot/tests/BUILD
Expand Up @@ -14,6 +14,7 @@ test_suite(
":test_graph_tfadd_test",
":test_graph_tfadd_with_ckpt_saver_test",
":test_graph_tfadd_with_ckpt_test",
":test_graph_tfassert_eq_test",
":test_graph_tffunction_test",
":test_graph_tfgather_test",
":test_graph_tfmatmul_test",
Expand All @@ -33,6 +34,7 @@ py_binary(
"//tensorflow/python", # TODO(b/34059704): remove when fixed
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
Expand All @@ -52,6 +54,7 @@ genrule(
"test_graph_tfadd_with_ckpt_saver.ckpt",
"test_graph_tfadd_with_ckpt_saver.pb",
"test_graph_tfadd_with_ckpt_saver.saver",
"test_graph_tfassert_eq.pb",
"test_graph_tffunction.pb",
"test_graph_tfgather.pb",
"test_graph_tfmatmul.pb",
Expand Down Expand Up @@ -104,6 +107,17 @@ tf_library(
],
)

tf_library(
name = "test_graph_tfassert_eq",
testonly = 1,
config = "test_graph_tfassert_eq.config.pbtxt",
cpp_class = "AssertComp",
graph = "test_graph_tfassert_eq.pb",
tags = [
"manual",
],
)

tf_library(
name = "test_graph_tffunction",
testonly = 1,
Expand Down Expand Up @@ -170,6 +184,7 @@ tf_cc_test(
":test_graph_tfadd",
":test_graph_tfadd_with_ckpt",
":test_graph_tfadd_with_ckpt_saver",
":test_graph_tfassert_eq",
":test_graph_tffunction",
":test_graph_tfgather",
":test_graph_tfmatmul",
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/compiler/aot/tests/make_test_graphs.py
Expand Up @@ -29,6 +29,7 @@
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import app
Expand Down Expand Up @@ -125,6 +126,14 @@ def tfsplits(_):
array_ops.identity(y, name='result')


def tfassert_eq(_):
x = array_ops.placeholder(dtypes.int32, name='x_hold')
y = array_ops.placeholder(dtypes.int32, name='y_hold')
control_flow_ops.Assert(
math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq')
math_ops.add(x, math_ops.negative(y), name='x_y_diff')


def write_graph(build_graph, out_dir):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
Expand All @@ -144,6 +153,7 @@ def main(_):
write_graph(tfmatmulandadd, FLAGS.out_dir)
write_graph(tffunction, FLAGS.out_dir)
write_graph(tfsplits, FLAGS.out_dir)
write_graph(tfassert_eq, FLAGS.out_dir)


if __name__ == '__main__':
Expand Down
16 changes: 16 additions & 0 deletions tensorflow/compiler/aot/tests/test_graph_tfassert_eq.config.pbtxt
@@ -0,0 +1,16 @@
# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_hold" }
shape {
dim { size: 1 }
}
}
feed {
id { node_name: "y_hold" }
shape {
dim { size: 1 }
}
}
fetch {
id { node_name: "x_y_diff" }
}
18 changes: 18 additions & 0 deletions tensorflow/compiler/aot/tests/tfcompile_test.cc
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq.h"
#include "tensorflow/compiler/aot/tests/test_graph_tffunction.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
Expand Down Expand Up @@ -413,6 +414,23 @@ TEST(TFCompileTest, Splits) {
EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4);
}

TEST(TFCompileTest, AssertEqAndReturnDiff) {
// Assert is converted into a no-op in XLA, so there is no failure even if the
// two args are different.
AssertComp assert;
EXPECT_EQ(assert.arg0_data(), assert.args()[0]);
EXPECT_EQ(assert.arg1_data(), assert.args()[1]);

assert.arg0() = 2;
assert.arg1() = 1;
const int32 expected_result = assert.arg0() - assert.arg1();
EXPECT_TRUE(assert.Run());
EXPECT_EQ(assert.error_msg(), "");
EXPECT_EQ(assert.result0(), expected_result);
EXPECT_EQ(assert.result0_data()[0], expected_result);
EXPECT_EQ(assert.result0_data(), assert.results()[0]);
}

TEST(TFCompileTest, LookupNameIndex) {
// add doesn't have any names defined in its config.
AddComp add;
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/compiler/jit/BUILD
Expand Up @@ -183,6 +183,13 @@ cc_library(
],
)

cc_library(
name = "shape_inference_helpers",
srcs = ["shape_inference_helpers.cc"],
hdrs = ["shape_inference_helpers.h"],
deps = ["//tensorflow/core:graph"],
)

# Internal targets below this point.

cc_library(
Expand Down Expand Up @@ -293,6 +300,7 @@ cc_library(
deps = [
":common",
":graph_to_functiondef",
":shape_inference_helpers",
":union_find",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/jit/kernels:parallel_check_op",
Expand Down
103 changes: 89 additions & 14 deletions tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/graph_to_functiondef.h"
#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/xla/status_macros.h"
Expand All @@ -36,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/tensor_id.h"
Expand Down Expand Up @@ -576,7 +578,8 @@ class Encapsulator {
// satisfied, e.g., because send_node depends on a node that doesn't have a
// registered shape inference function.
Status DoStaticShapeInferenceForOutsideCompilationSend(
const Graph& graph_in, const ShapeRefiner& shape_refiner,
const Graph& graph_in, const BackEdgeHelper& back_edge_helper,
const ShapeRefiner& shape_refiner,
const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
FunctionLibraryDefinition* library,
std::vector<TensorShapeProto>* static_shape_out,
Expand All @@ -599,7 +602,7 @@ class Encapsulator {
// to nodes in pruned_graph.
Status MakeGraphForOutsideCompilationSends(
const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
ShapeRefiner* shape_refiner,
BackEdgeHelper* back_edge_helper, ShapeRefiner* shape_refiner,
std::unordered_map<const Node*, Node*>* node_images,
FunctionLibraryDefinition* library);

Expand Down Expand Up @@ -1712,9 +1715,13 @@ namespace {
// matter because it will only be used subsequently for shape inference. (It
// would be possible to add a switch statement over data_type to create a value
// for the constant, but that would entail maintaining the logic as new types
// are added, and is not necessary.)
Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape,
Graph* graph_out) {
// are added, and is not necessary.) If the node being replaced was within a
// control flow frame, adds appropriate Enter nodes so that the use of the Const
// is well-formed.
Node* AddDummyShapedNode(const Node* src_node, int src_port,
const std::vector<ControlFlowInfo>& control_flow_info,
const TensorShapeProto& shape, Graph* graph_out) {
DataType data_type = src_node->output_type(src_port);
TensorProto dummy_proto;
dummy_proto.set_dtype(data_type);
*dummy_proto.mutable_tensor_shape() = shape;
Expand All @@ -1725,7 +1732,23 @@ Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape,
NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const",
options.op_registry());
node_builder.Attr("dtype", data_type).Attr("value", dummy_proto);
return options.FinalizeBuilder(&node_builder);
Node* node = options.FinalizeBuilder(&node_builder);
// Add any Enter nodes required to bring the constant to the correct control
// flow frame.
while (!control_flow_info[src_node->id()].frame_name.empty()) {
NodeBuilder enter_builder(options.GetNameForOp("Enter"), "Enter",
options.op_registry());
enter_builder.Attr("frame_name",
control_flow_info[src_node->id()].frame_name);
enter_builder.Attr("is_constant", true);
enter_builder.Input(node, 0);
Node* enter_node = options.FinalizeBuilder(&enter_builder);
// Adopt the new Enter node as the value in the current frame.
node = enter_node;
// Recurse to the parent frame to see if more Enter nodes need to be added.
src_node = control_flow_info[src_node->id()].parent_frame;
}
return node;
}

// Adds a copy of node_in to graph_out and adds the mapping to
Expand Down Expand Up @@ -1767,17 +1790,30 @@ Status CopyShapeInferenceNodeToGraph(
}
}
}
// Work around the fact that Enter nodes refuse to propagate shape information
// unless they are marked loop invariant. Since we are never going to execute
// this graph, marking them all loop invariant is fine.
if (node_out->type_string() == "Enter") {
node_out->ClearAttr("is_constant");
node_out->AddAttr("is_constant", true);
}
return Status::OK();
}

} // namespace

Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
const Graph& graph_in, const ShapeRefiner& shape_refiner,
const Graph& graph_in, const BackEdgeHelper& back_edge_helper,
const ShapeRefiner& shape_refiner,
const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
FunctionLibraryDefinition* library,
std::vector<TensorShapeProto>* static_shape_out,
std::unique_ptr<Graph>* graph_out) {
// Get the control flow structure of the input graph so we can build
// well-formed output graphs.
std::vector<ControlFlowInfo> control_flow_info;
TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph_in, &control_flow_info));

// Maps from nodes in graph_in to nodes in graph_out.
//
// When an edge has fully defined shape the source node in graph_in is
Expand All @@ -1802,7 +1838,6 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(

// We don't use the standard ReverseDFS because we want to cut off traversal
// whenever we find an output with fully defined shape.
// TODO(misard) make this work properly in the presence of control flow.
struct Work {
Node* node;
bool leave; // Are we entering or leaving node?
Expand Down Expand Up @@ -1840,8 +1875,9 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
TensorShapeProto proto;
context->ShapeHandleToProto(shape, &proto);
if (dummy_node_images.find(src_node) == dummy_node_images.end()) {
dummy_node_images[src_node] = AddDummyShapedNode(
src_node->output_type(src_port), proto, graph_out->get());
dummy_node_images[src_node] =
AddDummyShapedNode(src_node, src_port, control_flow_info,
proto, graph_out->get());
}
// The final input to the send node is the dynamic key, which we
// don't include in the static shapes.
Expand Down Expand Up @@ -1889,6 +1925,38 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
}
}

for (const auto edge : back_edge_helper.RemovedEdges()) {
if (copied_node_images.find(edge.dst) != copied_node_images.end()) {
// The destination of this back edge was added to the inference graph, so
// fix it up.
Node* dst = copied_node_images[edge.dst];
if (dst->type_string() != "Merge") {
return errors::InvalidArgument(
"outside_compilation cluster contains a back-edge to node ",
dst->name(), " of type ", dst->type_string(),
". The analysis pass only supports back-edges to Merge nodes.");
}
const Edge* existing_input_edge;
if (edge.dst_input != 1 || dst->num_inputs() != 2 ||
!dst->input_edge(0, &existing_input_edge).ok()) {
// TODO(misard) if we see graphs built with a different structure, relax
// this constraint. Leaving it here for now to avoid writing unnecessary
// complex code since we believe graphs generated by front ends all have
// the back edge as the second input to the merge node.
return errors::Internal(
"Internal assumption failed while rewriting an outside_compilation "
"cluster that contains a while loop. Logic assumes back-edge is to "
"port 1 of a 2-input "
"Merge node.");
}
// Connect the existing edge to both inputs of the Merge node so that the
// graph will be well-formed.
(*graph_out)
->AddEdge(existing_input_edge->src(),
existing_input_edge->src_output(), dst, edge.dst_input);
}
}

return Status::OK();
}

Expand Down Expand Up @@ -1956,7 +2024,7 @@ Status Encapsulator::MakePrunedGraphCopyAndInline(

Status Encapsulator::MakeGraphForOutsideCompilationSends(
const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
ShapeRefiner* shape_refiner,
BackEdgeHelper* back_edge_helper, ShapeRefiner* shape_refiner,
std::unordered_map<const Node*, Node*>* node_images,
FunctionLibraryDefinition* library) {
// Find all the send_from_host nodes in all subgraphs, to use as roots for the
Expand All @@ -1978,10 +2046,15 @@ Status Encapsulator::MakeGraphForOutsideCompilationSends(
// nodes, inlining any functions as needed.
TF_RETURN_IF_ERROR(MakePrunedGraphCopyAndInline(
graph, send_from_host_nodes, pruned_graph, node_images, library));
FixupSourceAndSinkEdges(pruned_graph->get());

// Remove back edges from any cycles in the pruned graph to simplify shape
// inference traversal. They will be fixed up in the per-subgraph shape
// inference graphs stored in the function library.
TF_RETURN_IF_ERROR(back_edge_helper->Remove(pruned_graph->get()));

// Perform shape inference on the pruned graph.
shape_refiner->set_require_shape_inference_fns(false);
FixupSourceAndSinkEdges(pruned_graph->get());
std::vector<Node*> post_order;
GetReversePostOrder(*(*pruned_graph), &post_order);
for (auto node : post_order) {
Expand All @@ -1999,11 +2072,13 @@ Status Encapsulator::MakeGraphForOutsideCompilationSends(

Status Encapsulator::GetShapeInfoForOutsideCompilationSends(
Graph* graph_out, FunctionLibraryDefinition* library) {
BackEdgeHelper back_edge_helper;
std::unique_ptr<Graph> pruned_graph;
ShapeRefiner shape_refiner(graph_out->versions(), graph_out->op_registry());
std::unordered_map<const Node*, Node*> node_images;
TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends(
*graph_out, &pruned_graph, &shape_refiner, &node_images, library));
*graph_out, &pruned_graph, &back_edge_helper, &shape_refiner,
&node_images, library));

if (VLOG_IS_ON(1)) {
dump_graph::DumpGraphToFile("pruned_graph_for_shape_inference",
Expand Down Expand Up @@ -2033,7 +2108,7 @@ Status Encapsulator::GetShapeInfoForOutsideCompilationSends(
std::unique_ptr<Graph> graph;
if (send_node != nullptr) {
TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend(
*pruned_graph, shape_refiner, recv_at_host_names,
*pruned_graph, back_edge_helper, shape_refiner, recv_at_host_names,
node_images[send_node], library, &static_shape, &graph));
if (graph == nullptr) {
VLOG(2) << "Send node " << send_node->name() << " shapes";
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/compiler/jit/mark_for_compilation_pass.cc
Expand Up @@ -51,6 +51,15 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
// is really a kind of function call and will be handled by
// IsCompilableCall().
if (node.type_string() == "SymbolicGradient") return false;
if (node.type_string() == "Const") {
// Skip Const op with type DT_STRING, since XLA doesn't support it, but the
// registered Const KernelDef says that it does, to support no-op Assert for
// tfcompile.
const AttrValue* attr = node.attrs().Find("dtype");
if (attr != nullptr && attr->type() == DT_STRING) {
return false;
}
}
return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok();
}

Expand Down