-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error in gradient of reduce_prod #2641
Comments
reduce_prod
Note that this is not an issue with regular TensorFlow multiplication. The following works correctly:
It prints |
I have reproduced this. |
Thanks @aselle. Idea: This could be caused if the gradients of |
The gradients are computed by computing the full prod and then dividing, which is broken as you point out. For example, for a length two produce @ibab, @benoitsteiner: The easiest way to fix this would be to do two scan products to get the sequence of partial products from both directions, then multiply them together to get all products with one element removed. Unfortunately, we don't yet have scan. |
Adding contributions welcome, but note that it'll have to wait until after #2711. |
I've had a shot at solving this today using
Then the |
@ibab It does seem like it has to be transpose + reshape + stuff + reshape + transpose. I don't think a custom |
I think idx = tf.range(0, dims)
nonreduced = tf.listdiff(idx, reduced) If I concatenate both of these, I can use them as the permutation in |
@ibab Have you considered extending the cumsum operation to take a list of indices instead of a single index over which to sum? That should make the gradient computation for reduce_prod a lot simpler. |
@benoitsteiner The awkward thing about that is that cumprod on multiple axes at a time is a very strange operation. It would implicitly flatten and expand, which would mean you have to do the same implicit flattening at the Eigen level. |
@benoitsteiner: That's a good idea! I'm open to using either solution (making the gradient more complicated, or making cumsum/prod more complicated) |
For reference, here's the gradient implementation I came up with: @ops.RegisterGradient("Prod")
def _ProdGrad(op, grad):
"""Gradient for Prod."""
input_shape = array_ops.shape(op.inputs[0])
# Expand grad to full input shape
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
grad = array_ops.reshape(grad, output_shape_kept_dims)
grad = array_ops.tile(grad, tile_scaling)
# If the list is empty, it defaults to float32
reduced = math_ops.cast(op.inputs[1], dtypes.int32)
idx = math_ops.range(0, array_ops.rank(op.inputs[0]))
other, _ = array_ops.listdiff(idx, reduced)
perm = array_ops.concat(0, [reduced, other])
reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced))
other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other))
permuted = array_ops.transpose(op.inputs[0], perm)
permuted_shape = array_ops.shape(permuted)
reshaped = array_ops.reshape(permuted, (reduced_num, other_num))
# Calculate product, leaving out the current entry
left = math_ops.cumprod(reshaped, axis=0, exclusive=True)
right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True)
y = array_ops.reshape(left * right, permuted_shape)
out = grad * array_ops.transpose(y, array_ops.invert_permutation(perm))
# Reset statically known shape information
return array_ops.reshape(out, input_shape), None This makes the existing tests pass and also works if there are zeros in the input array. |
This has been fixed by #3351, so this issue can be closed. |
The gradient of import tensorflow as tf
vars = tf.Variable([[1., 2.], [3., 4.]])
prod = tf.reduce_prod(vars, -1) # Negative axis here
tf.InteractiveSession()
tf.global_variables_initializer().run()
print(prod.eval()) # Works fine
print(tf.gradients(prod, vars)[0].eval()) # Crashes |
@pvanhaes Thanks for the bug report! Can you file it as a separate issue, since it's unrelated to the current thread? It helps us to keep Github issues organized. |
yields
[ 2., 1.]
which is correct. Butyields
[ nan, 0.]
which is incorrect. The correct gradient is[ 2., 0.]
The text was updated successfully, but these errors were encountered: