Skip to content

Commit

Permalink
Allow constants to propagate through while_loop even when explicitly …
Browse files Browse the repository at this point in the history
…in carried state.

PiperOrigin-RevId: 367507719
Change-Id: I900ea39b92c0d07f720d1872a575c05be82e25d5
  • Loading branch information
brianwa84 authored and tensorflower-gardener committed Apr 8, 2021
1 parent fb2c2cd commit a86ff79
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 2 deletions.
5 changes: 5 additions & 0 deletions tensorflow/compiler/tf2xla/const_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel,
Node* ret_i = fbody->ret_nodes[i];
const Node* ret_i_input_0;
TF_RETURN_IF_ERROR(ret_i->input_node(0, &ret_i_input_0));
if (ret_i_input_0->type_string() == "Identity") {
// TODO(b/184727356): Support IdentityN, loop-invariant While.
VLOG(2) << "Propagate through Identity: input " << i;
TF_RETURN_IF_ERROR(ret_i_input_0->input_node(0, &ret_i_input_0));
}
if (ret_i_input_0->id() == arg_i->id()) {
const_input_idxs->push_back(i);
} else {
Expand Down
16 changes: 15 additions & 1 deletion tensorflow/compiler/tf2xla/tf2xla_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,21 @@ Status PropagateConstIntoWhileNode(Graph* g, Node* while_node,
}
const OpDef_ArgDef& input_arg = body_func->signature().input_arg(i);
if (output_arg_input->second != input_arg.name()) {
continue;
// Attempt to tolerate _Retval = Identity(_Arg).
const NodeDef* output_arg_identity = nullptr;
for (const NodeDef& node : body_func->node_def()) {
if (node.op() == "Identity" &&
(node.name() == output_arg_input->second ||
node.name() + ":0" == output_arg_input->second ||
node.name() + ":output:0" == output_arg_input->second)) {
output_arg_identity = &node;
}
}
if (!output_arg_identity ||
output_arg_identity->input(0) != input_arg.name()) {
VLOG(1) << "While input/output mismatch; not propagating const " << i;
continue;
}
}

const_input_index_to_node[i] = input_node;
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/xla/client/xla_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
break;
}
if (!*is_constant) {
VLOG(1) << "Non-constant: " << instr.name();
VLOG(1) << "Non-constant: " << instr.name() << " " << instr.opcode();
}
visited->insert(op_handle);
}
Expand Down
18 changes: 18 additions & 0 deletions tensorflow/python/eager/def_function_xla_jit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
Expand Down Expand Up @@ -356,6 +357,23 @@ def g(x):
self.assertAllClose(40.0, f.get_concrete_function(2.0)(2.0))
self.assertAllClose([40.0, 28.0], g.get_concrete_function(2.0)(2.0))

def testWhileLoopWithUnmodifiedCarriedShape(self):
with ops.device('device:{}:0'.format(self.device)):
signature = [tensor_spec.TensorSpec(shape=[None], dtype=dtypes.float32)]

# We define a signature that specifies unknown vector shape, then test
# that tf.shape constness gets properly propagated into the while_loop
# even when carried as part of the loop state.
@def_function.function(input_signature=signature, jit_compile=True)
def g(x):
return control_flow_ops.while_loop_v2(
lambda *_: True,
lambda y, shp: (y + random_ops.random_normal(shp)**2, shp),
(x, array_ops.shape(x)),
maximum_iterations=3)[0]

self.assertAllGreater(g(array_ops.zeros([7])), 0.)

def testMethodCompilation(self):

with ops.device('device:{}:0'.format(self.device)):
Expand Down

0 comments on commit a86ff79

Please sign in to comment.