New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
make "smart_cond" api public and reusable #13954
Changes from 16 commits
4dcdd3b
9d061f6
53b970c
eea7913
7f982c1
bfad536
4b1ac6d
42b0d77
0519f21
7f4b68c
90aaa62
f776955
b982390
8935767
c7c1b42
a928c4c
310b2c6
8d792ce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -178,67 +178,53 @@ def deconv_output_length(input_length, filter_size, padding, stride): | |
return input_length | ||
|
||
|
||
def smart_cond(pred, fn1, fn2, name=None): | ||
"""Return either `fn1()` or `fn2()` based on the boolean predicate `pred`. | ||
def smart_cond(pred, true_fn=None, false_fn=None, name=None): | ||
"""Return either `true_fn()` if predicate `pred` is true else `false_fn()`. | ||
|
||
If `pred` is a bool or has a constant value, we return either `fn1()` | ||
or `fn2()`, otherwise we use `tf.cond` to dynamically route to both. | ||
If `pred` is a bool or has a constant value, we return either `true_fn()` | ||
or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. | ||
|
||
Arguments: | ||
pred: A scalar determining whether to return the result of `fn1` or `fn2`. | ||
fn1: The callable to be performed if pred is true. | ||
fn2: The callable to be performed if pred is false. | ||
pred: A scalar determining whether to return the result of `true_fn` or | ||
`false_fn`. | ||
true_fn: The callable to be performed if pred is true. | ||
false_fn: The callable to be performed if pred is false. | ||
name: Optional name prefix when using `tf.cond`. | ||
|
||
Returns: | ||
Tensors returned by the call to either `fn1` or `fn2`. | ||
Tensors returned by the call to either `true_fn` or `false_fn`. | ||
|
||
Raises: | ||
TypeError: If `fn1` or `fn2` is not callable. | ||
TypeError: If `true_fn` or `false_fn` is not callable. | ||
""" | ||
if not callable(fn1): | ||
raise TypeError('`fn1` must be callable.') | ||
if not callable(fn2): | ||
raise TypeError('`fn2` must be callable.') | ||
|
||
pred_value = constant_value(pred) | ||
if pred_value is not None: | ||
if pred_value: | ||
return fn1() | ||
else: | ||
return fn2() | ||
else: | ||
return control_flow_ops.cond(pred, true_fn=fn1, false_fn=fn2, name=name) | ||
if isinstance(pred, variables.Variable): | ||
return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn, name=name) | ||
return control_flow_ops.smart_cond(pred, true_fn=true_fn, false_fn=false_fn, name=name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not make smart_cond work with Variables too? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried but I found “The problem is that variables already import control_flow_ops and control_flow_ops cannot import variables. I already resolved it without loop dependencies.”. This is commented in the former comments due to a test error. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sguada That is to say,
Otherwise, I think current version can at least work well with the demands of #13903. Let me know what you think is right and I will go for implementing that. |
||
|
||
|
||
def constant_value(pred): | ||
"""Return the bool value for `pred`, or None if `pred` had a dynamic value. | ||
|
||
Arguments: | ||
pred: A scalar, either a Python bool or a TensorFlow boolean variable | ||
or tensor, or the Python integer 1 or 0. | ||
Arguments: | ||
pred: A scalar, either a Python bool or a TensorFlow boolean variable | ||
or tensor, or the Python integer 1 or 0. | ||
|
||
Returns: | ||
True or False if `pred` has a constant boolean value, None otherwise. | ||
Returns: | ||
True or False if `pred` has a constant boolean value, None otherwise. | ||
|
||
Raises: | ||
TypeError: If `pred` is not a Variable, Tensor or bool. | ||
""" | ||
Raises: | ||
TypeError: If `pred` is not a Variable, Tensor or bool, or Python | ||
interger 1 or 0. | ||
""" | ||
# Allow integer booleans. | ||
if pred == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be after the Variable check? There's a danger you'll get back a boolean Tensor if pred is a Variable? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or alternatively, should we check for type(pred) is int before we do this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the original code during the review of the PR. I can do that as well There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I already did that. This PR took much more time than I thought before |
||
pred = False | ||
elif pred == 1: | ||
pred = True | ||
|
||
if isinstance(pred, bool): | ||
pred_value = pred | ||
elif isinstance(pred, variables.Variable): | ||
pred_value = None | ||
elif isinstance(pred, ops.Tensor): | ||
pred_value = tensor_util.constant_value(pred) | ||
else: | ||
raise TypeError('`pred` must be a Tensor, a Variable, or a Python bool.') | ||
return pred_value | ||
if isinstance(pred, variables.Variable): | ||
return None | ||
return control_flow_ops.smart_constant_value(pred) | ||
|
||
|
||
def object_list_uid(object_list): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
@@no_op | ||
@@count_up_to | ||
@@cond | ||
@@smart_cond | ||
@@case | ||
@@while_loop | ||
@@logical_and | ||
|
@@ -2123,6 +2124,61 @@ def f2(): return tf.add(y, 23) | |
# pylint: enable=redefined-outer-name | ||
|
||
|
||
def smart_cond(pred, true_fn=None, false_fn=None, name=None): | ||
"""Return either `true_fn()` if predicate `pred` is true else `false_fn()`. | ||
|
||
If `pred` is a bool or has a constant value, we return either `true_fn()` | ||
or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. | ||
|
||
Arguments: | ||
pred: A scalar determining whether to return the result of `true_fn` or | ||
`false_fn`. | ||
true_fn: The callable to be performed if pred is true. | ||
false_fn: The callable to be performed if pred is false. | ||
name: Optional name prefix when using `tf.cond`. | ||
|
||
Returns: | ||
Tensors returned by the call to either `true_fn` or `false_fn`. | ||
|
||
Raises: | ||
TypeError: If `true_fn` or `false_fn` is not callable. | ||
""" | ||
if not callable(true_fn): | ||
raise TypeError('`true_fn` must be callable.') | ||
if not callable(false_fn): | ||
raise TypeError('`false_fn` must be callable.') | ||
|
||
pred_value = smart_constant_value(pred) | ||
if pred_value is not None: | ||
if pred_value: | ||
return true_fn() | ||
else: | ||
return false_fn() | ||
else: | ||
return cond(pred, true_fn=true_fn, false_fn=false_fn, name=name) | ||
|
||
|
||
def smart_constant_value(pred): | ||
"""Return the bool value for `pred`, or None if `pred` had a dynamic value. | ||
|
||
Arguments: | ||
pred: A scalar, either a Python bool or tensor. | ||
|
||
Returns: | ||
True or False if `pred` has a constant boolean value, None otherwise. | ||
|
||
Raises: | ||
TypeError: If `pred` is not a Tensor or bool. | ||
""" | ||
if isinstance(pred, bool): | ||
pred_value = pred | ||
elif isinstance(pred, ops.Tensor): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. call this smart_constant_value and add it to tf.contrib.framework? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I agree with you that the name is better and will submit a commit later. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It can continue to sit here, but for the public API expose it in contrib. The core API has much stricter behavior requirements and functions cannot be easily modified once exposed in core. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I get to know what you mean and I agree. I tried to submit a commit for that. |
||
pred_value = tensor_util.constant_value(pred) | ||
else: | ||
raise TypeError('`pred` must be a Tensor or a Python bool.') | ||
return pred_value | ||
|
||
|
||
def _resource_safe_shape(t): | ||
"""Returns the shape of t or the variable it points to.""" | ||
if t.dtype == dtypes.resource: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't you just do:
smart_cond = control_flow_ops.smart_cond
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did consider that before, but I think current way could make the user/developer access the function doc more easily. I also referred to some other wrapper codes in the project such as https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/tpu/python/tpu/tpu.py#L499.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't the doc be the same, regardless of whether you copy it over or just use a function alias? The docstring copy is a DRY violation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got it. Thanks for pointing out. I have submitted a commit to change to function alias assignment.