Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

possible unsafety in torch.distributions.kl_divergence for Bernoullis #15288

Open
marikgoldstein opened this issue Dec 16, 2018 · 2 comments
Open
Labels
module: distributions Related to torch.distributions module: NaNs and Infs Problems related to NaN and Inf handling in floating point triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@marikgoldstein
Copy link

marikgoldstein commented Dec 16, 2018

Note

This might be my first pytorch issue.

🐛 Bug

torch.distributions.kl_divergence seems numerically unsafe for Bernoullis. In the following script, I compare with a hand-written divergence between Bernoullis that makes sure to add epsilon before log()'ing. The torch KL and the handwritten version compute the same number down to the fourth decimal, but torch's KL causes a nan grad while mine does not. I found the nan using the torch anomaly catcher (so useful), whose trace I will include below the example code.

To Reproduce

For simplicity, the following code seeks to learn a vector of Bernoulli logits for a distribution q_z to minimize the KL betweeen q_z and a prior p_z. In the following code, switch between loss=kl_pytorch and loss=kl_custom to observe the difference in behavior:

import torch
import torch.nn as nn
from torch.distributions import Bernoulli
from torch.distributions import kl_divergence


EPS = 1e-16 
torch.set_anomaly_enabled(True)

def custom_bernoulli_kl(q_logits,p_probs):

	q1_probs = nn.Sigmoid()(q_logits)

	q1 = q1_probs
	q0 = 1 - q1
	p1 = p_probs
	p0 = 1 - p1

	logq1 = (q1 + EPS).log()
	logq0 = (q0 + EPS).log()
	logp1 = (p1 + EPS).log()
	logp0 = (p0 + EPS).log()

	kldiv_1 = q1*(logq1 - logp1)
	kldiv_0 = q0*(logq0 - logp0)
	return (kldiv_1 + kldiv_0).sum()


q_logits = torch.tensor([10.3,-6.0,30.0],requires_grad=True)
optimizer = torch.optim.Adam([q_logits],lr=0.001)

for i in range(10):

	p_probs = torch.tensor([0.5,0.5,0.5])
	q_z = Bernoulli(logits=q_logits)
	p_z = Bernoulli(probs=p_probs)

	kl_pytorch = torch.distributions.kl_divergence(q_z,p_z).sum()
	kl_custom = custom_bernoulli_kl(q_logits,p_probs)

	print("---")
	print("KL Pytorch:",kl_pytorch)
	print("KL Custom:",kl_custom)
	print("---")

	loss = kl_pytorch
	#loss = kl_custom # doesnt break
	loss.backward()
	optimizer.step()

torch anomaly detector traceback:

The following shows the line of torch.distributions.kl_divergence for Bernoullis that is causing the nan:

sys:1: RuntimeWarning: Traceback of forward call that caused the error:
  File "kl_test.py", line 38, in <module>
    kl_pytorch = torch.distributions.kl_divergence(q_z,p_z).sum()
  File "/Users/torch/anaconda3/envs/python3/lib/python3.6/site-packages/torch/distributions/kl.py", line 166, in kl_divergence
    return fun(p, q)
  File "/Users/torch/anaconda3/envs/python3/lib/python3.6/site-packages/torch/distributions/kl.py", line 183, in _kl_bernoulli_bernoulli
    t2 = (1 - p.probs) * ((1 - p.probs) / (1 - q.probs)).log()

Traceback (most recent call last):
  File "kl_test.py", line 48, in <module>
    loss.backward()
  File "/Users/torch/anaconda3/envs/python3/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/Users/torch/anaconda3/envs/python3/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Function 'MulBackward0' returned nan values in its 0th output.

Expected behavior

I would expect not to have nan gradients, even when the logits for the Bernoulli are around magnitude 40. This might arise e.g. in a deep generative model, where a matrix transformation yields a layer of real numbers to be interpreted as Bernoulli logits. Please let me know if one should instead reduce logit magnitudes before initializing a Bernoulli distribution.

If something should be fixed, not sure which of the following is better:

  • "distributions should always have probs that are safe for KL calculation"
  • "KL should perform additional numerical safety routines"

Environment

PyTorch version: 1.0.0
Is debug build: No
CUDA used to build PyTorch: None

OS: Mac OSX 10.11.6
GCC version: Could not collect
CMake version: version 3.11.4

Python version: 3.6
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip] Could not collect
[conda] blas 1.0 mkl
[conda] mkl 2018.0.3 1
[conda] mkl_fft 1.0.6 py36hb8a8100_0
[conda] mkl_random 1.0.1 py36h5d10147_1
[conda] pytorch 1.0.0 py3.6_1 pytorch
[conda] torchvision 0.2.1 py_2 pytorch

Additional context

Thank you!

cc @fritzo @neerajprad @alicanb @nikitaved

@marikgoldstein
Copy link
Author

Tagging some pytorch distributions contributors I know, in case they may be quick to help say whether this is an issue or non-issue: @fritzo @rachtsingh

@rachtsingh
Copy link
Contributor

I think this is because 0/0 == nan (though 1/0 == inf), which I would guess is intended behavior (though a link or explanation would be great):

In [1]: tensor(1.)/tensor(0.)
Out[1]: tensor(inf)
In [2]: tensor(0.)/tensor(0.)
Out[2]: tensor(nan)

Still, semantically speaking, I imagine we'd like the derivative of the kl_divergence between two Bernoullis with p = 1 w.r.t. their logits to be inf, rather than nan?:

In [3]: x = torch.tensor(40., requires_grad=True)
In [4]: y = torch.tensor(40., requires_grad=True)
In [5]: l = kl_divergence(Bernoulli(logits=y), Bernoulli(logits=x))
In [6]: l.backward()
In [7]: x.grad
Out[7]: tensor(nan)

In practice while training a generative model you would never want inf gradients anyway so you'd want to constrain the parameters - inf is probably the right answer from PyTorch as an autograd system, but the right modeling choice is not to let the probability get to 1.

@zou3519 zou3519 added module: distributions Related to torch.distributions triage review labels Dec 17, 2018
@zou3519 zou3519 added module: NaNs and Infs Problems related to NaN and Inf handling in floating point triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Jun 17, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: distributions Related to torch.distributions module: NaNs and Infs Problems related to NaN and Inf handling in floating point triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants