Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Clip by norm NaN gradients #22048

Closed
octavian-ganea opened this issue Sep 4, 2018 · 15 comments
Closed

[Bug] Clip by norm NaN gradients #22048

octavian-ganea opened this issue Sep 4, 2018 · 15 comments
Assignees
Labels
stat:awaiting response Status - Awaiting response from author

Comments

@octavian-ganea
Copy link

octavian-ganea commented Sep 4, 2018

a = tf.zeros([3], dtype=tf.float32)
b = tf.clip_by_norm(a, 1.)
c = tf.gradients(b,a)
s = tf.Session()
s.run(c)
[array([nan, nan, nan], dtype=float32)]

The gradient should obviously be [1,1,1] for all vectors a of norm smaller than 1, since this function should be the identity for those vectors.

Have I written custom code:
OS Platform and Distribution: Ubuntu 14.10
TensorFlow installed from: pip3
TensorFlow version: 1.10.1
Bazel version:
CUDA/cuDNN version:
GPU model and memory:
Exact command to reproduce: see above
Mobile device:

@octavian-ganea octavian-ganea changed the title Bug: Clip by norm NaN gradients [Bug] Clip by norm NaN gradients Sep 4, 2018
@tensorflowbutler tensorflowbutler added the stat:awaiting response Status - Awaiting response from author label Sep 4, 2018
@tensorflowbutler
Copy link
Member

Thank you for your post. We noticed you have not filled out the following field in the issue template. Could you update them if they are relevant in your case, or leave them as N/A? Thanks.
Have I written custom code
OS Platform and Distribution
TensorFlow installed from
TensorFlow version
Bazel version
CUDA/cuDNN version
GPU model and memory
Exact command to reproduce
Mobile device

@facaiy
Copy link
Member

facaiy commented Sep 7, 2018

According to its API

t * clip_norm / l2norm(t)

The gradient seems t / t * clip_norm, right ? I think the result is undefined when t == 0.

@alextp Could you please comment or reassign?

@octavian-ganea
Copy link
Author

No, this rescaling is only valid "If the L2-norm is greater than clip_norm" (see API), otherwise t should remain unchanged, so gradient should be all ones.
"if the L2-norm of t is already less than or equal to clip_norm, then t is not modified."

@facaiy
Copy link
Member

facaiy commented Sep 7, 2018

I think you're right.

@octavian-ganea
Copy link
Author

This is very similar with #20091 and it's essentially caused by the fact that the implementation of clip_by_norm requires computing the gradient w.r.t. tf.norm(t) even when |t| < clip_norm. I am not sure how this can be solved since TF uses static graphs. Would have been trivial in a framework that uses dynamic graphs ...

@facaiy
Copy link
Member

facaiy commented Sep 7, 2018

In [20]: a = tf.constant(1.0)
In [21]: b = tf.constant(2.0)
In [22]: c = tf.maximum(a, b)
In [23]: sess.run(tf.gradients(c, [a, b]))
Out[23]: [0.0, 1.0]

Because tf.maximum cannot block the gradient propagation for the smaller value (note that gradient w..rt. a is 0.0, rather than None), I have no idea of how to solve the problem. Let's wait for reply from @alextp .

@octavian-ganea
Copy link
Author

Yes, and similar:

a = tf.zeros([3], dtype=tf.float32)
b = tf.maximum(5., tf.norm(a))
c = tf.gradients(b, [a])
s = tf.Session()
s.run(c)

Gives
[array([nan, nan, nan], dtype=float32)]

@alextp alextp assigned alextp and unassigned jart Sep 7, 2018
@octavian-ganea
Copy link
Author

I am not sure this solves the issue. The code still fails for very small vectors, e.g.:
a = tf.ones([3], dtype=tf.float32) * 1e-20

@alextp
Copy link
Contributor

alextp commented Sep 13, 2018 via email

@octavian-ganea
Copy link
Author

Yes, I think it's called "dynamic graphs" :)

@alextp
Copy link
Contributor

alextp commented Sep 13, 2018 via email

@facaiy
Copy link
Member

facaiy commented Sep 13, 2018

I think the root is that tensorflow, in graph mode, cannot block the gradient backpropagation of a slice of Tensor, say tf.maximum op here. NaN propagation is really annoying, and two "where" trick hurts performance. Is there any solution to solve the problem totally?

@octavian-ganea
Copy link
Author

It is especially annoying since it is usually very hard to debug and doesn't have a "mathematical" cause. It would be nice to have a tool that would help TF users to understand when this issue is generating
their "NaN problem".

@alextp
Copy link
Contributor

alextp commented Sep 13, 2018 via email

@Nitinsiwach
Copy link

Nitinsiwach commented Apr 24, 2019

Any updates or suggestions here? I am running into this and do not know how to solve it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting response Status - Awaiting response from author
Projects
None yet
Development

No branches or pull requests

6 participants