diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 663fec56f1cb3e..ec4e77818eb0f3 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -23,10 +23,11 @@ limitations under the License. #ifndef __ANDROID__ #include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/framework/ops.h" -#include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/cc/saved_model/loader.h" #endif #include "tensorflow/c/c_api_internal.h" +#include "tensorflow/cc/framework/scope_internal.h" +#include "tensorflow/cc/ops/while_loop.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/allocation_description.pb.h" @@ -42,6 +43,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/while_context.h" #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -62,6 +64,7 @@ using tensorflow::AllocationDescription; using tensorflow::DataType; using tensorflow::Graph; using tensorflow::GraphDef; +using tensorflow::ImportGraphDefOptions; using tensorflow::NameRangeMap; using tensorflow::NameRangesForNode; using tensorflow::NewSession; @@ -70,16 +73,21 @@ using tensorflow::NodeBuilder; using tensorflow::NodeDef; using tensorflow::OpDef; using tensorflow::OpRegistry; +using tensorflow::Output; +using tensorflow::OutputTensor; using tensorflow::PartialTensorShape; using tensorflow::RunMetadata; using tensorflow::RunOptions; +using tensorflow::Scope; using tensorflow::Session; +using tensorflow::ShapeRefiner; using tensorflow::Status; using tensorflow::Tensor; using tensorflow::TensorBuffer; using tensorflow::TensorId; using tensorflow::TensorShape; using tensorflow::TensorShapeProto; +using tensorflow::WhileContext; using tensorflow::error::Code; using tensorflow::errors::FailedPrecondition; using tensorflow::errors::InvalidArgument; @@ -831,6 +839,30 @@ const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper, return attr; } +TensorId ToTensorId(const TF_Output& output) { + return TensorId(output.oper->node.name(), output.index); +} + +std::vector OutputsFromTFOutputs(TF_Output* tf_outputs, + int n) { + std::vector outputs(n); + for (int i = 0; i < n; ++i) { + outputs[i] = + tensorflow::Output(&tf_outputs[i].oper->node, tf_outputs[i].index); + } + return outputs; +} + +#ifndef __ANDROID__ +void TFOutputsFromOutputs(const std::vector& outputs, + TF_Output* tf_outputs) { + for (int i = 0; i < outputs.size(); i++) { + tf_outputs[i].oper = ToOperation(outputs[i].node()); + tf_outputs[i].index = outputs[i].index(); + } +} +#endif // __ANDROID__ + } // namespace // Shape functions ----------------------------------------------------------- @@ -966,7 +998,7 @@ void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) { } void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) { - desc->colocation_constraints.emplace_back( + desc->colocation_constraints.emplace( StrCat(tensorflow::kColocationGroupPrefix, op->node.name())); } @@ -979,12 +1011,20 @@ void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name, void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name, const void* const* values, const size_t* lengths, int num_values) { - std::vector v; - v.reserve(num_values); - for (int i = 0; i < num_values; ++i) { - v.emplace_back(static_cast(values[i]), lengths[i]); + if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) { + desc->colocation_constraints.clear(); + for (int i = 0; i < num_values; ++i) { + desc->colocation_constraints.emplace(static_cast(values[i]), + lengths[i]); + } + } else { + std::vector v; + v.reserve(num_values); + for (int i = 0; i < num_values; ++i) { + v.emplace_back(static_cast(values[i]), lengths[i]); + } + desc->node_builder.Attr(attr_name, v); } - desc->node_builder.Attr(attr_name, v); } void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name, @@ -1143,12 +1183,28 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, const void* proto, size_t proto_len, TF_Status* status) { tensorflow::AttrValue attr_value; - if (attr_value.ParseFromArray(proto, proto_len)) { - desc->node_builder.Attr(attr_name, attr_value); - status->status = Status::OK(); - } else { + if (!attr_value.ParseFromArray(proto, proto_len)) { status->status = InvalidArgument("Unparseable AttrValue proto"); + return; } + + if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) { + if (attr_value.value_case() != tensorflow::AttrValue::kList && + attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) { + status->status = + InvalidArgument("Expected \"list\" field for \"", + tensorflow::kColocationAttrName, "\" attribute"); + return; + } + desc->colocation_constraints.clear(); + for (const tensorflow::string& location : attr_value.list().s()) { + desc->colocation_constraints.insert(location); + } + } else { + desc->node_builder.Attr(attr_name, attr_value); + } + + status->status = Status::OK(); } static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, @@ -1160,10 +1216,12 @@ static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, status->status = InvalidArgument("Duplicate node name in graph: '", desc->node_builder.node_name(), "'"); } else { - std::sort(desc->colocation_constraints.begin(), - desc->colocation_constraints.end()); - desc->node_builder.Attr(tensorflow::kColocationAttrName, - desc->colocation_constraints); + if (!desc->colocation_constraints.empty()) { + desc->node_builder.Attr( + tensorflow::kColocationAttrName, + std::vector(desc->colocation_constraints.begin(), + desc->colocation_constraints.end())); + } status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret); if (status->status.ok()) { @@ -1695,14 +1753,6 @@ void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts, opts->opts.prefix = prefix; } -namespace { - -TensorId ToTensorId(const TF_Output& output) { - return TensorId(output.oper->node.name(), output.index); -} - -} // namespace - void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts, const char* src_name, int src_index, TF_Output dst) { @@ -1786,6 +1836,8 @@ void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def, // While loop functions ------------------------------------------------------- namespace { +// Creates a placeholder representing an input to the cond or body graph +// TODO(skyewm): remove these from final graph bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name, TF_Output* input, TF_Status* status) { TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name); @@ -1797,128 +1849,46 @@ bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name, return true; } -bool CreateEnter(TF_Graph* g, const char* node_name, const char* frame_name, - const TF_Output& input, TF_Output* enter, TF_Status* status) - EXCLUSIVE_LOCKS_REQUIRED(g->mu) { - TF_OperationDescription* desc = TF_NewOperationLocked(g, "Enter", node_name); - TF_AddInput(desc, input); - TF_SetAttrString(desc, "frame_name", frame_name, strlen(frame_name)); - TF_Operation* oper = TF_FinishOperationLocked(desc, status); - if (!status->status.ok()) return false; - *enter = {oper, 0}; - return true; -} - -bool CreateMerge(TF_Graph* g, const char* name, const TF_Output& input, - const char* backedge_name, int backedge_index, - TF_Output* merge, TF_Status* status) - EXCLUSIVE_LOCKS_REQUIRED(g->mu) { - TF_OperationDescription* desc = TF_NewOperationLocked(g, "Merge", name); - - // The merge nodes accept the while loop's back edges as an input. Use the - // underlying NodeBuilder API directly to create an input to the - // not-yet-created back edge. - std::vector input_list; - input_list.push_back(NodeBuilder::NodeOut(&input.oper->node, input.index)); - // All merge inputs must have same type - DataType type = input.oper->node.output_type(input.index); - input_list.push_back( - NodeBuilder::NodeOut(backedge_name, backedge_index, type)); - - desc->node_builder.Input(input_list); - - TF_Operation* oper = TF_FinishOperationLocked(desc, status); - if (!status->status.ok()) return false; - *merge = {oper, 0}; - return true; -} - -bool CreateSwitch(TF_Graph* g, const char* name, const TF_Output& input, - const TF_Output& predicate, TF_Output* switch_true, - TF_Output* switch_false, TF_Status* status) - EXCLUSIVE_LOCKS_REQUIRED(g->mu) { - TF_OperationDescription* desc = TF_NewOperationLocked(g, "Switch", name); - TF_AddInput(desc, input); - TF_AddInput(desc, predicate); - TF_Operation* oper = TF_FinishOperationLocked(desc, status); - if (!status->status.ok()) return false; - *switch_false = {oper, 0}; - *switch_true = {oper, 1}; - return true; -} - -bool CreateNext(TF_Graph* g, const char* name, const TF_Output& input, - TF_Output* next, TF_Status* status) - EXCLUSIVE_LOCKS_REQUIRED(g->mu) { - TF_OperationDescription* desc = - TF_NewOperationLocked(g, "NextIteration", name); - TF_AddInput(desc, input); - TF_Operation* oper = TF_FinishOperationLocked(desc, status); - if (!status->status.ok()) return false; - *next = {oper, 0}; - return true; -} - -bool CreateExit(TF_Graph* g, const char* name, const TF_Output& input, - TF_Output* exit, TF_Status* status) - EXCLUSIVE_LOCKS_REQUIRED(g->mu) { - TF_OperationDescription* desc = TF_NewOperationLocked(g, "Exit", name); - TF_AddInput(desc, input); - TF_Operation* oper = TF_FinishOperationLocked(desc, status); - if (!status->status.ok()) return false; - *exit = {oper, 0}; - return true; -} - -class ScopedImportGraphDefOptions { - public: - ScopedImportGraphDefOptions() { opts_ = TF_NewImportGraphDefOptions(); } - ~ScopedImportGraphDefOptions() { TF_DeleteImportGraphDefOptions(opts_); } - - TF_ImportGraphDefOptions* get() const { return opts_; } - - private: - TF_ImportGraphDefOptions* opts_; - - TF_DISALLOW_COPY_AND_ASSIGN(ScopedImportGraphDefOptions); -}; - // Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input // `src_inputs[i]` will have that input replaced with `dst_inputs[i]`. // `prefix` will be prepended to copied node names. `return_nodes` are nodes // in `src_graph`, and the new corresponding nodes in `dst_graph` will be // returned. `return_nodes` should be preallocated to size `nreturn_nodes`. -bool CopyGraph(TF_Graph* src_graph, TF_Graph* dst_graph, - const TF_Output* src_inputs, - const std::vector& dst_inputs, const char* prefix, - const TF_Output* nodes_to_return, int nreturn_nodes, - TF_Output* return_nodes, TF_Status* s) - EXCLUSIVE_LOCKS_REQUIRED(dst_graph->mu) { +Status CopyGraph(Graph* src_graph, Graph* dst_graph, + tensorflow::ShapeRefiner* dst_refiner, + const TF_Output* src_inputs, + const std::vector& dst_inputs, + const string& prefix, const TF_Output* nodes_to_return, + int nreturn_nodes, + std::vector* return_nodes) { GraphDef gdef; - src_graph->graph.ToGraphDef(&gdef); + src_graph->ToGraphDef(&gdef); - ScopedImportGraphDefOptions opts; - TF_ImportGraphDefOptionsSetPrefix(opts.get(), prefix); + tensorflow::ImportGraphDefOptions opts; + opts.prefix = prefix; for (int i = 0; i < dst_inputs.size(); ++i) { - TensorId src = ToTensorId(src_inputs[i]); - TF_ImportGraphDefOptionsAddInputMapping(opts.get(), src.first.data(), - src.second, dst_inputs[i]); + opts.input_map[ToTensorId(src_inputs[i])] = + TensorId(dst_inputs[i].node()->name(), dst_inputs[i].index()); } + // We use the pivot node to control constants in `src_graph` - TF_Operation* pivot = dst_inputs[0].oper; - TF_ImportGraphDefOptionsAddControlDependency(opts.get(), pivot); + Node* pivot = dst_inputs[0].node(); + opts.control_dependencies.push_back(pivot->name()); for (int i = 0; i < nreturn_nodes; ++i) { - TF_ImportGraphDefOptionsAddReturnOutput( - opts.get(), nodes_to_return[i].oper->node.name().c_str(), - nodes_to_return[i].index); + opts.return_tensors.push_back(ToTensorId(nodes_to_return[i])); } - GraphImportGraphDefLocked(dst_graph, gdef, opts.get(), return_nodes, - nreturn_nodes, s); - if (TF_GetCode(s) != TF_OK) return false; - return true; + // TOOD(skyewm): change to OutputTensor + std::vector> return_tensors; + TF_RETURN_IF_ERROR( + ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &return_tensors)); + + for (const auto& pair : return_tensors) { + return_nodes->emplace_back(pair.first, pair.second); + } + return Status::OK(); } bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) { @@ -1967,6 +1937,39 @@ TF_WhileParams EmptyWhileParams() { nullptr, nullptr, nullptr, nullptr}; } +// Utility function for converting to internal C++ datatypes +OutputTensor ToOutputTensor(TF_Output output) { + return OutputTensor(&output.oper->node, output.index); +} + +// Utility function for converting to internal C++ datatypes +std::vector ToOutputTensors( + const std::vector& outputs) { + std::vector result(outputs.size()); + for (int i = 0; i < outputs.size(); ++i) { + result[i] = ToOutputTensor(outputs[i]); + } + return result; +} + +// Utility function for converting to internal C++ datatypes +std::vector ToNodes(const std::vector& outputs) { + std::vector result(outputs.size()); + for (int i = 0; i < outputs.size(); ++i) { + result[i] = (&outputs[i].oper->node); + } + return result; +} + +// Utility function for converting to C++ datatypes +std::vector ToOutputs(TF_Output* outputs, int noutputs) { + std::vector result(noutputs); + for (int i = 0; i < noutputs; ++i) { + result[i] = Output(&outputs[i].oper->node, outputs[i].index); + } + return result; +} + } // namespace TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs, @@ -1977,8 +1980,8 @@ TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs, return EmptyWhileParams(); } - TF_Graph* cond_graph = TF_NewGraph(); - TF_Graph* body_graph = TF_NewGraph(); + TF_Graph* cond_graph = new TF_Graph(); + TF_Graph* body_graph = new TF_Graph(); cond_graph->parent = g; cond_graph->parent_inputs = inputs; body_graph->parent = g; @@ -2026,71 +2029,57 @@ void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status, mutex_lock l(parent->mu); - // Create Enter nodes - std::vector enter_nodes(n); - for (int i = 0; i < n; ++i) { - if (!CreateEnter(parent, StrCat(params->name, "/enter", i).c_str(), - params->name, parent_inputs[i], &enter_nodes[i], status)) { - return; - } - } - - // Create Merge nodes - std::vector merge_nodes(n); - for (int i = 0; i < n; ++i) { - if (!CreateMerge(parent, StrCat(params->name, "/merge", i).c_str(), - enter_nodes[i], StrCat(params->name, "/next", i).c_str(), - 0, &merge_nodes[i], status)) { - return; - } - } + // 'cond_fn' copies the cond graph into the parent graph + tensorflow::ops::CondGraphBuilderFn cond_fn = + [params, parent](const tensorflow::Scope& scope, + const std::vector& inputs, + tensorflow::Output* output) { + DCHECK_EQ(scope.graph(), &parent->graph); + std::vector cond_output; + TF_RETURN_IF_ERROR(CopyGraph(¶ms->cond_graph->graph, &parent->graph, + &parent->refiner, params->cond_inputs, + inputs, scope.impl()->name(), + ¶ms->cond_output, 1, &cond_output)); + *output = cond_output[0]; + return Status::OK(); + }; + + // 'body_fn' copies the body graph into the parent graph + tensorflow::ops::BodyGraphBuilderFn body_fn = + [params, parent, n](const tensorflow::Scope& scope, + const std::vector& inputs, + std::vector* outputs) { + DCHECK_EQ(scope.graph(), &parent->graph); + TF_RETURN_IF_ERROR(CopyGraph(¶ms->body_graph->graph, &parent->graph, + &parent->refiner, params->body_inputs, + inputs, scope.impl()->name(), + params->body_outputs, n, outputs)); + return Status::OK(); + }; - // Copy cond_graph to parent and replace input placeholders with merge node - // outputs, and get handle to new cond output - tensorflow::string cond_prefix = StrCat(params->name, "/cond"); - TF_Output cond_output; - if (!CopyGraph(params->cond_graph, parent, params->cond_inputs, merge_nodes, - cond_prefix.c_str(), ¶ms->cond_output, 1, &cond_output, - status)) { - return; - } + // Create the while loop using an internal scope + tensorflow::Scope scope = + NewInternalScope(&parent->graph, &status->status, &parent->refiner) + .NewSubScope(params->name); - // Create Switch nodes - std::vector switch_trues(n); - std::vector switch_falses(n); - for (int i = 0; i < n; ++i) { - if (!CreateSwitch(parent, StrCat(params->name, "/switch", i).c_str(), - merge_nodes[i], cond_output, &switch_trues[i], - &switch_falses[i], status)) { - return; - } - } + const int max_node_id_before = parent->graph.num_node_ids(); - // Copy body_graph to parent, replace input placeholders with switch node - // true outputs, and get handles to new body outputs - tensorflow::string body_prefix = StrCat(params->name, "/body"); - std::vector body_outputs(n); - if (!CopyGraph(params->body_graph, parent, params->body_inputs, switch_trues, - body_prefix.c_str(), params->body_outputs, n, - body_outputs.data(), status)) { - return; - } + tensorflow::OutputList loop_outputs; + status->status = tensorflow::ops::BuildWhileLoop( + scope, OutputsFromTFOutputs(parent_inputs, n), cond_fn, body_fn, + params->name, true, &loop_outputs); - // Create Next nodes - std::vector next_nodes(n); - for (int i = 0; i < n; ++i) { - if (!CreateNext(parent, StrCat(params->name, "/next", i).c_str(), - body_outputs[i], &next_nodes[i], status)) { - return; - } + // Update name_map with newly-created ops + for (int i = max_node_id_before; i < parent->graph.num_node_ids(); ++i) { + Node* n = parent->graph.FindNodeId(i); + if (n == nullptr) continue; + parent->name_map[n->name()] = n; } - // Create Exit nodes (which are the outputs of the while loop) - for (int i = 0; i < n; ++i) { - if (!CreateExit(parent, StrCat(params->name, "/exit", i).c_str(), - switch_falses[i], &outputs[i], status)) { - return; - } + // Populate 'outputs' + DCHECK_LE(loop_outputs.size(), n); + for (int i = 0; i < loop_outputs.size(); ++i) { + outputs[i] = {ToOperation(loop_outputs[i].node()), loop_outputs[i].index()}; } } @@ -2106,29 +2095,6 @@ void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status, void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); } -#ifndef __ANDROID__ -namespace { - -void OutputsFromTFOutputs(TF_Output* tf_outputs, int n, TF_Status* status, - std::vector* outputs) { - outputs->resize(n); - for (int i = 0; i < n; i++) { - const TF_Output& tf_output = tf_outputs[i]; - (*outputs)[i] = tensorflow::Output(&tf_output.oper->node, tf_output.index); - } -} - -void TFOutputsFromOutputs(const std::vector& outputs, - TF_Output* tf_outputs) { - for (int i = 0; i < outputs.size(); i++) { - tf_outputs[i].oper = ToOperation(outputs[i].node()); - tf_outputs[i].index = outputs[i].index(); - } -} - -} // namespace -#endif // __ANDROID__ - void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, TF_Output* dx, TF_Status* status, TF_Output* dy) { #ifdef __ANDROID__ @@ -2137,11 +2103,9 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, "https://github.com/tensorflow/tensorflow/issues if this feature is " "important to you"); #else - std::vector y_arg; - std::vector x_arg; + std::vector y_arg = OutputsFromTFOutputs(y, ny); + std::vector x_arg = OutputsFromTFOutputs(x, nx); std::vector dy_arg; - OutputsFromTFOutputs(y, ny, status, &y_arg); - OutputsFromTFOutputs(x, nx, status, &x_arg); { // We need to hold on to the lock while we have a scope that uses TF_Graph. @@ -2149,13 +2113,11 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, const int max_node_id_before = g->graph.num_node_ids(); - tensorflow::Scope scope = - NewInternalScope(&g->graph, &status->status, &g->refiner) - .NewSubScope("gradients"); + Scope scope = NewInternalScope(&g->graph, &status->status, &g->refiner) + .NewSubScope("gradients"); if (dx != nullptr) { - std::vector dx_arg; - OutputsFromTFOutputs(dx, ny, status, &dx_arg); + std::vector dx_arg = OutputsFromTFOutputs(dx, ny); status->status = AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg); } else { diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc index 66a943410e2757..7a4072dd611e95 100644 --- a/tensorflow/cc/framework/gradients.cc +++ b/tensorflow/cc/framework/gradients.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/framework/grad_op_registry.h" +#include "tensorflow/cc/framework/while_gradients.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" @@ -25,8 +26,10 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/while_context.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stacktrace.h" namespace tensorflow { namespace { @@ -78,6 +81,13 @@ class SymbolicGradientBuilder { const std::vector& grad_inputs, std::vector* grad_outputs); + // Creates the gradient subgraph for a while loop (or just stores + // `summed_grads` if not all incoming gradients are available yet). All exit + // nodes (which are the first nodes of a loop encountered in the backwards + // pass) are passed to this function rather than processed + // normally. `summed_grads` is the sum of `exit_node`s gradients. + Status ProcessWhileLoop(Node* exit_node, Output summed_grads); + const Scope& scope_; const ops::GradOpRegistry* registry_; const std::vector& outputs_; @@ -85,8 +95,7 @@ class SymbolicGradientBuilder { const std::vector& grad_inputs_; std::vector* grad_outputs_; - // A vector of output endpoints which represents backpropagated - // gradients + // A vector of output endpoints which represents backpropagated gradients typedef std::vector BackpropedGradients; // backprops_ is a map from a node output to its accumulated @@ -105,14 +114,20 @@ class SymbolicGradientBuilder { // gradients from `grad_inputs_`. std::deque ready_; - // The set of node ids in `outputs_`. Used to identify nodes at which to stop - // backprop. + // The set of node ids in `outputs_`. Used to identify nodes at which to + // stop backprop. std::unordered_set output_nodes_; // The set of node ids in `inputs_`. Used to identify nodes at backprop // frontier. Maps from Output -> index into `grad_outputs_`. std::unordered_map input_nodes_; + // For each while loop in the graph, collects the summed gradients for each of + // the loop's exit nodes. Note that unlike backprops_, this map contains the + // output of SumGradients(), not the input (i.e. each exit node may have + // multiple incoming gradients, but we only store the combined Output here). + std::map> while_backprops_; + TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientBuilder); }; @@ -266,6 +281,66 @@ Status SymbolicGradientBuilder::CallGradFunction( return Status::OK(); } +Status SymbolicGradientBuilder::ProcessWhileLoop(Node* exit_node, + Output summed_grads) { + // TODO(skyewm): handle NoGradient in while loop + if (summed_grads == NoGradient()) { + return errors::Unimplemented( + "Missing gradient into while loop not yet implemented"); + } + + DCHECK(exit_node->IsExit()); + WhileContext* while_ctx = exit_node->while_ctx(); + DCHECK(while_ctx != nullptr); + + // Record 'summed_grads' as the backprop input associated with 'exit_node' + std::map& backprops = while_backprops_[while_ctx]; + DCHECK(backprops.find(exit_node) == backprops.end()); + backprops[exit_node] = summed_grads; + + // Wait until we have all exit nodes' backprops collected before processing + // the while loop. + if (backprops.size() < while_ctx->exit_nodes().size()) return Status::OK(); + + // We've seen all the exit nodes for this loop and have collected all the + // backprops. Create the gradient graph for the while loop. + + Scope while_scope = scope_.NewSubScope(while_ctx->frame_name()); + + // Create forward loop counter, which counts how many times the while loop + // body executes. + Output forward_loop_count; + TF_RETURN_IF_ERROR(AddForwardLoopCounter( + while_ctx, while_scope.NewSubScope("ForwardLoopCounter"), + &forward_loop_count)); + + // Create backprop loop counter, which executes 'forward_loop_count' times in + // order to drive the gradient computation. + Output backprop_counter_cond; + TF_RETURN_IF_ERROR(AddBackPropLoopCounter( + while_ctx, forward_loop_count, + while_scope.NewSubScope("BackPropLoopCounter"), &backprop_counter_cond)); + + // Create the gradient while loop. + std::vector dy; + for (Node* n : while_ctx->exit_nodes()) dy.push_back(backprops[n]); + std::vector dx; + TF_RETURN_IF_ERROR(AddWhileGradientLoop( + while_ctx, dy, backprop_counter_cond, while_scope, &dx)); + + // Backprop along the in edges to the while loop (i.e. the inputs to the enter + // nodes) + DCHECK_EQ(dx.size(), while_ctx->enter_nodes().size()); + for (int i = 0; i < dx.size(); ++i) { + Node* enter_node = while_ctx->enter_nodes()[i]; + for (const Edge* e : enter_node->in_edges()) { + if (e->IsControlEdge()) continue; + TF_RETURN_IF_ERROR(BackpropAlongEdge(dx[i], {e->src(), e->src_output()})); + } + } + return Status::OK(); +} + Status SymbolicGradientBuilder::AddGradients() { // Initialize backprops. TF_RETURN_IF_ERROR(Initialize()); @@ -276,6 +351,7 @@ Status SymbolicGradientBuilder::AddGradients() { // n has collected all gradients. Node* n = ready_.front(); ready_.pop_front(); + LOG(ERROR) << "n: " << n->name(); // dy[i] is the sum of i-th output's backpropped gradients. const int num_y = n->num_outputs(); @@ -308,6 +384,15 @@ Status SymbolicGradientBuilder::AddGradients() { continue; } + // Special case: if we find an exit node, process the associated while loop + if (n->IsExit()) { + DCHECK_EQ(dy.size(), 1); + TF_RETURN_IF_ERROR(ProcessWhileLoop(n, dy[0])); + continue; + } + // All loop-specific control flow ops should have been handled above + DCHECK(!n->IsEnter() && !n->IsNextIteration()) << n->DebugString(); + const size_t num_no_grad = no_grad_dy_indices.size(); if (IsPrimitiveOpWithNoGrad(n->type_string()) || num_no_grad == num_y) { // No grad defined for this op, or all outputs returned 'NoGradient': diff --git a/tensorflow/cc/framework/while_gradients.cc b/tensorflow/cc/framework/while_gradients.cc new file mode 100644 index 00000000000000..6d439adb173441 --- /dev/null +++ b/tensorflow/cc/framework/while_gradients.cc @@ -0,0 +1,148 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/framework/while_gradients.h" + +#include "tensorflow/cc/framework/gradients.h" +#include "tensorflow/cc/framework/scope_internal.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/cc/ops/while_loop.h" +#include "tensorflow/core/graph/node_builder.h" + +namespace tensorflow { +namespace { + +using ops::CondGraphBuilderFn; +using ops::BodyGraphBuilderFn; +using ops::BuildWhileLoop; + +Output ToOutput(OutputTensor output_tensor) { + return Output(output_tensor.node, output_tensor.index); +} + +std::vector ToOutputVector( + const std::vector& output_tensors) { + int n = output_tensors.size(); + std::vector result(n); + for (int i = 0; i < n; ++i) result[i] = ToOutput(output_tensors[i]); + return result; +} + +} // namespace + +Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope, + Output* count) { + Output zero = ops::Const(scope, 0, {}); + + // Create while loop: + // i = 0 + // while forward loop predicate is true: + // ++i + + // Condition function that returns condition output from original while loop + CondGraphBuilderFn cond_fn = [while_ctx](const Scope& scope, + const std::vector& inputs, + Output* output) { + *output = ToOutput(while_ctx->cond_output()); + return Status::OK(); + }; + + // Body function that adds one to input + BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope, + const std::vector& inputs, + std::vector* outputs) { + DCHECK_EQ(inputs.size(), 1); + outputs->emplace_back(ops::Add(scope, inputs[0], 1)); + return scope.status(); + }; + + std::vector outputs; + TF_RETURN_IF_ERROR(BuildWhileLoop(scope, {zero}, cond_fn, body_fn, + while_ctx->frame_name(), false, &outputs)); + *count = outputs[0]; + return Status::OK(); +} + +Status AddBackPropLoopCounter(WhileContext* while_ctx, Output n, + const Scope& scope, + Output* backprop_execution_pred) { + // Create while loop: while n > 0: --n + + // Condition function that returns input > 0 + CondGraphBuilderFn cond_fn = [](const Scope& scope, + const std::vector& inputs, + Output* output) { + DCHECK_EQ(inputs.size(), 1); + *output = ops::Greater(scope, inputs[0], 0);; + return scope.status(); + }; + + // Body function that subtracts one from input + BodyGraphBuilderFn body_fn = [](const Scope& scope, + const std::vector& inputs, + std::vector* outputs) { + DCHECK_EQ(inputs.size(), 1); + outputs->emplace_back(ops::Subtract(scope, inputs[0], 1)); + return scope.status(); + }; + + std::vector outputs; + TF_RETURN_IF_ERROR(BuildWhileLoop(scope, {n}, cond_fn, body_fn, + while_ctx->frame_name(), false, &outputs, + backprop_execution_pred)); + return Status::OK(); +} + +Status AddWhileGradientLoop(WhileContext* while_ctx, + const std::vector& grad_inputs, + Output backprop_execution_pred, + const Scope& parent_scope, + std::vector* grad_outputs) { + DCHECK_EQ(grad_inputs.size(), while_ctx->body_outputs().size()); + DCHECK_EQ(while_ctx->body_inputs().size(), + while_ctx->body_outputs().size()); + + Scope scope = parent_scope.NewSubScope("while"); + + // Create while loop: while backprop_execution_pred: while body gradient + + // Condition function that returns 'backprop_execution_pred' + CondGraphBuilderFn cond_fn = [backprop_execution_pred]( + const Scope& scope, + const std::vector& inputs, + Output* output) { + *output = backprop_execution_pred; + return Status::OK(); + }; + + // Body function that builds while body gradient subgraph + BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope, + const std::vector& inputs, + std::vector* outputs) { + std::vector body_outputs = + ToOutputVector(while_ctx->body_outputs()); + std::vector body_inputs = ToOutputVector(while_ctx->body_inputs()); + return AddSymbolicGradients(scope, body_outputs, body_inputs, inputs, + outputs); + }; + + TF_RETURN_IF_ERROR(BuildWhileLoop(scope, grad_inputs, cond_fn, body_fn, + while_ctx->frame_name(), false, + grad_outputs)); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/cc/framework/while_gradients.h b/tensorflow/cc/framework/while_gradients.h new file mode 100644 index 00000000000000..e7ad08cded92b9 --- /dev/null +++ b/tensorflow/cc/framework/while_gradients.h @@ -0,0 +1,53 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ +#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/core/graph/while_context.h" + +// Utility functions for constructing while loop gradients + +namespace tensorflow { + +// Creates a loop that counts the number of iterations performed by the while +// loop associated with `while_ctx`. The returned output yields the iteration +// count. +Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope, + Output* count); + +// Creates a loop that executes `n` times. The returned output is the boolean +// predicate indicating if the loop is still executing. This is used to drive +// the gradient computation for the while loop associated with `while_ctx`. +Status AddBackPropLoopCounter(WhileContext* while_ctx, Output n, + const Scope& scope, + Output* backprop_execution_pred); + +// Creates the main backprop loop that computes the gradient of the loop +// associated with `while_ctx`. `grad_inputs` are the partial derivatives +// w.r.t. the loop outputs, i.e. the exit nodes. `backprop_execution_pred` is +// the predicate to use for the backprop loop (see AddBackPropLoopCounter()). +// The partial derivatives w.r.t. the loop inputs, i.e. the input loop vars, are +// returned in `grad_outputs`. +Status AddWhileGradientLoop(WhileContext* while_ctx, + const std::vector& grad_inputs, + Output backprop_execution_pred, const Scope& scope, + std::vector* grad_outputs); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ diff --git a/tensorflow/cc/ops/while_loop.cc b/tensorflow/cc/ops/while_loop.cc new file mode 100644 index 00000000000000..9e1a7f30285b59 --- /dev/null +++ b/tensorflow/cc/ops/while_loop.cc @@ -0,0 +1,173 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/while_loop.h" + +#include "tensorflow/cc/framework/scope_internal.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" + +namespace tensorflow { +namespace ops { + +namespace { + +// Utility function for converting to internal C++ datatypes +OutputTensor ToOutputTensor(Output output) { + return OutputTensor(output.node(), output.index()); +} + +// Utility function for converting to internal C++ datatypes +std::vector ToOutputTensors(const std::vector& outputs) { + std::vector result(outputs.size()); + for (int i = 0; i < outputs.size(); ++i) { + result[i] = ToOutputTensor(outputs[i]); + } + return result; +} + +// Utility function for converting to internal C++ datatypes +std::vector ToNodes(const std::vector& outputs) { + std::vector result(outputs.size()); + for (int i = 0; i < outputs.size(); ++i) { + result[i] = (outputs[i].node()); + } + return result; +} + +} // namespace + +Status BuildWhileLoop(const Scope& scope, const std::vector& inputs, + CondGraphBuilderFn cond, BodyGraphBuilderFn body, + const string& frame_name, bool create_while_ctx, + OutputList* outputs, Output* cond_output) { + DCHECK(!inputs.empty()); + DCHECK(outputs != nullptr); + DCHECK(outputs->empty()); + + TF_RETURN_IF_ERROR(scope.status()); + int n = inputs.size(); + + std::vector enter_outputs(n); + for (int i = 0; i < n; ++i) { + enter_outputs[i] = internal::Enter(scope, inputs[i], frame_name); + } + + // The merge nodes accept the while loop's back edges as an input (i.e. the + // not-yet-created next iteration nodes). Use the underlying NodeBuilder API + // directly to create an input to the not-yet-created back edge. + + // Manually generate what the NextIteration node names will be. + TF_RETURN_IF_ERROR(scope.status()); + std::vector next_names(n); + next_names[0] = strings::StrCat(scope.impl()->name(), "/NextIteration"); + for (int i = 1; i < n; ++i) { + next_names[i] = strings::StrCat(scope.impl()->name(), "/NextIteration_", i); + } + + // Use NodeBuilder API to build merge nodes + TF_RETURN_IF_ERROR(scope.status()); + std::vector merge_outputs(n); + for (int i = 0; i < n; ++i) { + NodeBuilder::NodeOut enter_input( + enter_outputs[i].node(), enter_outputs[i].index()); + + DataType dtype = enter_outputs[i].node()->output_type(0); + NodeBuilder::NodeOut next_input(next_names[i], 0, dtype); + + std::vector input_list({enter_input, next_input}); + string unique_name = scope.GetUniqueNameForOp("Merge"); + NodeBuilder builder = NodeBuilder(unique_name, "Merge").Input(input_list); + scope.UpdateBuilder(&builder); + + Node* merge_node; + TF_RETURN_IF_ERROR(builder.Finalize(scope.graph(), &merge_node)); + TF_RETURN_IF_ERROR(scope.DoShapeInference(merge_node)); + merge_outputs[i] = Output(merge_node, 0); + } + + TF_RETURN_IF_ERROR(scope.status()); + // The control dependency is for constants in the cond graph + Scope cond_scope = + scope.NewSubScope("cond").WithControlDependencies(merge_outputs[0]); + Output raw_cond_out; + TF_RETURN_IF_ERROR(cond(cond_scope, merge_outputs, &raw_cond_out)); + if (raw_cond_out.type() != DT_BOOL) { + return errors::InvalidArgument( + "BuildWhileLoop: 'cond' argument must return a boolean output, got ", + DataTypeString(raw_cond_out.type())); + } + Output cond_out = LoopCond(scope, raw_cond_out).output; + if (cond_output != nullptr) *cond_output = cond_out; + + std::vector switch_trues(n); + std::vector switch_falses(n); + for (int i = 0; i < n; ++i) { + auto swtch = Switch(scope, merge_outputs[i], cond_out); + switch_trues[i] = swtch.output_true; + switch_falses[i] = swtch.output_false; + } + + TF_RETURN_IF_ERROR(scope.status()); + // The control dependency is for constants in the body graph + Scope body_scope = + scope.NewSubScope("body").WithControlDependencies(switch_trues[0]); + std::vector body_outputs; + TF_RETURN_IF_ERROR(body(body_scope, switch_trues, &body_outputs)); + if (body_outputs.size() != n) { + return errors::InvalidArgument( + "BuildWhileLoop: 'body' argument expected to return ", n, + "outputs, got ", body_outputs.size()); + } + + std::vector next_outputs(n); + for (int i = 0; i < n; ++i) { + next_outputs[i] = NextIteration(scope, body_outputs[i]); + DCHECK_EQ(next_outputs[i].node()->name(), next_names[i]); + } + + // Create the backedges from the NextIteration nodes to the Merge nodes + for (int i = 0; i < n; ++i) { + // TOOD(skye): does this export correctly? + scope.graph()->AddEdge(next_outputs[i].node(), next_outputs[i].index(), + merge_outputs[i].node(), 1); + } + + outputs->resize(n); + for (int i = 0; i < n; ++i) { + (*outputs)[i] = internal::Exit(scope, switch_falses[i]); + } + TF_RETURN_IF_ERROR(scope.status()); + + if (create_while_ctx) { + WhileContext* while_ctx; + TF_RETURN_IF_ERROR(scope.graph()->AddWhileContext( + frame_name, ToNodes(enter_outputs), ToNodes(*outputs), + ToOutputTensor(cond_out), ToOutputTensors(switch_trues), + ToOutputTensors(body_outputs), &while_ctx)); + + // Set while_ctx for all exit nodes. We currently don't require knowing the + // while_ctx for any other nodes. + for (int i = 0; i < n; ++i) { + (*outputs)[i].node()->set_while_ctx(while_ctx); + } + } + return Status::OK(); +} + +} // namespace ops +} // namespace tensorflow diff --git a/tensorflow/cc/ops/while_loop.h b/tensorflow/cc/ops/while_loop.h new file mode 100644 index 00000000000000..7c05bdd56db2b9 --- /dev/null +++ b/tensorflow/cc/ops/while_loop.h @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CC_OPS_WHILE_LOOP_H_ +#define THIRD_PARTY_TENSORFLOW_CC_OPS_WHILE_LOOP_H_ + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" + +namespace tensorflow { +namespace ops { + +// Function that takes cond subgraph inputs and return cond subgraph boolean +// output. 'output' need not be set if an error is returned. +typedef std::function& inputs, + Output* output)> + CondGraphBuilderFn; + +// Function that takes body subgraph inputs and return body subgraph +// outputs. 'outputs' need not be populated if an error is returned. +typedef std::function& inputs, + std::vector* outputs)> + BodyGraphBuilderFn; + +// Constructs a while loop. +// +// Arguments: +// * scope: used to construct the while loop +// * inputs: the initial values of the loop variables. Must be non-empty. +// * cond: a function that builds the condition subgraph of the loop. Takes the +// current loop variables as inputs and returns a boolean output indicating +// whether the loop should continue. +// * body: a function that builds the body subgraph of the loop. Takes the +// current loop variables as inputs and returns the updated loop variables. +// * frame_name: the frame name to use for this while loop +// * outputs: output param that returns final loop variable outputs in non-error +// case. Must be non-null and empty. +// +// Returns an error if the while loop could not be fully constructed. +// +// TODO(skyewm): clean up partially-constructed loop in error case +// TODO(skyewm): create public interface to this method +Status BuildWhileLoop(const Scope& scope, const std::vector& inputs, + CondGraphBuilderFn cond, BodyGraphBuilderFn body, + const string& frame_name, bool create_while_ctx, + OutputList* outputs, Output* cond_output = nullptr); + +} // namespace ops +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CC_OPS_WHILE_LOOP_H_ diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index f6586f0519792f..c207c2dcaf3e18 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/graph/while_context.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -109,7 +110,8 @@ Node::Node() cost_id_(-1), class_(NC_UNINITIALIZED), props_(nullptr), - assigned_device_name_index_(0) {} + assigned_device_name_index_(0), + while_ctx_(nullptr) {} void Node::Initialize(int id, int cost_id, std::shared_ptr props) { @@ -378,6 +380,8 @@ const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) { DCHECK_EQ(y, kControlSlot) << dest->DebugString(); } + // Update props_ to reflect new input + Edge* e = nullptr; if (free_edges_.empty()) { e = new (arena_.Alloc(sizeof(Edge))) Edge; // placement new @@ -552,4 +556,30 @@ int Graph::InternDeviceName(const string& device_name) { return index; } +Status Graph::AddWhileContext(StringPiece frame_name, + std::vector enter_nodes, + std::vector exit_nodes, + OutputTensor cond_output, + std::vector body_inputs, + std::vector body_outputs, + WhileContext** result) { + auto pair = while_ctxs_.insert(std::pair( + frame_name.ToString(), + WhileContext(frame_name, enter_nodes, exit_nodes, cond_output, + body_inputs, body_outputs))); + if (!pair.second) { + *result = nullptr; + return errors::InvalidArgument("WhileContext with frame name '", frame_name, + "' already exists"); + } + *result = &pair.first->second; + return Status::OK(); +} + +string Graph::DebugString() const { + GraphDef gdef; + ToGraphDef(&gdef); + return gdef.DebugString(); +} + } // namespace tensorflow diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 78a0e8fd79f8e8..2fabb636270e28 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -62,6 +62,7 @@ class Graph; class GraphDef; class Node; class VersionDef; +class WhileContext; class NeighborIter; // Declared below class NodeIter; // Declared below @@ -184,6 +185,13 @@ class Node { Status input_node(int idx, const Node** n) const; Status input_node(int idx, Node** n) const; + WhileContext* while_ctx() const { return while_ctx_; } + void set_while_ctx(WhileContext* while_ctx) { + DCHECK(IsExit()); + DCHECK(while_ctx_ == nullptr); + while_ctx_ = while_ctx; + } + private: friend class Graph; Node(); @@ -256,9 +264,35 @@ class Node { // field and reclaim that memory. Graph* graph_; + // If this is an exit node, set to the WhileContext associated with the while + // loop this node is part of. Otherwise null. (This is only set for exit nodes + // because they're the first nodes of a loop encountered while creating the + // gradient graph.) + WhileContext* while_ctx_; + TF_DISALLOW_COPY_AND_ASSIGN(Node); }; +// Represents an input to a node, i.e., the `index`-th input to `node`. +struct InputTensor { + Node* node; + int index; + + InputTensor(Node* n, int i) : node(n), index(i) {} + InputTensor() : node(nullptr), index(0) {} +}; + +// Represents an output of a node, i.e., the `index`-th output of `node`. Note +// that a single `OutputTensor` can correspond to multiple `Edge`s if the output +// is consumed by multiple destination nodes. +struct OutputTensor { + Node* node; + int index; + + OutputTensor(Node* n, int i) : node(n), index(i) {} + OutputTensor() : node(nullptr), index(0) {} +}; + class Edge { public: Node* src() const { return src_; } @@ -503,6 +537,17 @@ class Graph { node->assigned_device_name_index_ = InternDeviceName(device_name); } + // Create and return a new WhileContext owned by this graph. This is called + // when a new while loop is created. + Status AddWhileContext(StringPiece frame_name, std::vector enter_nodes, + std::vector exit_nodes, + OutputTensor cond_output, + std::vector body_inputs, + std::vector body_outputs, + WhileContext** result); + + string DebugString() const; + // TODO(josh11b): uint64 hash() const; private: @@ -570,6 +615,12 @@ class Graph { // Maps unique device names to indices within device_names_[i]. std::unordered_map device_names_map_; + // All the while contexts owned by this graph, keyed by frame name, + // corresonding to all the while loops contained in this graph (including + // nested loops). The stored contexts are usually accessed via + // AddWhileContext() or Node::while_ctx() but this manages the lifetime. + std::map while_ctxs_; + TF_DISALLOW_COPY_AND_ASSIGN(Graph); }; diff --git a/tensorflow/core/graph/while_context.cc b/tensorflow/core/graph/while_context.cc new file mode 100644 index 00000000000000..57dc32cb0dd1e0 --- /dev/null +++ b/tensorflow/core/graph/while_context.cc @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/graph/while_context.h" + +namespace tensorflow { + +WhileContext::WhileContext(StringPiece frame_name, + std::vector enter_nodes, + std::vector exit_nodes, + OutputTensor cond_output, + std::vector body_inputs, + std::vector body_outputs) + : frame_name_(frame_name.ToString()), + enter_nodes_(std::move(enter_nodes)), + exit_nodes_(std::move(exit_nodes)), + cond_output_(cond_output), + body_inputs_(std::move(body_inputs)), + body_outputs_(std::move(body_outputs)) { + int n = enter_nodes_.size(); + DCHECK_EQ(exit_nodes_.size(), n); + DCHECK_EQ(body_inputs_.size(), n); + DCHECK_EQ(body_outputs_.size(), n); +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/while_context.h b/tensorflow/core/graph/while_context.h new file mode 100644 index 00000000000000..ad4c57fe2d7532 --- /dev/null +++ b/tensorflow/core/graph/while_context.h @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPH_WHILE_CONTEXT_H_ +#define TENSORFLOW_GRAPH_WHILE_CONTEXT_H_ + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// State describing a while loop. Every user-defined while loop has an +// associated WhileContext, i.e., there is a WhileContext for every execution +// frame. Created with the while loop and used during gradient +// construction. Note that the gradient graph of while loop contains while loops +// itself, but these do not generate separate WhileContexts. +class WhileContext { + public: + // `parent` can be nullptr. + WhileContext(StringPiece frame_name, std::vector enter_nodes, + std::vector exit_nodes, OutputTensor cond_output, + std::vector body_inputs, + std::vector body_outputs); + + const string& frame_name() const { return frame_name_; } + const std::vector& enter_nodes() const { return enter_nodes_; } + const std::vector& exit_nodes() const { return exit_nodes_; } + const OutputTensor& cond_output() const { return cond_output_; } + const std::vector& body_inputs() const { return body_inputs_; } + const std::vector& body_outputs() const { + return body_outputs_; + } + + private: + const string frame_name_; + + // The enter nodes defining the input loop variables to the while loop. This + // vector defines the order of the loop variables. + const std::vector enter_nodes_; + + // The exit nodes defining the outputs of the while loop. These are in loop + // variable order. + const std::vector exit_nodes_; + + // The boolean output of the loop predicate. + const OutputTensor cond_output_; + + // The inputs and outputs to the loop body + const std::vector body_inputs_; + const std::vector body_outputs_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_GRAPH_H_