Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derivative of pow(x, y) can give nan when x=0 for higher order derivatives #35011

Open
mjwatkins2 opened this issue Dec 11, 2019 · 10 comments
Open
Assignees
Labels
comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.12 For issues related to Tensorflow 2.12 type:bug Bug

Comments

@mjwatkins2
Copy link

System information

  • Have I written custom code: Example provided below
  • OS Platform and Distribution: Both Windows 10 and Google Colab
  • TensorFlow installed from binary
  • TensorFlow version: Both 2.0.0 and 2.1.0-rc0
  • Python version: 3.7.4

Describe the current behavior
The output of pow(x, y) is not correct for higher order derivatives when x is 0 and when y is other than a single value. For case 2 below I get:
[0 0 4]
[[-2 1 10]]
[[8 -nan 20]]
[[-18 -nan 30]]
[[24 -nan 24]]
[[0 -nan 0]]

For case 1 the nan only shows up in the 5th derivative, which is still not great, but not as bad:
[1 0 1]
[[-4 0 4]]
[[12 0 12]]
[[-24 0 24]]
[[24 24 24]]
[[-0 -nan 0]]

Describe the expected behavior
The output of this example for case 1 should be:
[1 0 1]
[[-4 0 4]]
[[12 0 12]]
[[-24 0 24]]
[[24 24 24]]
[[-0 0 0]]

The output of this example for case 2 should be:
[0 0 4]
[[-2 1 10]]
[[8 0 20]]
[[-18 0 30]]
[[24 0 24]]
[[0 0 0]]

Code to reproduce the issue

import tensorflow as tf

x = tf.Variable([[-1.0], [0.0], [1.0]])

with tf.GradientTape(persistent=True) as t:
    t.watch(x)
    # case 1: y = x^4
    # y = tf.reduce_sum(tf.pow(x, 4), axis=1) # gives nan for 5th derivative at x=0
    # case 2: y = x + x^2 + x^3 + x^4
    y = tf.reduce_sum(tf.pow(x, [[1, 2, 3, 4]]), axis=1) # gives nan for 2nd to 5th derivative at x=0
    dy_dx = t.gradient(y, x)
    d2y_dx2 = t.gradient(dy_dx, x)
    d3y_dx3 = t.gradient(d2y_dx2, x)
    d4y_dx4 = t.gradient(d3y_dx3, x)
    d5y_dx5 = t.gradient(d4y_dx4, x)
del t

tf.print(y)
tf.print(tf.transpose(dy_dx)) # transpose only to fit on one line when printed
tf.print(tf.transpose(d2y_dx2))
tf.print(tf.transpose(d3y_dx3))
tf.print(tf.transpose(d4y_dx4))
tf.print(tf.transpose(d5y_dx5))
@oanush oanush self-assigned this Dec 12, 2019
@oanush oanush added comp:ops OPs related issues TF 2.1 for tracking issues in 2.1 release type:bug Bug labels Dec 12, 2019
@oanush
Copy link

oanush commented Dec 12, 2019

Issue replicating in for given code, kindly find the gist of colab.Thanks!

@oanush oanush assigned ymodak and unassigned oanush Dec 12, 2019
@allenlavoie
Copy link
Member

Thank you for the report. I'm checking in a change that defines a gradient for 0 ** 0 wrt the base.

The higher-order gradient examples will still have NaNs. That's an interesting discussion we've had in the past about whether to "fix" backprop in cases like this by not propagating NaNs when multiplying by a zero derivative (so x ** 0 could "hide" NaNs from x ** -1). The conclusion so far has been that we should leave backprop dumb.

@mjwatkins2
Copy link
Author

I realize that it's preferable to avoid handling too many special cases, but it seems like some value/utility is lost here by not producing the correct higher-order gradients. The actual higher-order gradients are defined and are continuous, but those computed by TF have a discontinuity (at x = 0). It may be unlikely that an exact value of 0 will show up during typical training, but I ran into it and it sounds like it's happened before. What's the rationale on leaving backprop "dumb"?
Thanks!

@allenlavoie
Copy link
Member

@rmlarsen do you have more background on why we decided not to hide NaNs in backprop?

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@allenlavoie allenlavoie reopened this May 4, 2020
@allenlavoie
Copy link
Member

Reopening since I'm rolling back the fix (needs investigation).

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@allenlavoie allenlavoie reopened this May 4, 2020
@ymodak ymodak added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label May 19, 2021
@sushreebarsa
Copy link
Contributor

Was able to replicate the issue in TF v2.5 ,please find the gist here..Thanks !

@Saduf2019 Saduf2019 self-assigned this Aug 16, 2021
@ymodak ymodak removed their assignment Aug 19, 2021
@kumariko
Copy link

I could reproduce the issue with TF 2.6 . Please, find the gist here.Thanks!

@Saduf2019 Saduf2019 added the TF 2.7 Issues related to TF 2.7.0 label Dec 17, 2021
@Saduf2019 Saduf2019 removed their assignment Jan 10, 2022
@chunduriv chunduriv self-assigned this Jul 26, 2022
@chunduriv chunduriv removed TF 2.1 for tracking issues in 2.1 release TF 2.7 Issues related to TF 2.7.0 labels Jul 26, 2022
@chunduriv
Copy link
Contributor

chunduriv commented Jul 26, 2022

I was able to replicate the issue in tf-nightly 2.13.0-dev20230417. Please find the gist for reference. Thank you.

@chunduriv chunduriv added the TF 2.9 Issues found in the TF 2.9 release (or RCs) label Jul 26, 2022
@tilakrayal tilakrayal removed the TF 2.9 Issues found in the TF 2.9 release (or RCs) label Mar 6, 2023
@tilakrayal tilakrayal added the TF 2.11 Issues related to TF 2.11 label Mar 6, 2023
@tilakrayal tilakrayal assigned tilakrayal and unassigned chunduriv Apr 18, 2023
@tilakrayal tilakrayal added TF 2.12 For issues related to Tensorflow 2.12 and removed TF 2.11 Issues related to TF 2.11 labels Apr 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.12 For issues related to Tensorflow 2.12 type:bug Bug
Projects
None yet
Development

No branches or pull requests

9 participants