Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in gradient of reduce_prod #2641

Closed
RafaelCosman opened this issue Jun 3, 2016 · 15 comments
Closed

Error in gradient of reduce_prod #2641

RafaelCosman opened this issue Jun 3, 2016 · 15 comments
Labels
stat:contribution welcome Status - Contributions welcome type:bug Bug

Comments

@RafaelCosman
Copy link

vars = tf.Variable([1., 2.])
tf.initialize_all_variables().run()
tf.gradients(tf.reduce_prod(vars), vars)[0].eval()

yields [ 2., 1.] which is correct. But

vars = tf.Variable([0., 2.])
tf.initialize_all_variables().run()
tf.gradients(tf.reduce_prod(vars), vars)[0].eval()

yields [ nan, 0.] which is incorrect. The correct gradient is [ 2., 0.]

@RafaelCosman RafaelCosman changed the title Error in gradient of reduce_prod Error in gradient of reduce_prod Jun 3, 2016
@RafaelCosman
Copy link
Author

Note that this is not an issue with regular TensorFlow multiplication. The following works correctly:

var1 = tf.Variable(0.)
var2 = tf.Variable(2.)
tf.initialize_all_variables().run()
print tf.gradients(var1 * var2, var1)[0].eval()
print tf.gradients(var1 * var2, var2)[0].eval()

It prints 2.0 and 0.0 which is correct.

@aselle aselle added bug labels Jun 3, 2016
@aselle
Copy link
Contributor

aselle commented Jun 3, 2016

I have reproduced this.

@RafaelCosman
Copy link
Author

RafaelCosman commented Jun 3, 2016

Thanks @aselle.

Idea: This could be caused if the gradients of reduce_prod(.) are computed like the gradients of tf.exp(tf.reduce_sum(tf.log(.)))

@aselle aselle assigned aselle and girving and unassigned aselle Jun 3, 2016
@girving
Copy link
Contributor

girving commented Jun 3, 2016

The gradients are computed by computing the full prod and then dividing, which is broken as you point out. For example, for a length two produce x * y the gradient w.r.t. x is x * y / y. Indeed, there's an ancient TODO in the code that it is broken.

@ibab, @benoitsteiner: The easiest way to fix this would be to do two scan products to get the sequence of partial products from both directions, then multiply them together to get all products with one element removed. Unfortunately, we don't yet have scan.

@girving girving added the stat:contribution welcome Status - Contributions welcome label Jun 7, 2016
@girving
Copy link
Contributor

girving commented Jun 7, 2016

Adding contributions welcome, but note that it'll have to wait until after #2711.

@ibab
Copy link
Contributor

ibab commented Jun 29, 2016

I've had a shot at solving this today using cumprod, and it's working great for full reductions.
But if reduce_prod is performed with reduction_indices, things get more difficult.
In that case, the cumprod needs to be performed over the reduced dimensions, and not over the remaining ones.
Any ideas on how this could be solved?
Maybe if there was a tf.reshape_selected op that allows you to group all selected dimensions into a single dimension, and all remaining ones into a second one, like this:

x.shape = (a, b, c, d, e)
reduction_indices = [1, 3]
y = tf.reshape_selected(x, reduction_indices)
y.shape ==> (b * d, a * c * e)

Then the cumprod could be performed over the first dimension.
This is essentially a transpose followed by a reshape.
I've tried to build this using existing ops, but it seems that tf.cond is required to act based on the contents of the reduction_indices tensor.

@girving
Copy link
Contributor

girving commented Jun 29, 2016

@ibab It does seem like it has to be transpose + reshape + stuff + reshape + transpose. I don't think a custom tf.reshape_selected makes sense: separate transpose and reshape is cleaner especially since you have to invert it. The tf.cond is indeed necessary in general. This is getting ugly, but I don't know a cleaner way.

@ibab
Copy link
Contributor

ibab commented Jun 29, 2016

I think tf.listdiff can be used to separate the axis indices into reduced and non-reduced parts:

idx = tf.range(0, dims)
nonreduced = tf.listdiff(idx, reduced)

If I concatenate both of these, I can use them as the permutation in tf.transpose.
I can get the sizes of the two dimensions using tf.segment_prod or tf.gather and tf.reduce_prod.
So maybe this can be solved without resorting to tf.cond, but it's certainly not beautiful.

@benoitsteiner
Copy link
Contributor

@ibab Have you considered extending the cumsum operation to take a list of indices instead of a single index over which to sum? That should make the gradient computation for reduce_prod a lot simpler.

@girving
Copy link
Contributor

girving commented Jun 29, 2016

@benoitsteiner The awkward thing about that is that cumprod on multiple axes at a time is a very strange operation. It would implicitly flatten and expand, which would mean you have to do the same implicit flattening at the Eigen level.

@ibab
Copy link
Contributor

ibab commented Jun 29, 2016

@benoitsteiner: That's a good idea!
I just realized that I could make use of tensor.shuffle from Eigen to do the same operations as described above (permute the axes so that the scan axes are in the front, and then reshape).
This would also avoid copying the tensor, as shuffle is implemented through .coeff.
@girving: My first impression was also that it would be awkward, but now I'm thinking that it could make sense: The list of indices identifies the sequence in which cumsum should traverse each sub-tensor. It also plays well with the existing exclusive and reverse options.

I'm open to using either solution (making the gradient more complicated, or making cumsum/prod more complicated)

@ibab
Copy link
Contributor

ibab commented Jun 30, 2016

For reference, here's the gradient implementation I came up with:

@ops.RegisterGradient("Prod")
def _ProdGrad(op, grad):
  """Gradient for Prod."""
  input_shape = array_ops.shape(op.inputs[0])

  # Expand grad to full input shape
  output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
  tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
  grad = array_ops.reshape(grad, output_shape_kept_dims)
  grad = array_ops.tile(grad, tile_scaling)

  # If the list is empty, it defaults to float32
  reduced = math_ops.cast(op.inputs[1], dtypes.int32)
  idx = math_ops.range(0, array_ops.rank(op.inputs[0]))
  other, _ = array_ops.listdiff(idx, reduced)
  perm = array_ops.concat(0, [reduced, other])
  reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced))
  other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other))
  permuted = array_ops.transpose(op.inputs[0], perm)
  permuted_shape = array_ops.shape(permuted)
  reshaped = array_ops.reshape(permuted, (reduced_num, other_num))

  # Calculate product, leaving out the current entry
  left = math_ops.cumprod(reshaped, axis=0, exclusive=True)
  right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True)
  y = array_ops.reshape(left * right, permuted_shape)

  out = grad * array_ops.transpose(y, array_ops.invert_permutation(perm))
  # Reset statically known shape information
  return array_ops.reshape(out, input_shape), None

This makes the existing tests pass and also works if there are zeros in the input array.

@ibab
Copy link
Contributor

ibab commented Jul 18, 2016

This has been fixed by #3351, so this issue can be closed.

@vrv vrv closed this as completed Jul 18, 2016
@aselle aselle added type:bug Bug and removed bug labels Feb 9, 2017
@pvanhaes
Copy link

The gradient of reduce_prod does not support negative axis unlike reduce_prod itself.
It is caused by gather not supporting negative axes.
This code illustrates the problem.

import tensorflow as tf

vars = tf.Variable([[1., 2.], [3., 4.]])
prod = tf.reduce_prod(vars, -1) # Negative axis here

tf.InteractiveSession()
tf.global_variables_initializer().run()
print(prod.eval()) # Works fine
print(tf.gradients(prod, vars)[0].eval()) # Crashes

@girving
Copy link
Contributor

girving commented Jun 19, 2017

@pvanhaes Thanks for the bug report! Can you file it as a separate issue, since it's unrelated to the current thread? It helps us to keep Github issues organized.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:contribution welcome Status - Contributions welcome type:bug Bug
Projects
None yet
Development

No branches or pull requests

7 participants