Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Gaussian NLL Loss #50886

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/nn.rst
Expand Up @@ -276,6 +276,7 @@ Loss Functions
nn.CTCLoss
nn.NLLLoss
nn.PoissonNLLLoss
nn.GaussianNLLLoss
nn.KLDivLoss
nn.BCELoss
nn.BCEWithLogitsLoss
Expand Down
30 changes: 30 additions & 0 deletions test/test_nn.py
Expand Up @@ -4788,6 +4788,34 @@ def test_poisson_nll_loss_reduction_modes(self):
with self.assertRaisesRegex(ValueError, 'is not valid'):
F.poisson_nll_loss(input, target, reduction='total')

def test_gaussian_nll_loss_reduction_modes(self):
input = torch.tensor([[0.5, 1.5, 2.5], [2., 4., 6.]])
target = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
var = torch.tensor([[0.5, 1., 1.5], [1., 1.5, 2.]])
component_wise_loss = 0.5 * (torch.sum(torch.log(var) + (input - target)**2 / var, dim=1))
self.assertEqual(component_wise_loss,
F.gaussian_nll_loss(input, target, var, reduction='none'))
self.assertEqual(torch.sum(component_wise_loss),
F.gaussian_nll_loss(input, target, var, reduction='sum'))
self.assertEqual(torch.mean(component_wise_loss),
F.gaussian_nll_loss(input, target, var, reduction='mean'))
with self.assertRaisesRegex(ValueError, 'is not valid'):
F.gaussian_nll_loss(input, target, var, reduction='total')

def test_gaussian_nll_loss_args(self):
input = torch.randn(3, 5)
with self.assertRaisesRegex(ValueError, 'input and target must have same size'):
target = torch.randn(3, 6)
var = torch.ones(3, 5)
torch.nn.functional.gaussian_nll_loss(input, target, var)
with self.assertRaisesRegex(ValueError, 'var is of incorrect size'):
target = torch.randn(3, 5)
var = torch.ones(3, 3)
torch.nn.functional.gaussian_nll_loss(input, target, var)
with self.assertRaisesRegex(ValueError, 'var has negative entry/entries'):
var = -1 * torch.ones(3, 5)
torch.nn.functional.gaussian_nll_loss(input, target, var)

def test_KLDivLoss_batch_mean(self):
input_shape = (2, 5)
log_prob1 = F.log_softmax(torch.randn(input_shape), 1)
Expand Down Expand Up @@ -11356,6 +11384,7 @@ def test_invalid_reduction_strings(self, device):
input = torch.randn(3, 5, requires_grad=True, device=device)
cinput = torch.randn(3, 5, requires_grad=True, device=device, dtype=torch.cfloat)
target = torch.tensor([1, 0, 4], device=device)
var = torch.ones(size=input.size(), requires_grad=True, device=device)

for reduction in ['none', 'invalid']:
def v(fn):
Expand All @@ -11375,6 +11404,7 @@ def v(fn):
v(lambda: F.mse_loss(input, input, reduction=reduction))
v(lambda: F.hinge_embedding_loss(input, input, reduction=reduction))
v(lambda: F.poisson_nll_loss(input, input, reduction=reduction))
v(lambda: F.gaussian_nll_loss(input, input, var, reduction=reduction))
v(lambda: F.binary_cross_entropy_with_logits(input, input, reduction=reduction))

zeros = torch.zeros_like(input).to(torch.int64)
Expand Down
66 changes: 66 additions & 0 deletions torch/nn/functional.py
Expand Up @@ -2485,6 +2485,72 @@ def poisson_nll_loss(
return ret


def gaussian_nll_loss(input, target, var, *, full=False, eps=1e-6, reduction='mean'):
r"""Gaussian negative log likelihood loss.

See :class:`~torch.nn.GaussianNLLLoss` for details.

Args:
input: expectation of the Gaussian distribution.
target: sample from the Gaussian distribution.
var: tensor of positive variance(s), one for each of the expectations
in the input (heteroscedastic), or a single one (homoscedastic).
full: ``True``/``False`` (bool), include the constant term in the loss
calculation. Default: ``False``.
eps: value added to var, for stability. Default: 1e-6.
reduction: specifies the reduction to apply to the output:
`'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the output is the average of all batch member losses,
``'sum'``: the output is the sum of all batch member losses.
Default: ``'mean'``.
"""
if not torch.jit.is_scripting():
tens_ops = (input, target, var)
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
return handle_torch_function(
gaussian_nll_loss, tens_ops, input, target, var, full=full, eps=eps, reduction=reduction)

# Inputs and targets much have same shape
input = input.view(input.size(0), -1)
target = target.view(target.size(0), -1)
if input.size() != target.size():
raise ValueError("input and target must have same size")

# Second dim of var must match that of input or be equal to 1
var = var.view(input.size(0), -1)
if var.size(1) != input.size(1) and var.size(1) != 1:
raise ValueError("var is of incorrect size")

# Check validity of reduction mode
if reduction != 'none' and reduction != 'mean' and reduction != 'sum':
raise ValueError(reduction + " is not valid")

# Entries of var must be non-negative
if torch.any(var < 0):
raise ValueError("var has negative entry/entries")

# Clamp for stability
var = var.clone()
with torch.no_grad():
var.clamp_(min=eps)

# Calculate loss (without constant)
loss = 0.5 * (torch.log(var) + (input - target)**2 / var).view(input.size(0), -1).sum(dim=1)

# Add constant to loss term if required
if full:
D = input.size(1)
loss = loss + 0.5 * D * math.log(2 * math.pi)

# Apply reduction
if reduction == 'mean':
return loss.mean()
elif reduction == 'sum':
return loss.sum()
else:
return loss


def kl_div(
input: Tensor,
target: Tensor,
Expand Down
4 changes: 4 additions & 0 deletions torch/nn/functional.pyi.in
Expand Up @@ -229,6 +229,10 @@ def poisson_nll_loss(input: Tensor, target: Tensor, log_input: bool = ..., full:
reduction: str = ...) -> Tensor: ...


def gaussian_nll_loss(input: Tensor, target: Tensor, var: Tensor, full: Optional[bool] = ...,
eps: Optional[float] = ..., reduction: Optional[str] = ...) -> Tensor: ...


def kl_div(input: Tensor, target: Tensor, size_average: Optional[bool] = ..., reduce: Optional[bool] = ...,
reduction: str = ..., log_target: bool = ...) -> Tensor: ...

Expand Down
4 changes: 2 additions & 2 deletions torch/nn/modules/__init__.py
Expand Up @@ -10,7 +10,7 @@
from .loss import L1Loss, NLLLoss, KLDivLoss, MSELoss, BCELoss, BCEWithLogitsLoss, NLLLoss2d, \
CosineEmbeddingLoss, CTCLoss, HingeEmbeddingLoss, MarginRankingLoss, \
MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, SmoothL1Loss, \
SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, TripletMarginWithDistanceLoss, PoissonNLLLoss
SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, TripletMarginWithDistanceLoss, PoissonNLLLoss, GaussianNLLLoss
from .container import Container, Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict
from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \
MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, \
Expand Down Expand Up @@ -41,7 +41,7 @@
'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Softmin',
'Tanhshrink', 'RReLU', 'L1Loss', 'NLLLoss', 'KLDivLoss', 'MSELoss', 'BCELoss', 'BCEWithLogitsLoss',
'NLLLoss2d', 'PoissonNLLLoss', 'CosineEmbeddingLoss', 'CTCLoss', 'HingeEmbeddingLoss', 'MarginRankingLoss',
'MultiLabelMarginLoss', 'MultiLabelSoftMarginLoss', 'MultiMarginLoss', 'SmoothL1Loss',
'MultiLabelMarginLoss', 'MultiLabelSoftMarginLoss', 'MultiMarginLoss', 'SmoothL1Loss', 'GaussianNLLLoss',
'SoftMarginLoss', 'CrossEntropyLoss', 'Container', 'Sequential', 'ModuleList', 'ModuleDict',
'ParameterList', 'ParameterDict', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', "FractionalMaxPool3d",
Expand Down
79 changes: 79 additions & 0 deletions torch/nn/modules/loss.py
Expand Up @@ -299,6 +299,85 @@ def forward(self, log_input: Tensor, target: Tensor) -> Tensor:
eps=self.eps, reduction=self.reduction)


class GaussianNLLLoss(_Loss):
r"""Gaussian negative log likelihood loss.

The targets are treated as samples from Gaussian distributions with
expectations and variances predicted by the neural network. For a
D-dimensional ``target`` tensor modelled as having heteroscedastic Gaussian
distributions with a D-dimensional tensor of expectations ``input`` and a
D-dimensional tensor of positive variances ``var`` the loss is:

.. math::
\text{loss} = \frac{1}{2}\sum_{i=1}^D \left(\log\left(\text{max}\left(\text{var}[i],
\ \text{eps}\right)\right) + \frac{\left(\text{input}[i] - \text{target}[i]\right)^2}
{\text{max}\left(\text{var}[i], \ \text{eps}\right)}\right) + \text{const.}

where :attr:`eps` is used for stability. By default, the constant term of
the loss function is omitted unless :attr:`full` is ``True``. If ``var`` is
a scalar (implying ``target`` tensor has homoscedastic Gaussian
distributions) it is broadcasted to be the same size as the input.


Args:
full (bool, optional): include the constant term in the loss
calculation. Default: ``False``.
eps (float, optional): value used to clamp ``var`` (see note below), for
stability. Default: 1e-6.
reduction (string, optional): specifies the reduction to apply to the
output:``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
will be applied, ``'mean'``: the output is the average of all batch
member losses, ``'sum'``: the output is the sum of all batch member
losses. Default: ``'mean'``.

Shape:
- Input: :math:`(N, *)` where :math:`*` means any number of additional
dimensions
- Target: :math:`(N, *)`, same shape as the input
- Var: :math:`(N, 1)` or :math:`(N, *)`, same shape as the input
- Output: scalar if :attr:`reduction` is ``'mean'`` (default) or
``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N)`

Examples::

>>> loss = nn.GaussianNLLLoss()
>>> input = torch.randn(5, 2, requires_grad=True)
>>> target = torch.randn(5, 2)
>>> var = torch.ones(5, 2, requires_grad=True) #heteroscedastic
>>> output = loss(input, target, var)
>>> output.backward()


>>> loss = nn.GaussianNLLLoss()
>>> input = torch.randn(5, 2, requires_grad=True)
>>> target = torch.randn(5, 2)
>>> var = torch.ones(5, 1, requires_grad=True) #homoscedastic
>>> output = loss(input, target, var)
>>> output.backward()

Note:
The clamping of ``var`` is ignored with respect to autograd, and so the
gradients are unaffected by it.

Reference:
Nix, D. A. and Weigend, A. S., "Estimating the mean and variance of the
target probability distribution", Proceedings of 1994 IEEE International
Conference on Neural Networks (ICNN'94), Orlando, FL, USA, 1994, pp. 55-60
vol.1, doi: 10.1109/ICNN.1994.374138.
"""
__constants__ = ['full', 'eps', 'reduction']
full: bool
eps: float

def __init__(self, *, full: bool = False, eps: float = 1e-6, reduction: str = 'mean') -> None:
super(GaussianNLLLoss, self).__init__(None, None, reduction)
self.full = full
self.eps = eps

def forward(self, input: Tensor, target: Tensor, var: Tensor) -> Tensor:
return F.gaussian_nll_loss(input, target, var, full=self.full, eps=self.eps, reduction=self.reduction)


class KLDivLoss(_Loss):
r"""The Kullback-Leibler divergence loss measure

Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Expand Up @@ -621,6 +621,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.nn.functional.fractional_max_pool3d_with_indices: (
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
_random_samples=None: -1),
torch.nn.functional.gaussian_nll_loss: (lambda input, target, var, full=False, eps=1e-06, reduction='mean': -1),
torch.nn.functional.gelu: lambda input: -1,
torch.nn.functional.glu: lambda input, dim=-1: -1,
torch.nn.functional.grid_sample: lambda input, grid, mode='bilinear', padding_mode='zeros', align_corners=None: -1,
Expand Down