-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Open
Description
The TensorFlow probability implementation of softplus leaks memory, and appears to no longer be needed. That is, I think the standard tf.nn.softplus
implementation can be used now, as numerical stability issues appear to have been solved.
Currently the implementation of softplus is as follows (from here):
# TODO(b/155501444): Remove this when tf.nn.softplus is fixed.
if JAX_MODE:
_stable_grad_softplus = tf.nn.softplus
else:
@tf.custom_gradient
def _stable_grad_softplus(x):
"""A (more) numerically stable softplus than `tf.nn.softplus`."""
x = tf.convert_to_tensor(x)
if x.dtype == tf.float64:
cutoff = -20
else:
cutoff = -9
y = tf.where(x < cutoff, tf.math.log1p(tf.exp(x)), tf.nn.softplus(x))
def grad_fn(dy):
return dy * tf.where(x < cutoff, tf.exp(x), tf.nn.sigmoid(x))
return y, grad_fn
This leaks memory (in non-JAX mode) due to a couple of issues:
- The
grad_fn
closure captures the tensor represented byx
. This closure then ends up in the gradient registry, which is never cleared. So the tensor represented byx
hangs around forever. - For a similar reason TensorFlow's
custom_gradient
implementation also leaks memory. See 97697 for more details.
Here is a Colab notebook to demonstrate the memory leak.
However, I believe that the numerical stability issues with tf.nn.softplus
have been solved. Specifically:
- The
tf.nn.softplus
implementation now useslog1p
as of this commit on May 1 2020. - The gradient computation for
tf.nn.softplus
now usesmath_ops.sigmoid
as of this commit on April 4 2019. - The Eigen implementation of sigmoid (which I think is here) computes this as
e^x / 1.0 + e^x
, so using the approximation ofe^x
in_stable_grad_softplus
seems unnecessary to me. Ife^x
is very small then1.0 + e^x
will be exactly 1.0, so this is equivalent toe^x
. Ife^x
> 1.0 then the result ofe^x / 1.0 + e^x
will be (I think) more accurate than just approximating the gradient toe^x
. But I am not a numerical stability expert, so I may be wrong.
Metadata
Metadata
Assignees
Labels
No labels