possible unsafety in torch.distributions.kl_divergence for Bernoullis #15288
Labels
module: distributions
Related to torch.distributions
module: NaNs and Infs
Problems related to NaN and Inf handling in floating point
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Note
This might be my first pytorch issue.
🐛 Bug
torch.distributions.kl_divergence seems numerically unsafe for Bernoullis. In the following script, I compare with a hand-written divergence between Bernoullis that makes sure to add epsilon before log()'ing. The torch KL and the handwritten version compute the same number down to the fourth decimal, but torch's KL causes a nan grad while mine does not. I found the nan using the torch anomaly catcher (so useful), whose trace I will include below the example code.
To Reproduce
For simplicity, the following code seeks to learn a vector of Bernoulli logits for a distribution q_z to minimize the KL betweeen q_z and a prior p_z. In the following code, switch between loss=kl_pytorch and loss=kl_custom to observe the difference in behavior:
torch anomaly detector traceback:
The following shows the line of torch.distributions.kl_divergence for Bernoullis that is causing the nan:
Expected behavior
I would expect not to have nan gradients, even when the logits for the Bernoulli are around magnitude 40. This might arise e.g. in a deep generative model, where a matrix transformation yields a layer of real numbers to be interpreted as Bernoulli logits. Please let me know if one should instead reduce logit magnitudes before initializing a Bernoulli distribution.
If something should be fixed, not sure which of the following is better:
Environment
PyTorch version: 1.0.0
Is debug build: No
CUDA used to build PyTorch: None
OS: Mac OSX 10.11.6
GCC version: Could not collect
CMake version: version 3.11.4
Python version: 3.6
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Versions of relevant libraries:
[pip] Could not collect
[conda] blas 1.0 mkl
[conda] mkl 2018.0.3 1
[conda] mkl_fft 1.0.6 py36hb8a8100_0
[conda] mkl_random 1.0.1 py36h5d10147_1
[conda] pytorch 1.0.0 py3.6_1 pytorch
[conda] torchvision 0.2.1 py_2 pytorch
Additional context
Thank you!
cc @fritzo @neerajprad @alicanb @nikitaved
The text was updated successfully, but these errors were encountered: