Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GaussianNLLLoss doesn't support the usual reduction='none' #53964

Closed
almson opened this issue Mar 13, 2021 · 13 comments
Closed

GaussianNLLLoss doesn't support the usual reduction='none' #53964

almson opened this issue Mar 13, 2021 · 13 comments
Labels
module: loss Problem is related to loss function module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@almson
Copy link

almson commented Mar 13, 2021

The new Gaussian NLL Loss behaves differently from the other losses by not supporting the usual none mode. Instead, it still does a reduction over all dimensions except batch.

What I expect: gaussian_nll_loss(..., reduction='none') should return a tensor with the same shape as input, target, and var. The loss is essentially just 0.5 * (torch.log(var) + (input - target)**2 / var), which is what I want returned.

What happens: a scalar or a tensor of shape (N,) is returned. The implementation does .view(input.size(0), -1).sum(dim=1)

Why this matters: I use reduction='none' for custom masking, weighing, etc.

P.S. gaussian_nll_loss is missing from the nn.functional documentation.

cc @albanD @mruberry @jbschlosser

@almson
Copy link
Author

almson commented Mar 13, 2021

P.P.S. The full parameter is useless and not theoretically grounded.

The main problem with a Gaussian Negative Log Likelihood that the full parameter was perhaps trying to solve is that the loss goes negative (which looks weird but doesn't actually matter). The true reason is that the likelihood (of which you're supposedly taking the negative log) is undefined. You simply can't get a probability by evaluating a gaussian at a point. You get a probability density, which is very often greater than 1. Scaling that density by sqrt(1/2pi) (which is what full does) doesn't help. The real solution is to use a PDF for the target, not a point, and integrate over it. However, we don't know this true target PDF. We can assign an arbitrary uniform distribution with a small variance, and that would translate into an arbitrary constant. Therefore, you might as well replace the full parameter with let_me_add_an_arbitrary_constant_for_you. It's better to just remove it.

@almson
Copy link
Author

almson commented Mar 13, 2021

P.P.P.S. The var is clamped but the clamping doesn't mask the gradients, which seems wrong to me. It will cause divergence because var keeps being updated even if it's not being used.

@nailimixaM
Copy link

nailimixaM commented Mar 13, 2021

Thanks for your comments @almson, I'll look into fixing that reduction mode for you and figuring out that doc absence.

The constant added by setting the full parameter is useless with respect to optimisation (because it is a constant, and that's why it's not included by default) but it is theoretically grounded: a k-dimensional Gaussian pdf has the term sqrt[(2pi^k) x det(var)] on the denominator. Taking the log of that gives you a (constant) k*log(2pi) term + the (non-constant) variance term. Setting full just adds that constant.

The var clamping business was discussed here. I refer to @albanD @jbschlosser if you want to discuss in more detail.

@nailimixaM
Copy link

@almson After checking the clamping again, I don't think your problem will happen as a local copy of var is made to which the clamping happens:

var = var.clone()
with torch.no_grad():
. var.clamp_(min=eps)

The original var is therefore unchanged by using the loss function.

@almson
Copy link
Author

almson commented Mar 13, 2021

@nailimixaM What I understand will happen is that: Assume eps is 1, var is 0.1. The gradient will be evaluated as if var=1 which, for example, will result in var.grad=1. The gradient will flow back to the real var, and after update var might be 0.09 while var.grad will continue to be 1. Eventually, var will be driven to 0. Using wrong gradients will result in divergence.

Here is the doc:

The clamping of var is ignored with respect to autograd, and so the gradients are unaffected by it.

@almson
Copy link
Author

almson commented Mar 13, 2021

I found the discussion. Original implementation uses var = var + eps. I guess it would have a similar problem. The safest alternative is to simply do var.clamp(min=eps) which will zero out the gradient elements which got clamped.

@almson
Copy link
Author

almson commented Mar 13, 2021

@nailimixaM

The constant added by setting the full parameter is useless with respect to optimisation (because it is a constant, and that's why it's not included by default) but it is theoretically grounded: a k-dimensional Gaussian pdf has the term sqrt[(2pi^k) x det(var)] on the denominator. Taking the log of that gives you a (constant) k*log(2pi) term + the (non-constant) variance term. Setting full just adds that constant.

I realize where that term comes from, but it's there to normalize the gaussian when integrating over it. But the integration isn't being done and would be arbitrary anyway, so including that term doesn't get you any closer to true correctness. It just leads to noise in the API and documentation.

Moreover, the term disappears if the the target PDF is assumed to be gaussian (an excellent assumption!). In that case, the loss becomes the KL loss between two gaussians, which doesn't actually have a sqrt(2pi) term.

The Gaussian KL reduces to the Gaussian (pseudo-)NLL (plus a constant) in the limit of target variance going to 0, but assuming non-negligible target variance results in an interesting K/var term. It's this term that might be interesting to include. Of note is that Gaussian KL never produces negative values (because it doesn't hack around prob densities).

I'll also add that support for full-covariance Gaussians would be a welcome (and straightforward) addition. (Formula: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence#Multivariate_normal_distributions)

@nailimixaM
Copy link

I found the discussion. Original implementation uses var = var + eps. I guess it would have a similar problem. The safest alternative is to simply do var.clamp(min=eps) which will zero out the gradient elements which got clamped.

I don't think the original implementation would have the problem as the addition is visible to autograd. I know that adding var = var + eps is one option used in industry, for ex. at DeepMind where a friend of mine works. However, I'd rather leave this autograd discussion to the PyTorch moderators who will know much more than I do about it - maybe open a PyTorch discussion about it?

@nailimixaM
Copy link

I realize where that term comes from, but it's there to normalize the gaussian when integrating over it. But the integration isn't being done and would be arbitrary anyway, so including that term doesn't get you any closer to true correctness. It just leads to noise in the API and documentation.

It's useful when comparing different losses, or when you want to report the actual value of the Gaussian likelihood and don't want to think about having to add klog(2pi) s everywhere.

Moreover, the term disappears if the the target PDF is assumed to be gaussian (an excellent assumption!). In that case, the loss becomes the KL loss between two gaussians, which doesn't actually have a sqrt(2pi) term.

Sure, that's a different loss function, which doesn't have such constant. That's not what this loss function is though.

Of note is that Gaussian KL never produces negative values (because it doesn't hack around prob densities).

This loss function doesn't try to avoid positive/negative values.

I'll also add that support for full-covariance Gaussians would be a welcome (and straightforward) addition. (Formula: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence#Multivariate_normal_distributions)

If you mean this for the (non-KL) Gaussian NLL then I'd argue this is a very rare case as it involves matrix inversion which is usually avoided by assuming diagonal covariances. I can't comment on how common/rare it is in the KL Gaussian loss.

@ngimel ngimel added module: loss Problem is related to loss function module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Mar 15, 2021
@albanD
Copy link
Collaborator

albanD commented Mar 15, 2021

For the var question, the idea here is that this is done only for numerical stability. Of course if you clamp with a very large value, you can get very wrong values.
But in essence, what the current version does is the same (gradient-wise) as var = var + eps. The benefit is that we don't distort the var value for large enough values.

@nailimixaM
Copy link

Hi @albanD I think I've got a fix for this ready (sorry I only just got round to it) - but when I fetch from remote/master I don't seem to have any of the code for the loss function, am I missing something? Do I need to re-paste it all in?

@albanD
Copy link
Collaborator

albanD commented Apr 19, 2021

Hi,

It might have been moved recently?
Looking at it it looks like this is already fixed no?

pytorch/torch/nn/functional.py

Lines 2637 to 2643 in 3fe4718

# Apply reduction
if reduction == 'mean':
return loss.mean()
elif reduction == 'sum':
return loss.sum()
else:
return loss

@nailimixaM
Copy link

That's still the old buggy version (the viewing business means the shape of loss isn't the expected shape when no reduction is desired)- I'll post a PR to fix

krshrimali pushed a commit to krshrimali/pytorch that referenced this issue May 19, 2021
Summary:
Fixes pytorch#53964. cc albanD almson

## Major changes:
- Overhauled the actual loss calculation so that the shapes are now correct (in functional.py)
- added the missing doc in nn.functional.rst

## Minor changes (in functional.py):
- I removed the previous check on whether input and target were the same shape. This is to allow for broadcasting, say when you have 10 predictions that all have the same target.
- I added some comments to explain each shape check in detail. Let me know if these should be shortened/cut.

Screenshots of updated docs attached.
Let me know what you think, thanks!

## Edit: Description of change of behaviour (affecting BC):
The backwards-compatibility is only affected for the `reduction='none'` mode. This was the source of the bug. For tensors with size (N, D), the old returned loss had size (N), as incorrect summation was happening. It will now have size (N, D) as expected.

### Example
Define input tensors, all with size (2, 3).
`input = torch.tensor([[0., 1., 3.], [2., 4., 0.]], requires_grad=True)`
`target = torch.tensor([[1., 4., 2.], [-1., 2., 3.]])`
`var = 2*torch.ones(size=(2, 3), requires_grad=True)`

Initialise loss with reduction mode 'none'. We expect the returned loss to have the same size as the input tensors, (2, 3).
`loss = torch.nn.GaussianNLLLoss(reduction='none')`

Old behaviour:
`print(loss(input, target, var)) `
`# Gives tensor([3.7897, 6.5397], grad_fn=<MulBackward0>. This has size (2).`

New behaviour:
`print(loss(input, target, var)) `
`# Gives tensor([[0.5966, 2.5966, 0.5966], [2.5966, 1.3466, 2.5966]], grad_fn=<MulBackward0>)`
`# This has the expected size, (2, 3).`

To recover the old behaviour, sum along all dimensions except for the 0th:
`print(loss(input, target, var).sum(dim=1))`
`# Gives tensor([3.7897, 6.5397], grad_fn=<SumBackward1>.`

![doc1](https://user-images.githubusercontent.com/26558092/115391089-f7f47b00-a1d6-11eb-8726-e4da9057aee0.png)
![doc2](https://user-images.githubusercontent.com/26558092/115391094-f925a800-a1d6-11eb-954b-afd187f42bc7.png)

Pull Request resolved: pytorch#56469

Reviewed By: jbschlosser, agolynski

Differential Revision: D27894170

Pulled By: albanD

fbshipit-source-id: 197890189c97c22109491c47f469336b5b03a23f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: loss Problem is related to loss function module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants