Skip to content

Commit

Permalink
Add Gaussian NLL Loss (#50886)
Browse files Browse the repository at this point in the history
Summary:
Fixes #48520.

cc albanD (This is a clean retry PR #49807)

Pull Request resolved: #50886

Reviewed By: ejguan

Differential Revision: D26007435

Pulled By: albanD

fbshipit-source-id: 88fe91b40dea6f72e093e6301f0f04fcc842d2f0
  • Loading branch information
M.L. Croci authored and facebook-github-bot committed Jan 22, 2021
1 parent e34992e commit 8eb90d4
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/nn.rst
Expand Up @@ -276,6 +276,7 @@ Loss Functions
nn.CTCLoss
nn.NLLLoss
nn.PoissonNLLLoss
nn.GaussianNLLLoss
nn.KLDivLoss
nn.BCELoss
nn.BCEWithLogitsLoss
Expand Down
30 changes: 30 additions & 0 deletions test/test_nn.py
Expand Up @@ -4788,6 +4788,34 @@ def test_poisson_nll_loss_reduction_modes(self):
with self.assertRaisesRegex(ValueError, 'is not valid'):
F.poisson_nll_loss(input, target, reduction='total')

def test_gaussian_nll_loss_reduction_modes(self):
input = torch.tensor([[0.5, 1.5, 2.5], [2., 4., 6.]])
target = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
var = torch.tensor([[0.5, 1., 1.5], [1., 1.5, 2.]])
component_wise_loss = 0.5 * (torch.sum(torch.log(var) + (input - target)**2 / var, dim=1))
self.assertEqual(component_wise_loss,
F.gaussian_nll_loss(input, target, var, reduction='none'))
self.assertEqual(torch.sum(component_wise_loss),
F.gaussian_nll_loss(input, target, var, reduction='sum'))
self.assertEqual(torch.mean(component_wise_loss),
F.gaussian_nll_loss(input, target, var, reduction='mean'))
with self.assertRaisesRegex(ValueError, 'is not valid'):
F.gaussian_nll_loss(input, target, var, reduction='total')

def test_gaussian_nll_loss_args(self):
input = torch.randn(3, 5)
with self.assertRaisesRegex(ValueError, 'input and target must have same size'):
target = torch.randn(3, 6)
var = torch.ones(3, 5)
torch.nn.functional.gaussian_nll_loss(input, target, var)
with self.assertRaisesRegex(ValueError, 'var is of incorrect size'):
target = torch.randn(3, 5)
var = torch.ones(3, 3)
torch.nn.functional.gaussian_nll_loss(input, target, var)
with self.assertRaisesRegex(ValueError, 'var has negative entry/entries'):
var = -1 * torch.ones(3, 5)
torch.nn.functional.gaussian_nll_loss(input, target, var)

def test_KLDivLoss_batch_mean(self):
input_shape = (2, 5)
log_prob1 = F.log_softmax(torch.randn(input_shape), 1)
Expand Down Expand Up @@ -11356,6 +11384,7 @@ def test_invalid_reduction_strings(self, device):
input = torch.randn(3, 5, requires_grad=True, device=device)
cinput = torch.randn(3, 5, requires_grad=True, device=device, dtype=torch.cfloat)
target = torch.tensor([1, 0, 4], device=device)
var = torch.ones(size=input.size(), requires_grad=True, device=device)

for reduction in ['none', 'invalid']:
def v(fn):
Expand All @@ -11375,6 +11404,7 @@ def v(fn):
v(lambda: F.mse_loss(input, input, reduction=reduction))
v(lambda: F.hinge_embedding_loss(input, input, reduction=reduction))
v(lambda: F.poisson_nll_loss(input, input, reduction=reduction))
v(lambda: F.gaussian_nll_loss(input, input, var, reduction=reduction))
v(lambda: F.binary_cross_entropy_with_logits(input, input, reduction=reduction))

zeros = torch.zeros_like(input).to(torch.int64)
Expand Down
66 changes: 66 additions & 0 deletions torch/nn/functional.py
Expand Up @@ -2485,6 +2485,72 @@ def poisson_nll_loss(
return ret


def gaussian_nll_loss(input, target, var, *, full=False, eps=1e-6, reduction='mean'):
r"""Gaussian negative log likelihood loss.
See :class:`~torch.nn.GaussianNLLLoss` for details.
Args:
input: expectation of the Gaussian distribution.
target: sample from the Gaussian distribution.
var: tensor of positive variance(s), one for each of the expectations
in the input (heteroscedastic), or a single one (homoscedastic).
full: ``True``/``False`` (bool), include the constant term in the loss
calculation. Default: ``False``.
eps: value added to var, for stability. Default: 1e-6.
reduction: specifies the reduction to apply to the output:
`'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the output is the average of all batch member losses,
``'sum'``: the output is the sum of all batch member losses.
Default: ``'mean'``.
"""
if not torch.jit.is_scripting():
tens_ops = (input, target, var)
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
return handle_torch_function(
gaussian_nll_loss, tens_ops, input, target, var, full=full, eps=eps, reduction=reduction)

# Inputs and targets much have same shape
input = input.view(input.size(0), -1)
target = target.view(target.size(0), -1)
if input.size() != target.size():
raise ValueError("input and target must have same size")

# Second dim of var must match that of input or be equal to 1
var = var.view(input.size(0), -1)
if var.size(1) != input.size(1) and var.size(1) != 1:
raise ValueError("var is of incorrect size")

# Check validity of reduction mode
if reduction != 'none' and reduction != 'mean' and reduction != 'sum':
raise ValueError(reduction + " is not valid")

# Entries of var must be non-negative
if torch.any(var < 0):
raise ValueError("var has negative entry/entries")

# Clamp for stability
var = var.clone()
with torch.no_grad():
var.clamp_(min=eps)

# Calculate loss (without constant)
loss = 0.5 * (torch.log(var) + (input - target)**2 / var).view(input.size(0), -1).sum(dim=1)

# Add constant to loss term if required
if full:
D = input.size(1)
loss = loss + 0.5 * D * math.log(2 * math.pi)

# Apply reduction
if reduction == 'mean':
return loss.mean()
elif reduction == 'sum':
return loss.sum()
else:
return loss


def kl_div(
input: Tensor,
target: Tensor,
Expand Down
4 changes: 4 additions & 0 deletions torch/nn/functional.pyi.in
Expand Up @@ -229,6 +229,10 @@ def poisson_nll_loss(input: Tensor, target: Tensor, log_input: bool = ..., full:
reduction: str = ...) -> Tensor: ...


def gaussian_nll_loss(input: Tensor, target: Tensor, var: Tensor, full: Optional[bool] = ...,
eps: Optional[float] = ..., reduction: Optional[str] = ...) -> Tensor: ...


def kl_div(input: Tensor, target: Tensor, size_average: Optional[bool] = ..., reduce: Optional[bool] = ...,
reduction: str = ..., log_target: bool = ...) -> Tensor: ...

Expand Down
4 changes: 2 additions & 2 deletions torch/nn/modules/__init__.py
Expand Up @@ -10,7 +10,7 @@
from .loss import L1Loss, NLLLoss, KLDivLoss, MSELoss, BCELoss, BCEWithLogitsLoss, NLLLoss2d, \
CosineEmbeddingLoss, CTCLoss, HingeEmbeddingLoss, MarginRankingLoss, \
MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, SmoothL1Loss, \
SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, TripletMarginWithDistanceLoss, PoissonNLLLoss
SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, TripletMarginWithDistanceLoss, PoissonNLLLoss, GaussianNLLLoss
from .container import Container, Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict
from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \
MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, \
Expand Down Expand Up @@ -41,7 +41,7 @@
'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Softmin',
'Tanhshrink', 'RReLU', 'L1Loss', 'NLLLoss', 'KLDivLoss', 'MSELoss', 'BCELoss', 'BCEWithLogitsLoss',
'NLLLoss2d', 'PoissonNLLLoss', 'CosineEmbeddingLoss', 'CTCLoss', 'HingeEmbeddingLoss', 'MarginRankingLoss',
'MultiLabelMarginLoss', 'MultiLabelSoftMarginLoss', 'MultiMarginLoss', 'SmoothL1Loss',
'MultiLabelMarginLoss', 'MultiLabelSoftMarginLoss', 'MultiMarginLoss', 'SmoothL1Loss', 'GaussianNLLLoss',
'SoftMarginLoss', 'CrossEntropyLoss', 'Container', 'Sequential', 'ModuleList', 'ModuleDict',
'ParameterList', 'ParameterDict', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', "FractionalMaxPool3d",
Expand Down
79 changes: 79 additions & 0 deletions torch/nn/modules/loss.py
Expand Up @@ -299,6 +299,85 @@ def forward(self, log_input: Tensor, target: Tensor) -> Tensor:
eps=self.eps, reduction=self.reduction)


class GaussianNLLLoss(_Loss):
r"""Gaussian negative log likelihood loss.
The targets are treated as samples from Gaussian distributions with
expectations and variances predicted by the neural network. For a
D-dimensional ``target`` tensor modelled as having heteroscedastic Gaussian
distributions with a D-dimensional tensor of expectations ``input`` and a
D-dimensional tensor of positive variances ``var`` the loss is:
.. math::
\text{loss} = \frac{1}{2}\sum_{i=1}^D \left(\log\left(\text{max}\left(\text{var}[i],
\ \text{eps}\right)\right) + \frac{\left(\text{input}[i] - \text{target}[i]\right)^2}
{\text{max}\left(\text{var}[i], \ \text{eps}\right)}\right) + \text{const.}
where :attr:`eps` is used for stability. By default, the constant term of
the loss function is omitted unless :attr:`full` is ``True``. If ``var`` is
a scalar (implying ``target`` tensor has homoscedastic Gaussian
distributions) it is broadcasted to be the same size as the input.
Args:
full (bool, optional): include the constant term in the loss
calculation. Default: ``False``.
eps (float, optional): value used to clamp ``var`` (see note below), for
stability. Default: 1e-6.
reduction (string, optional): specifies the reduction to apply to the
output:``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
will be applied, ``'mean'``: the output is the average of all batch
member losses, ``'sum'``: the output is the sum of all batch member
losses. Default: ``'mean'``.
Shape:
- Input: :math:`(N, *)` where :math:`*` means any number of additional
dimensions
- Target: :math:`(N, *)`, same shape as the input
- Var: :math:`(N, 1)` or :math:`(N, *)`, same shape as the input
- Output: scalar if :attr:`reduction` is ``'mean'`` (default) or
``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N)`
Examples::
>>> loss = nn.GaussianNLLLoss()
>>> input = torch.randn(5, 2, requires_grad=True)
>>> target = torch.randn(5, 2)
>>> var = torch.ones(5, 2, requires_grad=True) #heteroscedastic
>>> output = loss(input, target, var)
>>> output.backward()
>>> loss = nn.GaussianNLLLoss()
>>> input = torch.randn(5, 2, requires_grad=True)
>>> target = torch.randn(5, 2)
>>> var = torch.ones(5, 1, requires_grad=True) #homoscedastic
>>> output = loss(input, target, var)
>>> output.backward()
Note:
The clamping of ``var`` is ignored with respect to autograd, and so the
gradients are unaffected by it.
Reference:
Nix, D. A. and Weigend, A. S., "Estimating the mean and variance of the
target probability distribution", Proceedings of 1994 IEEE International
Conference on Neural Networks (ICNN'94), Orlando, FL, USA, 1994, pp. 55-60
vol.1, doi: 10.1109/ICNN.1994.374138.
"""
__constants__ = ['full', 'eps', 'reduction']
full: bool
eps: float

def __init__(self, *, full: bool = False, eps: float = 1e-6, reduction: str = 'mean') -> None:
super(GaussianNLLLoss, self).__init__(None, None, reduction)
self.full = full
self.eps = eps

def forward(self, input: Tensor, target: Tensor, var: Tensor) -> Tensor:
return F.gaussian_nll_loss(input, target, var, full=self.full, eps=self.eps, reduction=self.reduction)


class KLDivLoss(_Loss):
r"""The Kullback-Leibler divergence loss measure
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Expand Up @@ -621,6 +621,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.nn.functional.fractional_max_pool3d_with_indices: (
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
_random_samples=None: -1),
torch.nn.functional.gaussian_nll_loss: (lambda input, target, var, full=False, eps=1e-06, reduction='mean': -1),
torch.nn.functional.gelu: lambda input: -1,
torch.nn.functional.glu: lambda input, dim=-1: -1,
torch.nn.functional.grid_sample: lambda input, grid, mode='bilinear', padding_mode='zeros', align_corners=None: -1,
Expand Down

0 comments on commit 8eb90d4

Please sign in to comment.