Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

incorrect gradient of reduce_prod(tf.complex*) #12514

Closed
kjslag opened this issue Aug 23, 2017 · 12 comments
Closed

incorrect gradient of reduce_prod(tf.complex*) #12514

kjslag opened this issue Aug 23, 2017 · 12 comments
Assignees

Comments

@kjslag
Copy link
Contributor

kjslag commented Aug 23, 2017

Describe the problem

Tensorflow computes the wrong result for the following gradient:

import tensorflow as tf
x = tf.Variable(1.0)
E = tf.real(tf.reduce_prod(tf.complex( [x,x], [2*x,2*x] )))
sess = tf.Session()
sess.run(tf.variables_initializer([x]))
sess.run(tf.gradients(E,x))

Tensorflow returns 10.0
The correct result is -6 since:

E = real((x+2i*x)^2) = real((1+2i)^2) * x^2 = real(1+4i-4) * x^2 = -3*x^2
dE/dx = -6*x = -6 for x=1

Below is mathematically equivalent code for E, for which Tensorflow returns the correct result of -6.0:

E = tf.real( tf.complex(x,2*x) * tf.complex(x,2*x) )
E = tf.real(tf.exp(tf.reduce_sum(tf.log(tf.complex( [x,x], [2*x,2*x] )))))

System information

Linux distribution = Arch Linux (up to date)
TensorFlow was installed from the Arch Linux package python-tensorflow
I'm using an x86_64 CPU. I'm not using my GPU.
numpy (1.13.1)
protobuf (3.3.2)
tensorflow (1.3.0)
python (3.6.2)

@tensorflowbutler
Copy link
Member

It has been 14 days with no activity and this issue has an assignee.Please update the label and/or status accordingly.

1 similar comment
@tensorflowbutler
Copy link
Member

It has been 14 days with no activity and this issue has an assignee.Please update the label and/or status accordingly.

@tensorflowbutler
Copy link
Member

Nagging Assignee: It has been 14 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

2 similar comments
@tensorflowbutler
Copy link
Member

Nagging Assignee: It has been 14 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

@tensorflowbutler
Copy link
Member

Nagging Assignee: It has been 14 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

@tensorflowbutler
Copy link
Member

Nagging Assignee @rmlarsen: It has been 14 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

3 similar comments
@tensorflowbutler
Copy link
Member

Nagging Assignee @rmlarsen: It has been 14 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

@tensorflowbutler
Copy link
Member

Nagging Assignee @rmlarsen: It has been 14 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

@tensorflowbutler
Copy link
Member

Nagging Assignee @rmlarsen: It has been 14 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

@woodshop
Copy link

woodshop commented Apr 24, 2018

@rmlarsen Any timeline on when this might be looked at?
@brianwa84 - you might be interested in this

For others that run into this issue, you can use gradient_override_map. Ex (please note that this has not been unit tested):

import tensorflow as tf
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.framework import ops
from tensorflow.python.framework import dtypes
from tensorflow.python.ops.math_grad import _safe_shape_div

@tf.RegisterGradient("ModifiedProdGrad")
def _ModifiedProdGrad(op, grad):
    """Gradient for Prod."""
    # The gradient can be expressed by dividing the product by each entry of the
    # input tensor, but this approach can't deal with zeros in the input.
    # Here, we avoid this problem by composing the output as a product of two
    # cumprod operations.

    input_shape = array_ops.shape(op.inputs[0])
    # Reshape reduction indices for the case where the parameter is a scalar
    reduction_indices = array_ops.reshape(op.inputs[1], [-1])

    # Expand grad to full input shape
    output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
    tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
    grad = array_ops.reshape(grad, output_shape_kept_dims)
    grad = array_ops.tile(grad, tile_scaling)

    # Pack all reduced dimensions into a single one, so we can perform the
    # cumprod ops. If the reduction dims list is empty, it defaults to float32,
    # so we need to cast here.  We put all the shape-related ops on CPU to avoid
    # copying back and forth, and since listdiff is CPU only.
    with ops.device("/cpu:0"):
        rank = array_ops.rank(op.inputs[0])
        reduction_indices = (reduction_indices + rank) % rank
        reduced = math_ops.cast(reduction_indices, dtypes.int32)
        idx = math_ops.range(0, rank)
        other, _ = array_ops.setdiff1d(idx, reduced)
        perm = array_ops.concat([reduced, other], 0)
        reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced))
        other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other))
    permuted = array_ops.transpose(op.inputs[0], perm)
    permuted_shape = array_ops.shape(permuted)
    reshaped = array_ops.reshape(permuted, (reduced_num, other_num))

    # Calculate product, leaving out the current entry
    left = math_ops.cumprod(reshaped, axis=0, exclusive=True)
    right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True)
    y = array_ops.reshape(tf.conj(left) * tf.conj(right), permuted_shape)

    # Invert the transpose and reshape operations.
    # Make sure to set the statically known shape information through a reshape.
    out = grad * array_ops.transpose(y, array_ops.invert_permutation(perm))
    return array_ops.reshape(out, input_shape), None

With TF gradient:

with tf.Graph().as_default() as g:
    x = tf.Variable(1.0)
    E = tf.real(tf.reduce_prod(tf.complex( [x,x], [2*x,2*x] )))
    with tf.Session() as sess:
        sess.run(tf.variables_initializer([x]))
        print(sess.run(tf.gradients(E,x)))

>>> [10.0]

With modified gradient:

with tf.Graph().as_default() as g:
    with g.gradient_override_map({"Prod": "ModifiedProdGrad"}):
        x = tf.Variable(1.0)
        E = tf.real(tf.reduce_prod(tf.complex( [x,x], [2*x,2*x] )))
        with tf.Session() as sess:
            sess.run(tf.variables_initializer([x]))
            print(sess.run(tf.gradients(E,x)))

>>> [-6.0]

@woodshop
Copy link

Also note that the title of this issue should be "incorrect gradient of reduce_prod"

@tensorflowbutler
Copy link
Member

Nagging Assignee @rmlarsen: It has been 15 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

@brianwa84 brianwa84 changed the title incorrect gradient of real(reduce_prod(complex(...))) incorrect gradient of reduce_prod(tf.complex*) May 16, 2018
@brianwa84 brianwa84 assigned brianwa84 and unassigned rmlarsen May 16, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants