-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add broadcasting support for tf.where
#15982
Add broadcasting support for tf.where
#15982
Conversation
@aselle WDYT? |
tensorflow/python/ops/array_ops.py
Outdated
@@ -2515,13 +2539,24 @@ def where(condition, x=None, y=None, name=None): | |||
has the same shape as `x` and `y`, then it chooses which element to copy from | |||
`x` and `y`. | |||
|
|||
If `broadcast` is True, then values of `x`, `y` and `condition` are |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It sounds like the new behavior is backwards compatible? Why hide it behind a flag if it doesn't break existing usage?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ebrevdo. Initially I though the behaviors were different with respect to broadcast
. Now think again it might be possible to not break the old behavior, while at the same time extend the broadcast
👍 . Let me take a look and update the PR.
758804d
to
c06f5f3
Compare
@ebrevdo The PR has been updated with |
y = np.ones((7, 11)) | ||
np_val = np.where(f < 0, x, y) | ||
with self.test_session(use_gpu=True): | ||
tf_val = array_ops.where(constant_op.constant(f) < 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should be able to use:
self.evaluate(f < 0, x, y)
same below.
BCast bcast(BCast::FromShape(cond->shape()), | ||
BCast::FromShape(then->shape())); | ||
|
||
if (bcast.IsValid()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a comment about where this kicks in. we currently have scalar broadcasting and primary-dimension vector broadcasting and i'd like to know if this will kick in in one of those existing cases -- because it may affect performance.
c06f5f3
to
2cceb0f
Compare
2cceb0f
to
07782e7
Compare
@ebrevdo The PR has been rebase to merge conflict. However, after reviewing the code again, I realized that there is still one scenario where exiting
As is shown from the above example, when x has the shape Because of that, the current PR will fail several test cases. Unfortunately, I couldn't think of a way to make the proposed broadcasting changes in I am wondering maybe it would be better to name a new op (e.g., |
Status? |
Any updates? Also as in the issue someone mentioned:
Would this also be implemented? |
I think if you'd like the interface to exactly match np.where, it does make sense to create a new op + kernel, v2. you would want to expose it in tf.contrib somewhere (not in core). |
"If x and y are vectors of higher rank, then condition must be either a vector with size matching the first dimension of x, or must have the same shape as x." Could we deprecate this behavior from To ease the transition, we could safely add broadcasting support when the total number of dimensions match between all arguments, or add a function to contrib (temporarily) with the appropriate behavior. Deviating from NumPy's broadcasting rules feels like a design mistake to me, and I suspect this will be a repeated source of confusion in the future. |
67ddf92
to
f97061b
Compare
Sorry for the delay. The PR has been updated. Now a new op |
73aa8da
to
06037ac
Compare
I rebased the PR to resolve the merge conflict though it looks like there are some build failures after that. Will take a look and update the PR shortly to fix the build. |
Status guys? |
5a80bd1
to
c62c8ff
Compare
The PR has been rebased with build error fixed. All test passed now. Sorry for the long wait. |
Looking now. |
// 2-ary broadcast. | ||
|
||
// Combine `then` and `else`. | ||
BCast elem_bcast(BCast::FromShape(then->shape()), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe call this then_else_bcast?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @rmlarsen, the name has been changed.
@@ -324,10 +444,43 @@ struct BatchSelectFunctor<CPUDevice, T> { | |||
} | |||
}; | |||
|
|||
template <typename T, int NDIMS> | |||
struct BCastSelectFunctor<CPUDevice, T, NDIMS> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this defined in multiple places? Can you just define this once in the header file and template it on device type as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @rmlarsen, the PR has been updated with two definitions (one for CPU and one for SYCL) consolidated into one.
…pecification based on review comment. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
4d710d9
to
33cd7b8
Compare
@rmlarsen Thanks for the review. The PR has been updated. Please take a look and let me know if there are any issues. |
Since this has been going on for so long, I'd favor a cherry pick for this.
…On Fri, May 3, 2019, 16:01 Alexandre Passos ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In tensorflow/python/ops/array_ops.py
<#15982 (comment)>
:
> @@ -3177,6 +3181,46 @@ def where(condition, x=None, y=None, name=None):
raise ValueError("x and y must both be non-None or both be None.")
***@***.***_export("where", v1=["where_v2"])
+def where_v2(condition, x=None, y=None, name=None):
Since I wrote this message the tf v2 API has been frozen for the 1.14
release. This means we'll need to export this symbol as where_v2 both in
tf1 and in tf2 :-/
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#15982 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AAEM57N3MKTMXHZBAVTHJQDPTS737ANCNFSM4ELCBNEQ>
.
|
We had to roll this back because the tests had no adequate coverage and missed things such as the broadcasting selectv2 op having no gradient defined for it. |
@alextp sorry about that. I will take a look to add grad and resubmit the PR later. |
I have a suggestion for the gradient, haven't tested it but maybe it gets you started. @ops.RegisterGradient("SelectV2")
def _SelectGrad(op, grad):
c = op.inputs[0]
x = op.inputs[1]
y = op.inputs[2]
zeros = array_ops.zeros([], dtype=grad.dtype.base_dtype)
gx = array_ops.where_v2(c, grad, zeros)
gx_shape = array_ops.shape(gx)
x_shape = array_ops.shape(x)
rankdiff_x = array_ops.rank(gx) - array_ops.rank(x)
# Reduce away broadcasted leading dims.
gx = math_ops.reduce_sum(gx, axis=math_ops.range(rankdiff_x))
# Reduce but keep x's 1-valued dims which were broadcast.
gx = math_ops.reduce_sum(
gx, keepdims=1, axis=array_ops.where(grad_shape[rankdiff_x:] > x_shape))
gy = array_ops.where_v2(c, zeros, grad)
gy_shape = array_ops.shape(gy)
y_shape = array_ops.shape(y)
rankdiff_y = array_ops.rank(gy) - array_ops.rank(y)
# Reduce away broadcasted leading dims.
gy = math_ops.reduce_sum(gy, axis=math_ops.range(rankdiff_y))
# Reduce but keep y's 1-valued dims which were broadcast.
gy = math_ops.reduce_sum(
gy, keepdims=1, axis=array_ops.where(grad_shape[rankdiff_y:] > y_shape))
return (None, gx, gy) |
Credit to @brianwa84: tensorflow#15982 (comment) Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
@yongtang any update on this? This was rolled back |
It's supported in tf.compat.v2.where
…On Wed, Dec 11, 2019, 7:43 AM Tejas Lodaya ***@***.***> wrote:
@yongtang <https://github.com/yongtang> any update on this? This was
rolled back
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#15982?email_source=notifications&email_token=AFJFSIY2TJ6NWWAGUSFZJDLQYDN5VA5CNFSM4ELCBNE2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEGS672A#issuecomment-564522984>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AFJFSI3P5IYRBTGAJK6GOSLQYDN5VANCNFSM4ELCBNEQ>
.
|
@tejaslodaya It was rolled back, then rolled forward (with the help from @brianwa84 for providing the gradient op 👍 ❤️ ). It is now available in 2.0. |
I tried and it works! For anyone coming to this PR, here's how you do it Before (TF 1.x)- with tf.Session() as sess:
col = tf.convert_to_tensor([1,2,3,4,5,6,7,8,9,10,11,12])
print(tf.where(tf.math.greater(col, 10),
tf.zeros_like(col),
tf.ones_like(col)).eval())
After (TF 2.x)- import tensorflow as tf
col = [1,2,3,4,5,6,7,8,9,10,11,12]
print(tf.where(tf.math.greater(col, 10),
tf.zeros([1]),
tf.ones([1])))
Notice, I had to do zeros_like and broadcast it to the shape of the column to make it work in 1.x Thanks @yongtang , great work! |
Adds where_v2 (which will be where in TF 2.0), which has numpy's broadcasting semantics.
This fix fixes #9284.
Signed-off-by: Yong Tang yong.tang.github@outlook.com