Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add broadcasting support for tf.where #15982

Merged

Conversation

yongtang
Copy link
Member

@yongtang yongtang commented Jan 9, 2018

Adds where_v2 (which will be where in TF 2.0), which has numpy's broadcasting semantics.

This fix fixes #9284.

Signed-off-by: Yong Tang yong.tang.github@outlook.com

@drpngx
Copy link
Contributor

drpngx commented Jan 9, 2018

@aselle WDYT?

@@ -2515,13 +2539,24 @@ def where(condition, x=None, y=None, name=None):
has the same shape as `x` and `y`, then it chooses which element to copy from
`x` and `y`.

If `broadcast` is True, then values of `x`, `y` and `condition` are
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds like the new behavior is backwards compatible? Why hide it behind a flag if it doesn't break existing usage?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ebrevdo. Initially I though the behaviors were different with respect to broadcast. Now think again it might be possible to not break the old behavior, while at the same time extend the broadcast 👍 . Let me take a look and update the PR.

@yongtang
Copy link
Member Author

@ebrevdo The PR has been updated with broadcast attribute removed. Please take a look.

y = np.ones((7, 11))
np_val = np.where(f < 0, x, y)
with self.test_session(use_gpu=True):
tf_val = array_ops.where(constant_op.constant(f) < 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to use:

self.evaluate(f < 0, x, y)

same below.

BCast bcast(BCast::FromShape(cond->shape()),
BCast::FromShape(then->shape()));

if (bcast.IsValid()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment about where this kicks in. we currently have scalar broadcasting and primary-dimension vector broadcasting and i'd like to know if this will kick in in one of those existing cases -- because it may affect performance.

@drpngx drpngx added the kokoro:force-run Tests on submitted change label Feb 11, 2018
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Feb 11, 2018
@yongtang
Copy link
Member Author

@ebrevdo The PR has been rebase to merge conflict.

However, after reviewing the code again, I realized that there is still one scenario where exiting tf.where is not compatible with np.where (below is the output before this PR):

$ python
Python 2.7.12 (default, Dec  4 2017, 14:50:18)
[GCC 5.4.0 20160609] on linux2
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow as tf
>>> import numpy as np
>>>
>>> x = np.arange(4)
>>> y = np.zeros((4, 4))
>>> z = np.ones((4, 4))
>>>
>>> np.where(x > 1, y, z)
array([[1., 1., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 0., 0.]])
>>> v = tf.where(x > 1, y, z)
>>> tf.Session().run(v)
array([[1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]])
>>>

As is shown from the above example, when x has the shape (4) and y/z has the shape (4, 4), the broadcast orientation is different.

Because of that, the current PR will fail several test cases.

Unfortunately, I couldn't think of a way to make the proposed broadcasting changes in tf.where backward-compatible.

I am wondering maybe it would be better to name a new op (e.g., tf.where_v2) so that it is compatible with numpy and not breaking existing users?

@carlthome
Copy link
Contributor

Status?

@velicue
Copy link

velicue commented Apr 19, 2018

Any updates? Also as in the issue someone mentioned:

There is also not full support for broadcasting in the other way, when condition is smaller than x or y; as indicated in the docs it only works if condition is a vector, but it could be extended to work with condition matching an arbitrary number of first dimensions of x and y (I'm not sure if this is considered also "broadcasting") and/or broadcasting singleton dimensions.

Would this also be implemented?

@ebrevdo
Copy link
Contributor

ebrevdo commented Apr 19, 2018

I think if you'd like the interface to exactly match np.where, it does make sense to create a new op + kernel, v2. you would want to expose it in tf.contrib somewhere (not in core).

@shoyer
Copy link
Contributor

shoyer commented Apr 20, 2018

"If x and y are vectors of higher rank, then condition must be either a vector with size matching the first dimension of x, or must have the same shape as x."

Could we deprecate this behavior from tf.where (start issuing FutureWarning) and change it in some future API breaking release?

To ease the transition, we could safely add broadcasting support when the total number of dimensions match between all arguments, or add a function to contrib (temporarily) with the appropriate behavior.

Deviating from NumPy's broadcasting rules feels like a design mistake to me, and I suspect this will be a repeated source of confusion in the future.

@yongtang yongtang force-pushed the 9284-tf.where-broadcasting branch 2 times, most recently from 67ddf92 to f97061b Compare May 29, 2018 00:03
@yongtang
Copy link
Member Author

Sorry for the delay. The PR has been updated. Now a new op tf.contrib.framework.where is exposed so that the broadcast rule follows numpy conventions. The original tf.where remains intact. Please take a look.

@yongtang yongtang force-pushed the 9284-tf.where-broadcasting branch from 73aa8da to 06037ac Compare July 2, 2018 18:33
@yongtang
Copy link
Member Author

yongtang commented Jul 3, 2018

I rebased the PR to resolve the merge conflict though it looks like there are some build failures after that. Will take a look and update the PR shortly to fix the build.

@charan223
Copy link

Status guys?

@yongtang yongtang force-pushed the 9284-tf.where-broadcasting branch 2 times, most recently from 5a80bd1 to c62c8ff Compare July 12, 2018 17:32
@yongtang
Copy link
Member Author

The PR has been rebased with build error fixed. All test passed now. Sorry for the long wait.

@rmlarsen
Copy link
Member

rmlarsen commented May 1, 2019

Looking now.

// 2-ary broadcast.

// Combine `then` and `else`.
BCast elem_bcast(BCast::FromShape(then->shape()),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe call this then_else_bcast?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @rmlarsen, the name has been changed.

@@ -324,10 +444,43 @@ struct BatchSelectFunctor<CPUDevice, T> {
}
};

template <typename T, int NDIMS>
struct BCastSelectFunctor<CPUDevice, T, NDIMS> {
Copy link
Member

@rmlarsen rmlarsen May 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this defined in multiple places? Can you just define this once in the header file and template it on device type as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @rmlarsen, the PR has been updated with two definitions (one for CPU and one for SYCL) consolidated into one.

…pecification

based on review comment.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
@rthadur rthadur requested a review from rmlarsen May 2, 2019 22:00
@yongtang yongtang added the kokoro:force-run Tests on submitted change label May 3, 2019
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label May 3, 2019
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
@yongtang yongtang force-pushed the 9284-tf.where-broadcasting branch from 4d710d9 to 33cd7b8 Compare May 3, 2019 02:13
@yongtang yongtang added the kokoro:force-run Tests on submitted change label May 3, 2019
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label May 3, 2019
@yongtang
Copy link
Member Author

yongtang commented May 3, 2019

@rmlarsen Thanks for the review. The PR has been updated. Please take a look and let me know if there are any issues.

@tensorflow-bot tensorflow-bot bot added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels May 3, 2019
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label May 3, 2019
@martinwicke
Copy link
Member

martinwicke commented May 3, 2019 via email

@tensorflow-copybara tensorflow-copybara merged commit 33cd7b8 into tensorflow:master May 7, 2019
PR Queue automation moved this from Reviewer Requested Changes to Merged May 7, 2019
@yongtang yongtang deleted the 9284-tf.where-broadcasting branch May 7, 2019 05:02
@alextp
Copy link
Contributor

alextp commented May 7, 2019

We had to roll this back because the tests had no adequate coverage and missed things such as the broadcasting selectv2 op having no gradient defined for it.

tensorflow-copybara pushed a commit that referenced this pull request May 7, 2019
@yongtang
Copy link
Member Author

yongtang commented May 7, 2019

@alextp sorry about that. I will take a look to add grad and resubmit the PR later.

@brianwa84
Copy link
Contributor

I have a suggestion for the gradient, haven't tested it but maybe it gets you started.

@ops.RegisterGradient("SelectV2")
def _SelectGrad(op, grad):
  c = op.inputs[0]
  x = op.inputs[1]
  y = op.inputs[2]
  zeros = array_ops.zeros([], dtype=grad.dtype.base_dtype)
  gx = array_ops.where_v2(c, grad, zeros)
  gx_shape = array_ops.shape(gx)
  x_shape = array_ops.shape(x)
  rankdiff_x = array_ops.rank(gx) - array_ops.rank(x)
  # Reduce away broadcasted leading dims.
  gx = math_ops.reduce_sum(gx, axis=math_ops.range(rankdiff_x))
  # Reduce but keep x's 1-valued dims which were broadcast.
  gx = math_ops.reduce_sum(
      gx, keepdims=1, axis=array_ops.where(grad_shape[rankdiff_x:] > x_shape))

  gy = array_ops.where_v2(c, zeros, grad)
  gy_shape = array_ops.shape(gy)
  y_shape = array_ops.shape(y)
  rankdiff_y = array_ops.rank(gy) - array_ops.rank(y)
  # Reduce away broadcasted leading dims.
  gy = math_ops.reduce_sum(gy, axis=math_ops.range(rankdiff_y))
  # Reduce but keep y's 1-valued dims which were broadcast.
  gy = math_ops.reduce_sum(
      gy, keepdims=1, axis=array_ops.where(grad_shape[rankdiff_y:] > y_shape))

  return (None, gx, gy)

yongtang added a commit to yongtang/tensorflow that referenced this pull request May 10, 2019
Credit to @brianwa84:

tensorflow#15982 (comment)

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
@tejaslodaya
Copy link

@yongtang any update on this? This was rolled back

@brianwa84
Copy link
Contributor

brianwa84 commented Dec 11, 2019 via email

@yongtang
Copy link
Member Author

@tejaslodaya It was rolled back, then rolled forward (with the help from @brianwa84 for providing the gradient op 👍 ❤️ ). It is now available in 2.0.

@tejaslodaya
Copy link

I tried and it works! For anyone coming to this PR, here's how you do it

Before (TF 1.x)-

with tf.Session() as sess:  
    col = tf.convert_to_tensor([1,2,3,4,5,6,7,8,9,10,11,12])    
    print(tf.where(tf.math.greater(col, 10),
                  tf.zeros_like(col),
                  tf.ones_like(col)).eval())

[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.]

After (TF 2.x)-

import tensorflow as tf
col = [1,2,3,4,5,6,7,8,9,10,11,12]
print(tf.where(tf.math.greater(col, 10),
              tf.zeros([1]),
               tf.ones([1])))

tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.], shape=(12,), dtype=float32)

Notice, I had to do zeros_like and broadcast it to the shape of the column to make it work in 1.x

Thanks @yongtang , great work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes contrib Anything that comes under contrib directory ready to pull PR ready for merge process size:L CL Change Size: Large
Projects
PR Queue
  
Merged
Development

Successfully merging this pull request may close these issues.

Broadcasting support in tf.where