From cd906826b6aff2c6067a871de308129251b45de5 Mon Sep 17 00:00:00 2001 From: David Pollack Date: Sat, 13 Jan 2018 12:19:31 +0100 Subject: [PATCH] current code works with dim = 3, so I added it to dim checks --- test/common_nn.py | 2 +- test/test_nn.py | 9 +++++++++ torch/nn/functional.py | 8 ++++---- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/test/common_nn.py b/test/common_nn.py index b54bdd525af01..0d422e2a7841f 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -259,7 +259,7 @@ def kldivloss_reference(input, target, size_average=True, reduce=True): def nlllossNd_reference(input, target, weight=None, ignore_index=-100, size_average=True, reduce=True): - assert input.dim() >= 4 + assert input.dim() >= 3 N = input.size(0) C = input.size(1) out_size = (N,) + input.size()[2:] diff --git a/test/test_nn.py b/test/test_nn.py index 786f7d36515d2..8eb8eb6737fde 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -4266,6 +4266,15 @@ def forward(self, *args): check_no_size_average=True, desc='higher_dim' ), + dict( + module_name='NLLLoss', + input_size=(2, 3, 5), + target_fn=lambda: torch.rand(2, 5).mul(3).floor().long(), + reference_fn=lambda i, t, m: + loss_reference_fns['NLLLossNd'](i, t, size_average=get_size_average(m)), + check_no_size_average=True, + desc='dim_is_3' + ), dict( module_name='PoissonNLLLoss', input_size=(2, 3, 4, 5), diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 53b051714f533..e82021d737002 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1162,10 +1162,10 @@ def nll_loss(input, target, weight=None, size_average=True, ignore_index=-100, r Args: input: :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)` - in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K > 2` + in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K > 1` in the case of K-dimensional loss. target: :math:`(N)` where each value is `0 <= targets[i] <= C-1`, - or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K >= 2` for + or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K >= 1` for K-dimensional loss. weight (Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size `C` @@ -1192,7 +1192,7 @@ def nll_loss(input, target, weight=None, size_average=True, ignore_index=-100, r return torch._C._nn.nll_loss(input, target, weight, size_average, ignore_index, reduce) elif dim == 4: return torch._C._nn.nll_loss2d(input, target, weight, size_average, ignore_index, reduce) - elif dim > 4: + elif dim == 3 or dim > 4: n = input.size(0) c = input.size(1) out_size = (n,) + input.size()[2:] @@ -1206,7 +1206,7 @@ def nll_loss(input, target, weight=None, size_average=True, ignore_index=-100, r out = torch._C._nn.nll_loss2d(input, target, weight, size_average, ignore_index, reduce) return out.view(out_size) else: - raise ValueError('Expected 2, 4, or more than 4 dimensions (got {})'.format(dim)) + raise ValueError('Expected 2 or more dimensions (got {})'.format(dim)) def poisson_nll_loss(input, target, log_input=True, full=False, size_average=True, eps=1e-8, reduce=True):