diff --git a/docs/source/nn.rst b/docs/source/nn.rst index 74f7994447a1..4e0beaae41f0 100644 --- a/docs/source/nn.rst +++ b/docs/source/nn.rst @@ -276,6 +276,7 @@ Loss Functions nn.CTCLoss nn.NLLLoss nn.PoissonNLLLoss + nn.GaussianNLLLoss nn.KLDivLoss nn.BCELoss nn.BCEWithLogitsLoss diff --git a/test/test_nn.py b/test/test_nn.py index ae9dd8d23d9f..0512c55bb86f 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -4788,6 +4788,34 @@ def test_poisson_nll_loss_reduction_modes(self): with self.assertRaisesRegex(ValueError, 'is not valid'): F.poisson_nll_loss(input, target, reduction='total') + def test_gaussian_nll_loss_reduction_modes(self): + input = torch.tensor([[0.5, 1.5, 2.5], [2., 4., 6.]]) + target = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) + var = torch.tensor([[0.5, 1., 1.5], [1., 1.5, 2.]]) + component_wise_loss = 0.5 * (torch.sum(torch.log(var) + (input - target)**2 / var, dim=1)) + self.assertEqual(component_wise_loss, + F.gaussian_nll_loss(input, target, var, reduction='none')) + self.assertEqual(torch.sum(component_wise_loss), + F.gaussian_nll_loss(input, target, var, reduction='sum')) + self.assertEqual(torch.mean(component_wise_loss), + F.gaussian_nll_loss(input, target, var, reduction='mean')) + with self.assertRaisesRegex(ValueError, 'is not valid'): + F.gaussian_nll_loss(input, target, var, reduction='total') + + def test_gaussian_nll_loss_args(self): + input = torch.randn(3, 5) + with self.assertRaisesRegex(ValueError, 'input and target must have same size'): + target = torch.randn(3, 6) + var = torch.ones(3, 5) + torch.nn.functional.gaussian_nll_loss(input, target, var) + with self.assertRaisesRegex(ValueError, 'var is of incorrect size'): + target = torch.randn(3, 5) + var = torch.ones(3, 3) + torch.nn.functional.gaussian_nll_loss(input, target, var) + with self.assertRaisesRegex(ValueError, 'var has negative entry/entries'): + var = -1 * torch.ones(3, 5) + torch.nn.functional.gaussian_nll_loss(input, target, var) + def test_KLDivLoss_batch_mean(self): input_shape = (2, 5) log_prob1 = F.log_softmax(torch.randn(input_shape), 1) @@ -11356,6 +11384,7 @@ def test_invalid_reduction_strings(self, device): input = torch.randn(3, 5, requires_grad=True, device=device) cinput = torch.randn(3, 5, requires_grad=True, device=device, dtype=torch.cfloat) target = torch.tensor([1, 0, 4], device=device) + var = torch.ones(size=input.size(), requires_grad=True, device=device) for reduction in ['none', 'invalid']: def v(fn): @@ -11375,6 +11404,7 @@ def v(fn): v(lambda: F.mse_loss(input, input, reduction=reduction)) v(lambda: F.hinge_embedding_loss(input, input, reduction=reduction)) v(lambda: F.poisson_nll_loss(input, input, reduction=reduction)) + v(lambda: F.gaussian_nll_loss(input, input, var, reduction=reduction)) v(lambda: F.binary_cross_entropy_with_logits(input, input, reduction=reduction)) zeros = torch.zeros_like(input).to(torch.int64) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index ca2aaa5f9a40..7491ccf0384a 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2485,6 +2485,72 @@ def poisson_nll_loss( return ret +def gaussian_nll_loss(input, target, var, *, full=False, eps=1e-6, reduction='mean'): + r"""Gaussian negative log likelihood loss. + + See :class:`~torch.nn.GaussianNLLLoss` for details. + + Args: + input: expectation of the Gaussian distribution. + target: sample from the Gaussian distribution. + var: tensor of positive variance(s), one for each of the expectations + in the input (heteroscedastic), or a single one (homoscedastic). + full: ``True``/``False`` (bool), include the constant term in the loss + calculation. Default: ``False``. + eps: value added to var, for stability. Default: 1e-6. + reduction: specifies the reduction to apply to the output: + `'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the output is the average of all batch member losses, + ``'sum'``: the output is the sum of all batch member losses. + Default: ``'mean'``. + """ + if not torch.jit.is_scripting(): + tens_ops = (input, target, var) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + gaussian_nll_loss, tens_ops, input, target, var, full=full, eps=eps, reduction=reduction) + + # Inputs and targets much have same shape + input = input.view(input.size(0), -1) + target = target.view(target.size(0), -1) + if input.size() != target.size(): + raise ValueError("input and target must have same size") + + # Second dim of var must match that of input or be equal to 1 + var = var.view(input.size(0), -1) + if var.size(1) != input.size(1) and var.size(1) != 1: + raise ValueError("var is of incorrect size") + + # Check validity of reduction mode + if reduction != 'none' and reduction != 'mean' and reduction != 'sum': + raise ValueError(reduction + " is not valid") + + # Entries of var must be non-negative + if torch.any(var < 0): + raise ValueError("var has negative entry/entries") + + # Clamp for stability + var = var.clone() + with torch.no_grad(): + var.clamp_(min=eps) + + # Calculate loss (without constant) + loss = 0.5 * (torch.log(var) + (input - target)**2 / var).view(input.size(0), -1).sum(dim=1) + + # Add constant to loss term if required + if full: + D = input.size(1) + loss = loss + 0.5 * D * math.log(2 * math.pi) + + # Apply reduction + if reduction == 'mean': + return loss.mean() + elif reduction == 'sum': + return loss.sum() + else: + return loss + + def kl_div( input: Tensor, target: Tensor, diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in index 4a38dca9bdb8..9794d34eec24 100644 --- a/torch/nn/functional.pyi.in +++ b/torch/nn/functional.pyi.in @@ -229,6 +229,10 @@ def poisson_nll_loss(input: Tensor, target: Tensor, log_input: bool = ..., full: reduction: str = ...) -> Tensor: ... +def gaussian_nll_loss(input: Tensor, target: Tensor, var: Tensor, full: Optional[bool] = ..., + eps: Optional[float] = ..., reduction: Optional[str] = ...) -> Tensor: ... + + def kl_div(input: Tensor, target: Tensor, size_average: Optional[bool] = ..., reduce: Optional[bool] = ..., reduction: str = ..., log_target: bool = ...) -> Tensor: ... diff --git a/torch/nn/modules/__init__.py b/torch/nn/modules/__init__.py index 4911d4bef38f..e6a7543bd9ba 100644 --- a/torch/nn/modules/__init__.py +++ b/torch/nn/modules/__init__.py @@ -10,7 +10,7 @@ from .loss import L1Loss, NLLLoss, KLDivLoss, MSELoss, BCELoss, BCEWithLogitsLoss, NLLLoss2d, \ CosineEmbeddingLoss, CTCLoss, HingeEmbeddingLoss, MarginRankingLoss, \ MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, SmoothL1Loss, \ - SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, TripletMarginWithDistanceLoss, PoissonNLLLoss + SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, TripletMarginWithDistanceLoss, PoissonNLLLoss, GaussianNLLLoss from .container import Container, Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \ MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, \ @@ -41,7 +41,7 @@ 'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Softmin', 'Tanhshrink', 'RReLU', 'L1Loss', 'NLLLoss', 'KLDivLoss', 'MSELoss', 'BCELoss', 'BCEWithLogitsLoss', 'NLLLoss2d', 'PoissonNLLLoss', 'CosineEmbeddingLoss', 'CTCLoss', 'HingeEmbeddingLoss', 'MarginRankingLoss', - 'MultiLabelMarginLoss', 'MultiLabelSoftMarginLoss', 'MultiMarginLoss', 'SmoothL1Loss', + 'MultiLabelMarginLoss', 'MultiLabelSoftMarginLoss', 'MultiMarginLoss', 'SmoothL1Loss', 'GaussianNLLLoss', 'SoftMarginLoss', 'CrossEntropyLoss', 'Container', 'Sequential', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', "FractionalMaxPool3d", diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index a2642fb4f149..000ae49e4910 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -299,6 +299,85 @@ def forward(self, log_input: Tensor, target: Tensor) -> Tensor: eps=self.eps, reduction=self.reduction) +class GaussianNLLLoss(_Loss): + r"""Gaussian negative log likelihood loss. + + The targets are treated as samples from Gaussian distributions with + expectations and variances predicted by the neural network. For a + D-dimensional ``target`` tensor modelled as having heteroscedastic Gaussian + distributions with a D-dimensional tensor of expectations ``input`` and a + D-dimensional tensor of positive variances ``var`` the loss is: + + .. math:: + \text{loss} = \frac{1}{2}\sum_{i=1}^D \left(\log\left(\text{max}\left(\text{var}[i], + \ \text{eps}\right)\right) + \frac{\left(\text{input}[i] - \text{target}[i]\right)^2} + {\text{max}\left(\text{var}[i], \ \text{eps}\right)}\right) + \text{const.} + + where :attr:`eps` is used for stability. By default, the constant term of + the loss function is omitted unless :attr:`full` is ``True``. If ``var`` is + a scalar (implying ``target`` tensor has homoscedastic Gaussian + distributions) it is broadcasted to be the same size as the input. + + + Args: + full (bool, optional): include the constant term in the loss + calculation. Default: ``False``. + eps (float, optional): value used to clamp ``var`` (see note below), for + stability. Default: 1e-6. + reduction (string, optional): specifies the reduction to apply to the + output:``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction + will be applied, ``'mean'``: the output is the average of all batch + member losses, ``'sum'``: the output is the sum of all batch member + losses. Default: ``'mean'``. + + Shape: + - Input: :math:`(N, *)` where :math:`*` means any number of additional + dimensions + - Target: :math:`(N, *)`, same shape as the input + - Var: :math:`(N, 1)` or :math:`(N, *)`, same shape as the input + - Output: scalar if :attr:`reduction` is ``'mean'`` (default) or + ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N)` + + Examples:: + + >>> loss = nn.GaussianNLLLoss() + >>> input = torch.randn(5, 2, requires_grad=True) + >>> target = torch.randn(5, 2) + >>> var = torch.ones(5, 2, requires_grad=True) #heteroscedastic + >>> output = loss(input, target, var) + >>> output.backward() + + + >>> loss = nn.GaussianNLLLoss() + >>> input = torch.randn(5, 2, requires_grad=True) + >>> target = torch.randn(5, 2) + >>> var = torch.ones(5, 1, requires_grad=True) #homoscedastic + >>> output = loss(input, target, var) + >>> output.backward() + + Note: + The clamping of ``var`` is ignored with respect to autograd, and so the + gradients are unaffected by it. + + Reference: + Nix, D. A. and Weigend, A. S., "Estimating the mean and variance of the + target probability distribution", Proceedings of 1994 IEEE International + Conference on Neural Networks (ICNN'94), Orlando, FL, USA, 1994, pp. 55-60 + vol.1, doi: 10.1109/ICNN.1994.374138. + """ + __constants__ = ['full', 'eps', 'reduction'] + full: bool + eps: float + + def __init__(self, *, full: bool = False, eps: float = 1e-6, reduction: str = 'mean') -> None: + super(GaussianNLLLoss, self).__init__(None, None, reduction) + self.full = full + self.eps = eps + + def forward(self, input: Tensor, target: Tensor, var: Tensor) -> Tensor: + return F.gaussian_nll_loss(input, target, var, full=self.full, eps=self.eps, reduction=self.reduction) + + class KLDivLoss(_Loss): r"""The Kullback-Leibler divergence loss measure diff --git a/torch/overrides.py b/torch/overrides.py index 0238077e7be5..1b4d3a83774e 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -621,6 +621,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.nn.functional.fractional_max_pool3d_with_indices: ( lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1), + torch.nn.functional.gaussian_nll_loss: (lambda input, target, var, full=False, eps=1e-06, reduction='mean': -1), torch.nn.functional.gelu: lambda input: -1, torch.nn.functional.glu: lambda input, dim=-1: -1, torch.nn.functional.grid_sample: lambda input, grid, mode='bilinear', padding_mode='zeros', align_corners=None: -1,