Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
408 changes: 185 additions & 223 deletions tensorflow/c/c_api.cc

Large diffs are not rendered by default.

93 changes: 89 additions & 4 deletions tensorflow/cc/framework/gradients.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@ limitations under the License.

#include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/framework/while_gradients.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/while_context.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stacktrace.h"

namespace tensorflow {
namespace {
Expand Down Expand Up @@ -78,15 +81,21 @@ class SymbolicGradientBuilder {
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs);

// Creates the gradient subgraph for a while loop (or just stores
// `summed_grads` if not all incoming gradients are available yet). All exit
// nodes (which are the first nodes of a loop encountered in the backwards
// pass) are passed to this function rather than processed
// normally. `summed_grads` is the sum of `exit_node`s gradients.
Status ProcessWhileLoop(Node* exit_node, Output summed_grads);

const Scope& scope_;
const ops::GradOpRegistry* registry_;
const std::vector<Output>& outputs_;
const std::vector<Output>& inputs_;
const std::vector<Output>& grad_inputs_;
std::vector<Output>* grad_outputs_;

// A vector of output endpoints which represents backpropagated
// gradients
// A vector of output endpoints which represents backpropagated gradients
typedef std::vector<Output> BackpropedGradients;

// backprops_ is a map from a node output to its accumulated
Expand All @@ -105,14 +114,20 @@ class SymbolicGradientBuilder {
// gradients from `grad_inputs_`.
std::deque<Node*> ready_;

// The set of node ids in `outputs_`. Used to identify nodes at which to stop
// backprop.
// The set of node ids in `outputs_`. Used to identify nodes at which to
// stop backprop.
std::unordered_set<int> output_nodes_;

// The set of node ids in `inputs_`. Used to identify nodes at backprop
// frontier. Maps from Output -> index into `grad_outputs_`.
std::unordered_map<Output, int, OutputHash, OutputEq> input_nodes_;

// For each while loop in the graph, collects the summed gradients for each of
// the loop's exit nodes. Note that unlike backprops_, this map contains the
// output of SumGradients(), not the input (i.e. each exit node may have
// multiple incoming gradients, but we only store the combined Output here).
std::map<WhileContext*, std::map<Node*, Output>> while_backprops_;

TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientBuilder);
};

Expand Down Expand Up @@ -266,6 +281,66 @@ Status SymbolicGradientBuilder::CallGradFunction(
return Status::OK();
}

Status SymbolicGradientBuilder::ProcessWhileLoop(Node* exit_node,
Output summed_grads) {
// TODO(skyewm): handle NoGradient in while loop
if (summed_grads == NoGradient()) {
return errors::Unimplemented(
"Missing gradient into while loop not yet implemented");
}

DCHECK(exit_node->IsExit());
WhileContext* while_ctx = exit_node->while_ctx();
DCHECK(while_ctx != nullptr);

// Record 'summed_grads' as the backprop input associated with 'exit_node'
std::map<Node*, Output>& backprops = while_backprops_[while_ctx];
DCHECK(backprops.find(exit_node) == backprops.end());
backprops[exit_node] = summed_grads;

// Wait until we have all exit nodes' backprops collected before processing
// the while loop.
if (backprops.size() < while_ctx->exit_nodes().size()) return Status::OK();

// We've seen all the exit nodes for this loop and have collected all the
// backprops. Create the gradient graph for the while loop.

Scope while_scope = scope_.NewSubScope(while_ctx->frame_name());

// Create forward loop counter, which counts how many times the while loop
// body executes.
Output forward_loop_count;
TF_RETURN_IF_ERROR(AddForwardLoopCounter(
while_ctx, while_scope.NewSubScope("ForwardLoopCounter"),
&forward_loop_count));

// Create backprop loop counter, which executes 'forward_loop_count' times in
// order to drive the gradient computation.
Output backprop_counter_cond;
TF_RETURN_IF_ERROR(AddBackPropLoopCounter(
while_ctx, forward_loop_count,
while_scope.NewSubScope("BackPropLoopCounter"), &backprop_counter_cond));

// Create the gradient while loop.
std::vector<Output> dy;
for (Node* n : while_ctx->exit_nodes()) dy.push_back(backprops[n]);
std::vector<Output> dx;
TF_RETURN_IF_ERROR(AddWhileGradientLoop(
while_ctx, dy, backprop_counter_cond, while_scope, &dx));

// Backprop along the in edges to the while loop (i.e. the inputs to the enter
// nodes)
DCHECK_EQ(dx.size(), while_ctx->enter_nodes().size());
for (int i = 0; i < dx.size(); ++i) {
Node* enter_node = while_ctx->enter_nodes()[i];
for (const Edge* e : enter_node->in_edges()) {
if (e->IsControlEdge()) continue;
TF_RETURN_IF_ERROR(BackpropAlongEdge(dx[i], {e->src(), e->src_output()}));
}
}
return Status::OK();
}

Status SymbolicGradientBuilder::AddGradients() {
// Initialize backprops.
TF_RETURN_IF_ERROR(Initialize());
Expand All @@ -276,6 +351,7 @@ Status SymbolicGradientBuilder::AddGradients() {
// n has collected all gradients.
Node* n = ready_.front();
ready_.pop_front();
LOG(ERROR) << "n: " << n->name();

// dy[i] is the sum of i-th output's backpropped gradients.
const int num_y = n->num_outputs();
Expand Down Expand Up @@ -308,6 +384,15 @@ Status SymbolicGradientBuilder::AddGradients() {
continue;
}

// Special case: if we find an exit node, process the associated while loop
if (n->IsExit()) {
DCHECK_EQ(dy.size(), 1);
TF_RETURN_IF_ERROR(ProcessWhileLoop(n, dy[0]));
continue;
}
// All loop-specific control flow ops should have been handled above
DCHECK(!n->IsEnter() && !n->IsNextIteration()) << n->DebugString();

const size_t num_no_grad = no_grad_dy_indices.size();
if (IsPrimitiveOpWithNoGrad(n->type_string()) || num_no_grad == num_y) {
// No grad defined for this op, or all outputs returned 'NoGradient':
Expand Down
148 changes: 148 additions & 0 deletions tensorflow/cc/framework/while_gradients.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/cc/framework/while_gradients.h"

#include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/framework/scope_internal.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/cc/ops/while_loop.h"
#include "tensorflow/core/graph/node_builder.h"

namespace tensorflow {
namespace {

using ops::CondGraphBuilderFn;
using ops::BodyGraphBuilderFn;
using ops::BuildWhileLoop;

Output ToOutput(OutputTensor output_tensor) {
return Output(output_tensor.node, output_tensor.index);
}

std::vector<Output> ToOutputVector(
const std::vector<OutputTensor>& output_tensors) {
int n = output_tensors.size();
std::vector<Output> result(n);
for (int i = 0; i < n; ++i) result[i] = ToOutput(output_tensors[i]);
return result;
}

} // namespace

Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope,
Output* count) {
Output zero = ops::Const(scope, 0, {});

// Create while loop:
// i = 0
// while forward loop predicate is true:
// ++i

// Condition function that returns condition output from original while loop
CondGraphBuilderFn cond_fn = [while_ctx](const Scope& scope,
const std::vector<Output>& inputs,
Output* output) {
*output = ToOutput(while_ctx->cond_output());
return Status::OK();
};

// Body function that adds one to input
BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope,
const std::vector<Output>& inputs,
std::vector<Output>* outputs) {
DCHECK_EQ(inputs.size(), 1);
outputs->emplace_back(ops::Add(scope, inputs[0], 1));
return scope.status();
};

std::vector<Output> outputs;
TF_RETURN_IF_ERROR(BuildWhileLoop(scope, {zero}, cond_fn, body_fn,
while_ctx->frame_name(), false, &outputs));
*count = outputs[0];
return Status::OK();
}

Status AddBackPropLoopCounter(WhileContext* while_ctx, Output n,
const Scope& scope,
Output* backprop_execution_pred) {
// Create while loop: while n > 0: --n

// Condition function that returns input > 0
CondGraphBuilderFn cond_fn = [](const Scope& scope,
const std::vector<Output>& inputs,
Output* output) {
DCHECK_EQ(inputs.size(), 1);
*output = ops::Greater(scope, inputs[0], 0);;
return scope.status();
};

// Body function that subtracts one from input
BodyGraphBuilderFn body_fn = [](const Scope& scope,
const std::vector<Output>& inputs,
std::vector<Output>* outputs) {
DCHECK_EQ(inputs.size(), 1);
outputs->emplace_back(ops::Subtract(scope, inputs[0], 1));
return scope.status();
};

std::vector<Output> outputs;
TF_RETURN_IF_ERROR(BuildWhileLoop(scope, {n}, cond_fn, body_fn,
while_ctx->frame_name(), false, &outputs,
backprop_execution_pred));
return Status::OK();
}

Status AddWhileGradientLoop(WhileContext* while_ctx,
const std::vector<Output>& grad_inputs,
Output backprop_execution_pred,
const Scope& parent_scope,
std::vector<Output>* grad_outputs) {
DCHECK_EQ(grad_inputs.size(), while_ctx->body_outputs().size());
DCHECK_EQ(while_ctx->body_inputs().size(),
while_ctx->body_outputs().size());

Scope scope = parent_scope.NewSubScope("while");

// Create while loop: while backprop_execution_pred: while body gradient

// Condition function that returns 'backprop_execution_pred'
CondGraphBuilderFn cond_fn = [backprop_execution_pred](
const Scope& scope,
const std::vector<Output>& inputs,
Output* output) {
*output = backprop_execution_pred;
return Status::OK();
};

// Body function that builds while body gradient subgraph
BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope,
const std::vector<Output>& inputs,
std::vector<Output>* outputs) {
std::vector<Output> body_outputs =
ToOutputVector(while_ctx->body_outputs());
std::vector<Output> body_inputs = ToOutputVector(while_ctx->body_inputs());
return AddSymbolicGradients(scope, body_outputs, body_inputs, inputs,
outputs);
};

TF_RETURN_IF_ERROR(BuildWhileLoop(scope, grad_inputs, cond_fn, body_fn,
while_ctx->frame_name(), false,
grad_outputs));
return Status::OK();
}

} // namespace tensorflow
53 changes: 53 additions & 0 deletions tensorflow/cc/framework/while_gradients.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_

#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/core/graph/while_context.h"

// Utility functions for constructing while loop gradients

namespace tensorflow {

// Creates a loop that counts the number of iterations performed by the while
// loop associated with `while_ctx`. The returned output yields the iteration
// count.
Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope,
Output* count);

// Creates a loop that executes `n` times. The returned output is the boolean
// predicate indicating if the loop is still executing. This is used to drive
// the gradient computation for the while loop associated with `while_ctx`.
Status AddBackPropLoopCounter(WhileContext* while_ctx, Output n,
const Scope& scope,
Output* backprop_execution_pred);

// Creates the main backprop loop that computes the gradient of the loop
// associated with `while_ctx`. `grad_inputs` are the partial derivatives
// w.r.t. the loop outputs, i.e. the exit nodes. `backprop_execution_pred` is
// the predicate to use for the backprop loop (see AddBackPropLoopCounter()).
// The partial derivatives w.r.t. the loop inputs, i.e. the input loop vars, are
// returned in `grad_outputs`.
Status AddWhileGradientLoop(WhileContext* while_ctx,
const std::vector<Output>& grad_inputs,
Output backprop_execution_pred, const Scope& scope,
std::vector<Output>* grad_outputs);

} // namespace tensorflow

#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_
Loading