-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] Clip by norm NaN gradients #22048
Comments
Thank you for your post. We noticed you have not filled out the following field in the issue template. Could you update them if they are relevant in your case, or leave them as N/A? Thanks. |
No, this rescaling is only valid "If the L2-norm is greater than clip_norm" (see API), otherwise t should remain unchanged, so gradient should be all ones. |
I think you're right. |
This is very similar with #20091 and it's essentially caused by the fact that the implementation of clip_by_norm requires computing the gradient w.r.t. tf.norm(t) even when |t| < clip_norm. I am not sure how this can be solved since TF uses static graphs. Would have been trivial in a framework that uses dynamic graphs ... |
In [20]: a = tf.constant(1.0)
In [21]: b = tf.constant(2.0)
In [22]: c = tf.maximum(a, b)
In [23]: sess.run(tf.gradients(c, [a, b]))
Out[23]: [0.0, 1.0] Because tf.maximum cannot block the gradient propagation for the smaller value (note that gradient w..rt. a is 0.0, rather than None), I have no idea of how to solve the problem. Let's wait for reply from @alextp . |
Yes, and similar:
Gives |
I am not sure this solves the issue. The code still fails for very small vectors, e.g.: |
The general solution of this issue is somewhat harder.
…On Sun, Sep 9, 2018 at 9:01 AM Octavian Ganea ***@***.***> wrote:
I am not sure this solves the issue. The code still fails for very small
vectors, e.g.:
a = tf.ones([3], dtype=tf.float32) * 1e-20
—
You are receiving this because you were assigned.
Reply to this email directly, view it on GitHub
<#22048 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AAATxWxZntdSSmF3s1omGN4LHAYWZx2Fks5uZTt2gaJpZM4WYiLF>
.
--
- Alex
|
Yes, I think it's called "dynamic graphs" :) |
Not quite, as you will find that you can reproduce this issue even with
eager execution enabled. Dynamic graphs would help if we only needed to
support the case where the l2norm is a scalar (that is when there is no
axis argument).
The underlying issue is that an op like tf.maximum, which we use here to
pick which coordinates of the tensor need to be divided by the norm,
produces a gradient of 0 wrt the inputs it did not use to compute the
output. At the same time, an op like tf.sqrt() produces a gradient of
upstream_gradient * 1/2sqrt(input). If 1/2sqrt(input) is inf or NaN,
multiplying it by 0 (which is the upstream gradient for the coordinates
which were not used) will result in a NaN, which is what you're seeing here.
We are looking into fixing this overall issue but it's tricky to do so
without slowing down all operations whose gradients boil down up
upstream_gradient * f(x) when f(x) can be inf or NaN.
…On Thu, Sep 13, 2018 at 2:42 PM Octavian Ganea ***@***.***> wrote:
Yes, I think it's called "dynamic graphs" :)
—
You are receiving this because you were assigned.
Reply to this email directly, view it on GitHub
<#22048 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AAATxbP6PBJ4r1skKHmNNETdqcrB2jP9ks5uatE8gaJpZM4WYiLF>
.
--
- Alex
|
I think the root is that tensorflow, in graph mode, cannot block the gradient backpropagation of a slice of Tensor, say tf.maximum op here. NaN propagation is really annoying, and two "where" trick hurts performance. Is there any solution to solve the problem totally? |
It is especially annoying since it is usually very hard to debug and doesn't have a "mathematical" cause. It would be nice to have a tool that would help TF users to understand when this issue is generating |
The tool I used to debug this and partially fix, at least for the zeros
case, tf.add_check_numerics_ops, works pretty well for identifying these
issues.
Note that you can get around this using a tf.cond-based version (which
behaves the same as a dynamic graph) of clip_by_norm if you only care about
a scalar norm.
…On Thu, Sep 13, 2018 at 4:04 PM Octavian Ganea ***@***.***> wrote:
It is especially annoying since it is usually very hard to debug and
doesn't have a "mathematical" cause. It would be nice to have a tool that
would help TF users to understand when this issue is generating
their "NaN problem".
—
You are receiving this because you were assigned.
Reply to this email directly, view it on GitHub
<#22048 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AAATxZbDbQaFQeJNEMdEq7b5jIiEBttjks5uauSEgaJpZM4WYiLF>
.
--
- Alex
|
Any updates or suggestions here? I am running into this and do not know how to solve it. |
The gradient should obviously be [1,1,1] for all vectors a of norm smaller than 1, since this function should be the identity for those vectors.
Have I written custom code:
OS Platform and Distribution: Ubuntu 14.10
TensorFlow installed from: pip3
TensorFlow version: 1.10.1
Bazel version:
CUDA/cuDNN version:
GPU model and memory:
Exact command to reproduce: see above
Mobile device:
The text was updated successfully, but these errors were encountered: