Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch 199311231 #19787

Merged
merged 32 commits into from
Jun 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
6b2a088
Add various missing aliases for symbols in tf.keras submodules.
fchollet Jun 4, 2018
06c4fb6
Fixes a cleanup bug in BatchFunction op.
vinuraja Jun 4, 2018
142ccf3
Add rip-offs of LLVM's cast, dyn_cast, cast_or_null, dyn_cast_or_null…
tensorflower-gardener Jun 4, 2018
e2d3008
Move benchmarking code to a new directory and add some documentation.
shashishekhar Jun 4, 2018
d947e2c
Remove tf_export decorator from contrib. tf_export decorators current…
annarev Jun 4, 2018
18995ec
Adds update_ops to train_op for all heads.
tensorflower-gardener Jun 4, 2018
eab2e4d
nit: FlatBuffer -> FrozenGraph
Jun 4, 2018
69613d2
More handle_data fixing.
skye Jun 4, 2018
cf01d11
Add support for kDomain parsing in HLO parser.
tensorflower-gardener Jun 4, 2018
14d4d16
Add TOKEN primitive type.
meheffernan Jun 4, 2018
7d195d0
Fix an floating point inaccuracy issue in precision_recall_at_equal_t…
tensorflower-gardener Jun 4, 2018
ff5ad20
Updated include path for internal protobuf implementation.
tensorflower-gardener Jun 4, 2018
310a51b
HloParser: use uint16 in U16 case
yunxing Jun 5, 2018
35c8574
[XLA] Don't dump subgraphs twice in hlo_graph_dumper.
Jun 5, 2018
76801dd
Enable XLA fusions as a Grappler optimization.
tensorflower-gardener Jun 5, 2018
fedfc47
Resolve device names when passed into DistributionStrategy methods.
tensorflower-gardener Jun 5, 2018
d660ab0
[TF:XLA] Add method CreateNewModule to HloVerifiedTestBase, and remem…
dimvar Jun 5, 2018
bf8d058
Windows: Refactor bazel_test_lib.sh and common_env.sh
tensorflower-gardener Jun 5, 2018
5403336
Added missing backtick in tf.ones_like documentation
tensorflower-gardener Jun 5, 2018
92789d7
Handle scalar input to assert_equal in eager.
tomhennigan Jun 5, 2018
22a8c24
Remove test dependencies that are no longer needed.
Jun 5, 2018
c0dc76a
Fix generated_zip_test failure caused by regex matching failures.
tensorflower-gardener Jun 5, 2018
274f951
Remove _USE_C_API staging from ops.py.
skye Jun 5, 2018
3653e80
Address compiler warnings in tensorflow/core/distributed_runtime.
mrry Jun 5, 2018
e1f31d4
Expose `@tfe.run_all_tests_in_graph_and_eager_modes`.
tomhennigan Jun 5, 2018
51445a7
Add computed receptive field parameters from popular convnets.
tensorflower-gardener Jun 5, 2018
72f6b4d
Delete "RuntimeWarning" it is not having the intended effect.
MarkDaoust Jun 5, 2018
16a4b1e
Automated g4 rollback of changelist 199244092
gunan Jun 5, 2018
ad1fc6b
Eliminate nested try/catch's in Distribution._call_prob and friends. …
csuter Jun 5, 2018
b8b93f3
Edit error message to make it clear which yaml module you need.
MarkDaoust Jun 5, 2018
8c9afdf
Fix docstring formatting.
tensorflower-gardener Jun 5, 2018
bf1c227
Merge commit for internal changes
Jun 5, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 46 additions & 0 deletions tensorflow/compiler/jit/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")

# Target that bundles up the XLA CPU and GPU JIT devices.
cc_library(
Expand Down Expand Up @@ -312,6 +313,7 @@ cc_library(
":common",
":shape_inference_helpers",
":union_find",
":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/jit/kernels:parallel_check_op",
"//tensorflow/compiler/jit/legacy_flags:encapsulate_subgraphs_pass_flags",
Expand All @@ -332,6 +334,18 @@ cc_library(
],
)

cc_library(
name = "xla_cluster_util",
srcs = ["xla_cluster_util.cc"],
hdrs = ["xla_cluster_util.h"],
deps = [
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core/kernels:bounds_check",
],
)

cc_library(
name = "union_find",
hdrs = ["union_find.h"],
Expand Down Expand Up @@ -408,6 +422,38 @@ tf_cc_test(
],
)

cc_library(
name = "xla_fusion_optimizer",
srcs = ["xla_fusion_optimizer.cc"],
hdrs = ["xla_fusion_optimizer.h"],
visibility = ["//visibility:public"],
deps = [
":common",
":union_find",
":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
],
)

tf_cuda_cc_test(
name = "xla_fusion_optimizer_test",
srcs = ["xla_fusion_optimizer_test.cc"],
deps = [
":common",
":xla_cluster_util",
":xla_fusion_optimizer",
"//tensorflow/core:graph",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/grappler/utils:grappler_test",
],
)

# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
cc_header_only_library(
name = "xla_jit_headers_lib",
Expand Down
161 changes: 19 additions & 142 deletions tensorflow/compiler/jit/mark_for_compilation_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
Expand All @@ -41,9 +42,6 @@ limitations under the License.

namespace tensorflow {

const char* const kXlaClusterAttr = "_XlaCluster";
const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation";

namespace {

// Returns true if, when executed in TensorFlow, `node` is guaranteed to forward
Expand Down Expand Up @@ -191,16 +189,6 @@ bool IsCompilableCall(const NodeDef& call_def,
return true;
}

// Returns the DeviceType corresponding to 'device'.
Status DeviceTypeOfDevice(const string& device, DeviceType* device_type) {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
return errors::Internal("Malformed assigned device '", device, "'");
}
*device_type = DeviceType(parsed.type);
return Status::OK();
}

// Tests whether `node` has a DT_RESOURCE typed input or output.
bool HasResourceInputOrOutput(const Node& node) {
return std::find(node.input_types().begin(), node.input_types().end(),
Expand All @@ -209,18 +197,11 @@ bool HasResourceInputOrOutput(const Node& node) {
DT_RESOURCE) != node.output_types().end();
}

struct NodeCompare {
bool operator()(const Node* a, const Node* b) const {
return a->id() < b->id();
}
};
using OrderedNodeSet = std::set<Node*, NodeCompare>;

// Returns true if the op can be decomposed into XLA ops for which
// there are fusable elemental implementations.
//
// TODO(hpucha): Consider a black list instead of a white list as
// implemented below.
// TODO(hpucha): Remove this code since this functionality is subsumed by
// Grappler XlaFusionOptimizer.
bool IsXlaFusable(const NodeDef& node) {
static const std::unordered_set<std::string>* elementwise_ops =
new std::unordered_set<std::string>(
Expand Down Expand Up @@ -390,7 +371,7 @@ Status FindCompilationCandidates(
for (Node* node : graph.op_nodes()) {
sorted_nodes.push_back(node);
}
std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeCompare());
std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID());

for (Node* node : sorted_nodes) {
VLOG(2) << "Fuel: " << fuel;
Expand All @@ -405,9 +386,13 @@ Status FindCompilationCandidates(

DeviceType device_type("");
TF_RETURN_IF_ERROR(
DeviceTypeOfDevice(node->assigned_device_name(), &device_type));
DeviceToDeviceType(node->assigned_device_name(), &device_type));

if (is_compilable_fn && !is_compilable_fn(node, device_type)) continue;
if (is_compilable_fn && !is_compilable_fn(node, device_type)) {
VLOG(2) << "Compilation rejected node: not compilable " << node->name()
<< ": " << node->type_string();
continue;
}

const XlaOpRegistry::DeviceRegistration* registration;
CHECK(
Expand Down Expand Up @@ -456,46 +441,6 @@ struct Cluster {
int representative = -1;
};

// Returns a string describing how an edge from src to dst would
// create a cycle.
string DescribeCycle(const GraphCycles& cycles, const Graph& graph, int src,
int dst) {
int32 max_path_size = graph.num_node_ids() + 1;
std::vector<int32> path(max_path_size);
int32 path_size = cycles.FindPath(dst, src, max_path_size, path.data());
if (path_size == 0) {
return "";
}

auto node_name = [&cycles, &graph](int node_id) {
if (!FastBoundsCheck(node_id, graph.num_node_ids())) {
return string("(null)");
}
auto* node = graph.FindNodeId(node_id);
if (node == nullptr) {
return string("(null)");
}
return node->name();
};

string description;
strings::StrAppend(&description, "Edge from ", node_name(src), " to ",
node_name(dst), " would create a cycle.\n");
path.resize(path_size);
for (int32 node_id : path) {
string ascii_art;
if (node_id == dst) {
ascii_art = "+-> ";
} else if (node_id != src) {
ascii_art = "| ";
} else {
ascii_art = "+-- ";
}
strings::StrAppend(&description, ascii_art, node_name(node_id), "\n");
}
return description;
}

} // anonymous namespace

bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
Expand Down Expand Up @@ -601,84 +546,13 @@ Status MarkForCompilationPass::RunImpl(
: Env::Default(),
is_compilable_fn, &compilation_candidates));

GraphCycles cycles;
for (int i = 0; i < graph->num_node_ids(); ++i) {
// We rely on the node IDs in the cycle detection graph being consecutive
// integers starting from 0.
CHECK_EQ(i, cycles.NewNode());
if (compilation_candidates.empty()) {
VLOG(2) << "No compilable candidates";
return Status::OK();
}

// Compute the loop structure of the graph.
std::vector<ControlFlowInfo> control_flow_info;
TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info));

// The clustering code must avoid adding cycles to the graph to prevent
// deadlock. However, the graph may contain loops, which would trigger the
// cycle detection code. To handle loops, we alter the structure of the cycle
// detection graph, disconnecting each loop from the enclosing graph.
// Specifically, we:
// * add a new "frame" node for each loop.
// * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges
// to/from the corresponding frame node. In essence, we collapse the loop
// into a single node for the purpose of cycle detection in the enclosing
// graph.
// * the body of the loop should now be disconnected from the rest of the
// graph; we make it acyclic by breaking loop backedges (edges outgoing from
// "NextIteration" nodes.

// Map from frame name strings to node IDs in the cycle detection graph.
std::unordered_map<string, int> frame_nodes;

// Get the cycle graph node ID for frame 'frame_name', or add one if none
// exists.
auto GetOrAddFrameNodeId = [&frame_nodes, &cycles](const string& frame_name) {
int& frame_id = frame_nodes.emplace(frame_name, -1).first->second;
if (frame_id < 0) {
// The emplace succeeded; we have not allocated a frame node yet.
frame_id = cycles.NewNode();
}
return frame_id;
};

for (Edge const* edge : graph->edges()) {
if (edge->dst()->IsEnter()) {
// Lift edges to an "Enter" node to the corresponding frame node.
const string& frame_name =
control_flow_info[edge->dst()->id()].frame_name;
int dst = GetOrAddFrameNodeId(frame_name);
if (!cycles.InsertEdge(edge->src()->id(), dst)) {
return errors::Internal(
"Cycle detected when adding enter->frame edge: ",
DescribeCycle(cycles, *graph, edge->src()->id(), dst));
}
continue;
}
if (edge->src()->IsExit()) {
// Lift edges from an "Exit" node to the corresponding frame node.
const string& frame_name =
control_flow_info[edge->src()->id()].frame_name;
int src = GetOrAddFrameNodeId(frame_name);
if (!cycles.InsertEdge(src, edge->dst()->id())) {
return errors::Internal(
"Cycle detected when adding frame->exit edge: ",
DescribeCycle(cycles, *graph, src, edge->dst()->id()));
}
// Drop the original edge.
continue;
}
if (edge->src()->IsNextIteration()) {
// Break loop back-edges.
continue;
}
if (!cycles.InsertEdge(edge->src()->id(), edge->dst()->id())) {
// This should never happen. All cycles in the graph should contain
// a control flow operator.
return errors::Internal(
"Found cycle in graph without control flow operator during XLA "
"compilation: ",
DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id()));
}
}
GraphCycles cycles;
TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles));

// Each compilation candidate belongs to a cluster. The cluster's
// representative
Expand All @@ -696,6 +570,9 @@ Status MarkForCompilationPass::RunImpl(

// Repeatedly contract edges between clusters that are on the same device,
// provided the contraction would not create a cycle.
//
// TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for
// example, from the Grappler fusion pass).
while (!worklist.empty()) {
int from = worklist.front()->Get().representative;
worklist.pop_front();
Expand Down Expand Up @@ -804,7 +681,7 @@ Status MarkForCompilationPass::RunImpl(
// compilation.
DeviceType device_type("");
TF_RETURN_IF_ERROR(
DeviceTypeOfDevice(n->assigned_device_name(), &device_type));
DeviceToDeviceType(n->assigned_device_name(), &device_type));
const XlaOpRegistry::DeviceRegistration* registration;
XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration);

Expand Down