New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Why should we add a small number to the output and multiply a small number by the input of the softplus function? #703
Comments
I can't speak to this particular case in much detail, but a general reason
for the habit of adding a small constant like 1e-3 is to avoid numerical
issues that might occur if the scale becomes very close to zero (or even
zero exactly, e.g., softplus(-150.0) == 0.0 in float32). Adding the
constant prevents the optimizer from ever considering pathologically small
values.
By contrast, multiplying the input by 0.05 doesn't change the space of
possible outputs; and (as you point out) it doesn't change the optimal
scale --- it's just a reparameterization. The effect is to precondition the
optimization: if y = 0.05 * x, then df(y)/dx = 0.05 * df(y)/dy, so the
effect is that gradients are divided by 20, while x also has a scale 20
times that of y, so the *relative* change in gradient is a factor of 400.
That's equivalent to specifying that the optimizer's step size for the
scale param should be 1/400 of the step size it takes for the loc.
I can't say exactly why that was done in this example; it might not be
*that* necessary since you'd expect an adaptive optimizer like Adam to work
out reasonable step sizes on its own, eventually. I'd guess that it helps
speed up the optimization, or avoid local minima, or both. It's also
possible that it avoids numerical issues by preventing the optimizer from
considering extreme values for the scale before it has a reasonable idea of
the loc. Perhaps one of the authors of that post can say more.
…On Sat, Dec 28, 2019 at 1:16 PM nbro ***@***.***> wrote:
In the article
https://blog.tensorflow.org/2019/03/regression-with-probabilistic-layers-in.html,
you have the following code
model = tfk.Sequential([
tf.keras.layers.Dense(1 + 1),
tfp.layers.DistributionLambda(
lambda t: tfd.Normal(loc=t[..., :1],
scale=1e-3 + tf.math.softplus(0.05 * t[..., 1:]))),
])
where you multiply 0.05 by the input to the softplus function. You also
add 1e-3 to its output. We want the scale (variance) to be non-negative.
However, the softplus never produces a negative number, so there should be
no need for adding 1e-3 to the output of the softplus. Similarly, I don't
see the need for multiplying t[..., 1:] by 0.05. Of course, you
apparently want to make the input to the softplus smaller than the output
of the previous dense layer, but why? A smaller input will make the
softplus produce a smaller output, so an output closer to zero, but, at the
same time, you add 1e-3 to the output of the softplus, so, in a way,
these two operations are cancelling each other.
—
You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub
<#703?email_source=notifications&email_token=AAHSFCQNRX4WK5KVYHKBQHTQ26JWVA5CNFSM4KAQL3J2YY3PNVWWK3TUL52HS4DFUVEXG43VMWVGG33NNVSW45C7NFSM4IDBR6UA>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAHSFCUVJSCCG7BSRQXHGVDQ26JWVANCNFSM4KAQL3JQ>
.
|
@davmre You will get numeric errors whether or not you perform these two operations. If you look at the plot of the softmax function, it should roughly be zero when I don't understand what you mean by "optimal scale". If, by reparameterization, you mean something similar to the reparameterization trick, then, yes, at first glance, it seems something similar. I am not sure I follow your reasoning though. The derivatives are taken with respect to the parameters of the model. |
I just mean 'reparameterization' in the ordinary mathematical sense of the Here you might ordinarily define a normal RV in terms of its The effect of doing this is that optimization wrt scale will move 400 times
where t(θ) are the activations defining the scale from the previous layer.
Yup, exactly. The bias is that the optimizer won't consider values of less
will yield NaN because it's both dividing by and trying to take the log of |
I've noticed that this trick is also used in Keras and TensorFlow. You may want to have a look at https://github.com/tensorflow/tensorflow/blob/7fda1add7cc637693781f4967ca290b6b659072b/tensorflow/python/keras/backend_config.py#L33, where the following function is defined
which is, for example, used in the calculation of the binary cross-entropy loss https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/backend.py
The TensorFlow Probability creators may be interested in implementing a similar thing in TFP. |
Great answer. I understand the gradient changes when the constant is multiplied inside the softplus. Nevertheless, I don't understand why in this example, the constant is added, so it should not affect the optimization, and just adding a new parameterization that does nothing. Copying it here if someone cannot go to the link: # Specify the surrogate posterior over `keras.layers.Dense` `kernel` and `bias`.
def posterior_mean_field(kernel_size, bias_size=0, dtype=None):
n = kernel_size + bias_size
c = np.log(np.expm1(1.))
return tf.keras.Sequential([
tfp.layers.VariableLayer(2 * n, dtype=dtype),
tfp.layers.DistributionLambda(lambda t: tfd.Independent(
tfd.Normal(loc=t[..., :n],
scale=1e-5 + tf.nn.softplus(c + t[..., n:])),
reinterpreted_batch_ndims=1)),
]) |
This is great work. To address one of the paper employed posterior sharpering. I also experienced that having variable sequence length cause issues with Backprop through time, this was in stock Tensorflow. I am not if that has any relationship this performance issue here but though to mention incase if folks have similar observation. |
Closing this as I believe davmre has answered the issue. Basically, we want to avoid bad regions of parameter space for our optimization, and by using Another example is when you train a Gaussian Process with a |
In the article https://blog.tensorflow.org/2019/03/regression-with-probabilistic-layers-in.html, you have the following code
where you multiply
0.05
by the input to the softplus function (operation 1) and you add1e-3
to its output (operation 2).We want the scale (variance) to be non-negative. However, the softplus never produces a negative number, so there should be no need for adding
1e-3
to the output of the softplus. Similarly, I don't see the need for multiplyingt[..., 1:]
by0.05
.I tried to train a network that attempts to model aleatoric uncertainty with and without operations 1 and 2. The results are indeed slightly different in both cases. Without operation 1 and 2, the uncertainty (i.e. variance) doesn't seem to be modeled correctly, apparently, in the cases where the points lie in a small range.
With operations 1 and 2.
Without operations 1 and 2.
The text was updated successfully, but these errors were encountered: