Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make "smart_cond" api public and reusable #13954

Merged
merged 18 commits into from Feb 19, 2018
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions tensorflow/contrib/crf/python/ops/crf.py
Expand Up @@ -105,8 +105,8 @@ def _multi_seq_fn():
return utils.smart_cond(
pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1],
1),
fn1=_single_seq_fn,
fn2=_multi_seq_fn)
true_fn=_single_seq_fn,
false_fn=_multi_seq_fn)


def crf_log_norm(inputs, sequence_lengths, transition_params):
Expand Down Expand Up @@ -513,5 +513,5 @@ def _multi_seq_fn():
return utils.smart_cond(
pred=math_ops.equal(
potentials.shape[1].value or array_ops.shape(potentials)[1], 1),
fn1=_single_seq_fn,
fn2=_multi_seq_fn)
true_fn=_single_seq_fn,
false_fn=_multi_seq_fn)
2 changes: 2 additions & 0 deletions tensorflow/contrib/framework/__init__.py
Expand Up @@ -104,6 +104,8 @@

from tensorflow.python.framework.ops import prepend_name_scope
from tensorflow.python.framework.ops import strip_name_scope
from tensorflow.python.ops.control_flow_ops import smart_cond
from tensorflow.python.ops.control_flow_ops import smart_constant_value

from tensorflow.python.framework.tensor_spec import BoundedTensorSpec
from tensorflow.python.framework.tensor_spec import TensorSpec
Expand Down
64 changes: 25 additions & 39 deletions tensorflow/python/layers/utils.py
Expand Up @@ -178,67 +178,53 @@ def deconv_output_length(input_length, filter_size, padding, stride):
return input_length


def smart_cond(pred, fn1, fn2, name=None):
"""Return either `fn1()` or `fn2()` based on the boolean predicate `pred`.
def smart_cond(pred, true_fn=None, false_fn=None, name=None):
"""Return either `true_fn()` if predicate `pred` is true else `false_fn()`.

If `pred` is a bool or has a constant value, we return either `fn1()`
or `fn2()`, otherwise we use `tf.cond` to dynamically route to both.
If `pred` is a bool or has a constant value, we return either `true_fn()`
or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.

Arguments:
pred: A scalar determining whether to return the result of `fn1` or `fn2`.
fn1: The callable to be performed if pred is true.
fn2: The callable to be performed if pred is false.
pred: A scalar determining whether to return the result of `true_fn` or
`false_fn`.
true_fn: The callable to be performed if pred is true.
false_fn: The callable to be performed if pred is false.
name: Optional name prefix when using `tf.cond`.

Returns:
Tensors returned by the call to either `fn1` or `fn2`.
Tensors returned by the call to either `true_fn` or `false_fn`.

Raises:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't you just do:

smart_cond = control_flow_ops.smart_cond

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did consider that before, but I think current way could make the user/developer access the function doc more easily. I also referred to some other wrapper codes in the project such as https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/tpu/python/tpu/tpu.py#L499.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the doc be the same, regardless of whether you copy it over or just use a function alias? The docstring copy is a DRY violation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got it. Thanks for pointing out. I have submitted a commit to change to function alias assignment.

TypeError: If `fn1` or `fn2` is not callable.
TypeError: If `true_fn` or `false_fn` is not callable.
"""
if not callable(fn1):
raise TypeError('`fn1` must be callable.')
if not callable(fn2):
raise TypeError('`fn2` must be callable.')

pred_value = constant_value(pred)
if pred_value is not None:
if pred_value:
return fn1()
else:
return fn2()
else:
return control_flow_ops.cond(pred, true_fn=fn1, false_fn=fn2, name=name)
if isinstance(pred, variables.Variable):
return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn, name=name)
return control_flow_ops.smart_cond(pred, true_fn=true_fn, false_fn=false_fn, name=name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not make smart_cond work with Variables too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried but I found “The problem is that variables already import control_flow_ops and control_flow_ops cannot import variables. I already resolved it without loop dependencies.”. This is commented in the former comments due to a test error.
From the view the project structure, and also based on the original demands of boolean condition, I think it's ok to put it here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sguada That is to say, tensorflow/python/ops/variables.py is using cond function of tensorflow/python/ops/control_flow_ops.py. If we do need consider using Variables class of variables.py in control_flow_ops.py. We may consider two solutions without circular dependencies:

  1. Move smart_cond to some other places that's not depended by tensorflow/python/ops/variables.py and do that inclusion.
  2. Just import that specific class Variables from /tensorflow/python/ops.variables.py

Otherwise, I think current version can at least work well with the demands of #13903. Let me know what you think is right and I will go for implementing that.



def constant_value(pred):
"""Return the bool value for `pred`, or None if `pred` had a dynamic value.

Arguments:
pred: A scalar, either a Python bool or a TensorFlow boolean variable
or tensor, or the Python integer 1 or 0.
Arguments:
pred: A scalar, either a Python bool or a TensorFlow boolean variable
or tensor, or the Python integer 1 or 0.

Returns:
True or False if `pred` has a constant boolean value, None otherwise.
Returns:
True or False if `pred` has a constant boolean value, None otherwise.

Raises:
TypeError: If `pred` is not a Variable, Tensor or bool.
"""
Raises:
TypeError: If `pred` is not a Variable, Tensor or bool, or Python
interger 1 or 0.
"""
# Allow integer booleans.
if pred == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be after the Variable check? There's a danger you'll get back a boolean Tensor if pred is a Variable?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or alternatively, should we check for type(pred) is int before we do this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the original code during the review of the PR. I can do that as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I already did that. This PR took much more time than I thought before

pred = False
elif pred == 1:
pred = True

if isinstance(pred, bool):
pred_value = pred
elif isinstance(pred, variables.Variable):
pred_value = None
elif isinstance(pred, ops.Tensor):
pred_value = tensor_util.constant_value(pred)
else:
raise TypeError('`pred` must be a Tensor, a Variable, or a Python bool.')
return pred_value
if isinstance(pred, variables.Variable):
return None
return control_flow_ops.smart_constant_value(pred)


def object_list_uid(object_list):
Expand Down
56 changes: 56 additions & 0 deletions tensorflow/python/ops/control_flow_ops.py
Expand Up @@ -23,6 +23,7 @@
@@no_op
@@count_up_to
@@cond
@@smart_cond
@@case
@@while_loop
@@logical_and
Expand Down Expand Up @@ -2123,6 +2124,61 @@ def f2(): return tf.add(y, 23)
# pylint: enable=redefined-outer-name


def smart_cond(pred, true_fn=None, false_fn=None, name=None):
"""Return either `true_fn()` if predicate `pred` is true else `false_fn()`.

If `pred` is a bool or has a constant value, we return either `true_fn()`
or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.

Arguments:
pred: A scalar determining whether to return the result of `true_fn` or
`false_fn`.
true_fn: The callable to be performed if pred is true.
false_fn: The callable to be performed if pred is false.
name: Optional name prefix when using `tf.cond`.

Returns:
Tensors returned by the call to either `true_fn` or `false_fn`.

Raises:
TypeError: If `true_fn` or `false_fn` is not callable.
"""
if not callable(true_fn):
raise TypeError('`true_fn` must be callable.')
if not callable(false_fn):
raise TypeError('`false_fn` must be callable.')

pred_value = smart_constant_value(pred)
if pred_value is not None:
if pred_value:
return true_fn()
else:
return false_fn()
else:
return cond(pred, true_fn=true_fn, false_fn=false_fn, name=name)


def smart_constant_value(pred):
"""Return the bool value for `pred`, or None if `pred` had a dynamic value.

Arguments:
pred: A scalar, either a Python bool or tensor.

Returns:
True or False if `pred` has a constant boolean value, None otherwise.

Raises:
TypeError: If `pred` is not a Tensor or bool.
"""
if isinstance(pred, bool):
pred_value = pred
elif isinstance(pred, ops.Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call this smart_constant_value and add it to tf.contrib.framework?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree with you that the name is better and will submit a commit later.
I also considered whether to put this function in core or not and finally I found that the function was already used as internal interface in the core module before, I mean it's in tensorflow/python/layers/utils.py and used by some other modules like python/layers/normalization.py. What I did here should just make it public and reusable, so I didn't move it out of the core framework.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can continue to sit here, but for the public API expose it in contrib. The core API has much stricter behavior requirements and functions cannot be easily modified once exposed in core.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get to know what you mean and I agree. I tried to submit a commit for that.

pred_value = tensor_util.constant_value(pred)
else:
raise TypeError('`pred` must be a Tensor or a Python bool.')
return pred_value


def _resource_safe_shape(t):
"""Returns the shape of t or the variable it points to."""
if t.dtype == dtypes.resource:
Expand Down
38 changes: 38 additions & 0 deletions tensorflow/python/ops/control_flow_ops_test.py
Expand Up @@ -349,6 +349,44 @@ def testGradientThroughSingleBranchOutsideOfContext(self):
self.assertEquals(grad_x_false.eval(), 0.)


@test_util.with_c_api
class SmartCondTest(test_util.TensorFlowTestCase):

def testSmartCondTrue(self):
with ops.Graph().as_default():
with session.Session():
x = constant_op.constant(2)
y = constant_op.constant(5)
z = control_flow_ops.smart_cond(
True, lambda: math_ops.multiply(x, 16),
lambda: math_ops.multiply(y, 5))
self.assertEqual(z.eval(), 32)

def testSmartCondFalse(self):
with ops.Graph().as_default():
with session.Session():
x = constant_op.constant(4)
y = constant_op.constant(3)
z = control_flow_ops.smart_cond(
False, lambda: math_ops.multiply(x, 16),
lambda: math_ops.multiply(y, 3))
self.assertEqual(z.eval(), 9)

def testSmartCondMissingArg1(self):
with ops.Graph().as_default():
with session.Session():
x = constant_op.constant(1)
with self.assertRaises(TypeError):
control_flow_ops.smart_cond(True, false_fn=lambda: x)

def testSmartCondMissingArg2(self):
with ops.Graph().as_default():
with session.Session():
x = constant_op.constant(1)
with self.assertRaises(TypeError):
control_flow_ops.smart_cond(True, lambda: x)


@test_util.with_c_api
class CondTest(test_util.TensorFlowTestCase):

Expand Down