Skip to content

Commit

Permalink
Merge pull request #107 from tensorflow/master
Browse files Browse the repository at this point in the history
pull
  • Loading branch information
rajendraarora16 committed Nov 7, 2018
2 parents 3809b4b + 71e4954 commit 63e1e77
Show file tree
Hide file tree
Showing 237 changed files with 5,749 additions and 2,879 deletions.
1 change: 1 addition & 0 deletions tensorflow/BUILD
Expand Up @@ -352,6 +352,7 @@ package_group(
"//tensorflow/...",
"//tensorflow_estimator/...",
"//tensorflow_fold/llgtm/...",
"//tensorflow_text/...",
"//third_party/py/tensor2tensor/...",
],
)
Expand Down
42 changes: 16 additions & 26 deletions tensorflow/compiler/jit/BUILD
Expand Up @@ -21,7 +21,6 @@ package(
)

load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
Expand Down Expand Up @@ -255,6 +254,7 @@ cc_library(
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/core:core_cpu",
Expand All @@ -265,6 +265,21 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:variable_ops",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)

tf_cc_test(
name = "xla_compilation_cache_test",
srcs = [
"xla_compilation_cache_test.cc",
],
deps = [
":xla_compilation_cache",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

Expand Down Expand Up @@ -651,31 +666,6 @@ tf_cc_test(
],
)

tf_cc_test(
name = "xla_launch_util_test",
size = "small",
srcs = ["xla_launch_util_test.cc"],
deps = [
":common",
":xla_compilation_cache",
":xla_launch_util",
":xla_tensor",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_runtime",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core/kernels:variable_ops",
],
)

cc_library(
name = "xla_fusion_optimizer",
srcs = ["xla_fusion_optimizer.cc"],
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/compiler/jit/kernels/xla_ops.cc
Expand Up @@ -286,8 +286,10 @@ static Status CompileToLocalExecutable(
// rather than a one-element tuple.
compile_options.always_return_tuple = false;

return cache->Compile(options, function, constant_args, *variables, ctx,
compile_options,
std::vector<XlaCompiler::Argument> args;
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_args, *variables, ctx, &args));
return cache->Compile(options, function, args, compile_options,
lazy ? XlaCompilationCache::CompileMode::kLazy
: XlaCompilationCache::CompileMode::kStrict,
kernel, executable);
Expand Down
15 changes: 15 additions & 0 deletions tensorflow/compiler/jit/mark_for_compilation_pass.cc
Expand Up @@ -61,6 +61,10 @@ struct OperationFilter {
// seeding behavior as TensorFlow's RNG (b/34749654). So we avoid
// auto-clustering stateful RNG ops.
bool allow_stateful_rng_ops;

// TODO(b/118970344): Whether ControlTrigger ops are allowed. It is unsound
// to cluster ControlTrigger because of how we use deadness analysis.
bool allow_control_trigger;
};

bool IsStatefulRandomOp(absl::string_view op_name) {
Expand Down Expand Up @@ -225,6 +229,9 @@ bool IsCompilableCall(const NodeDef& call_def,
IsStatefulRandomOp(node->type_string())) {
return false;
}
if (!op_filter.allow_control_trigger && node->IsControlTrigger()) {
return false;
}
if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, op_filter, depth + 1,
lib_runtime)) {
Expand Down Expand Up @@ -455,6 +462,9 @@ Status FindCompilationCandidates(
op_filter.allow_stateful_rng_ops =
(registration->autoclustering_policy ==
XlaOpRegistry::AutoclusteringPolicy::kAlways);
op_filter.allow_control_trigger =
(registration->autoclustering_policy ==
XlaOpRegistry::AutoclusteringPolicy::kAlways);

if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, op_filter, 0,
Expand All @@ -469,6 +479,10 @@ Status FindCompilationCandidates(
VLOG(2) << "Rejecting " << node->name() << ": stateful random operation";
continue;
}
if (!op_filter.allow_control_trigger && node->IsControlTrigger()) {
VLOG(2) << "Rejecting " << node->name() << ": is a control trigger op";
continue;
}

if (!op_filter.allow_resource_ops &&
(HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) {
Expand Down Expand Up @@ -604,6 +618,7 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
OperationFilter op_filter;
op_filter.allow_resource_ops = true;
op_filter.allow_stateful_rng_ops = true;
op_filter.allow_control_trigger = true;
return IsCompilableCall(ndef, jit_device_type, op_filter, 0, flr);
}

Expand Down
12 changes: 4 additions & 8 deletions tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
Expand Up @@ -817,14 +817,10 @@ TEST(XlaCompilationTest, ClusterControlTrigger) {

std::unordered_map<string, string> clusters = GetClusters(*graph);

ASSERT_FALSE(clusters.empty());
string cluster_name = clusters.begin()->second;

// ctrl_trigger_a has inputs with mismatching deadness so it won't be
// clustered. ctrl_trigger_b is okay to cluster.
std::unordered_map<string, string> expected_clusters(
{{"const_a", cluster_name}, {"ctrl_trigger_b", cluster_name}});
EXPECT_EQ(clusters, expected_clusters);
// TODO(b/118970344): ctrl_trigger_a has inputs with mismatching deadness so
// it won't be clustered. ctrl_trigger_b is okay to cluster but we don't
// cluster it because of b/118970344.
EXPECT_TRUE(clusters.empty());
}

TEST(XlaCompilationTest, RandomShape) {
Expand Down

0 comments on commit 63e1e77

Please sign in to comment.