Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong derivatives for complex second order derivatives. #27845

Closed
proteneer opened this issue Apr 15, 2019 · 5 comments
Closed

Wrong derivatives for complex second order derivatives. #27845

proteneer opened this issue Apr 15, 2019 · 5 comments
Assignees
Labels
comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower type:bug Bug

Comments

@proteneer
Copy link

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): No
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): OSX
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: N/A
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): tensorflow==1.12.0
  • Python version: 3.6.8

You can collect some of this information using our environment capture script
You can also obtain the TensorFlow version with
python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"

Describe the current behavior

Derivatives of non-holomorphic functions are incorrect when compared both against AD and finite differences.

Describe the expected behavior

Derivatives of non-holomorphic functions should becorrect.

Code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate the problem.

import numpy as onp
import autograd as ag
import autograd.numpy as anp
import numpy as onp
import tensorflow as tf

inp = anp.array(2.0)

print("input", inp)

def ag_fn(x):
    real = anp.cos(x+2)
    imag = anp.sin(x-1)
    return anp.abs(real+1j*imag)

ag_hess = ag.hessian(ag_fn)

print("ag val:", ag_fn(inp))
print("ag hess:", ag_hess(inp))

def tf_fn(x):
    real = tf.cos(x+2)
    imag = tf.sin(x-1)
    return tf.abs(tf.complex(real, imag))

# tf_inp = tf.convert_to_tensor(inp)
tf_inp = tf.placeholder(shape=tuple(), dtype=onp.float64)

out_op = tf_fn(tf_inp)

tf_grad = tf.gradients(out_op, tf_inp)[0]
tf_hess = tf.hessians(out_op, tf_inp)[0]

sess = tf.Session()
delta = 1e-7

_, d0, tf_ad = sess.run([out_op, tf_grad, tf_hess], feed_dict={tf_inp: inp})
_, d1, _ = sess.run([out_op, tf_grad, tf_hess], feed_dict={tf_inp: inp+delta})

print("tf_numerical derivative:", (d1-d0)/delta)
print("tf_autodiff derivative:", tf_ad)
input 2.0
ag val: 1.0655155566059393
ag hess: -0.25533014019223726
2019-04-14 22:55:43.481283: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
tf_numerical derivative: -0.25533013481293665
tf_autodiff derivative: -1.0655155566059389

Other info / logs
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

Additional information: google/jax#603

@Squadrick
Copy link
Member

Ran this on tensorflow==2.0.0-dev20190327 and I get the same incorrect output.

@ymodak
Copy link
Contributor

ymodak commented Apr 20, 2019

Thanks for the minimal code snippet to reproduce the issue. I was able to reproduce the behavior in TF 1.13 and latest nightly build.

@ymodak ymodak assigned alextp and unassigned ymodak Apr 20, 2019
@ymodak ymodak added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Apr 20, 2019
@alextp
Copy link
Contributor

alextp commented Apr 22, 2019

Thanks for filing the issue!

If I replace tf.abs on your example with a manual implementation (tf.sqrt(real(x)*real(x) + imag(x)*imag(x))) the values are identical, so I think this is a problem with the gradient for the ComplexAbs op.

@tensorflow-bot
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@proteneer
Copy link
Author

Thanks for fixing this guys!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower type:bug Bug
Projects
None yet
Development

No branches or pull requests

5 participants