Skip to content

Commit

Permalink
Implement NLLLossNd (#4035)
Browse files Browse the repository at this point in the history
* Implement NLLLossNd

* Fix tests and typos

* Fix tests
  • Loading branch information
zou3519 authored and soumith committed Dec 18, 2017
1 parent 7f41149 commit 30e6898
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 23 deletions.
31 changes: 17 additions & 14 deletions test/common_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,24 +257,27 @@ def kldivloss_reference(input, target, size_average=True, reduce=True):
return result


def nllloss2d_reference(input, target, weight=None, ignore_index=-100,
def nlllossNd_reference(input, target, weight=None, ignore_index=-100,
size_average=True, reduce=True):
N, C, H, W = input.size()
output = torch.zeros(N, H, W).type_as(input)
assert input.dim() >= 4
N = input.size(0)
C = input.size(1)
out_size = (N,) + input.size()[2:]
output = torch.zeros(out_size).type_as(input)
if isinstance(target, Variable):
target = target.data

if weight is None:
weight = torch.ones(C).type_as(input)

total_weight_data = 0
for n in range(0, N):
for h in range(0, H):
for w in range(0, W):
t_nhw = target[n][h][w]
norm = 0. if ignore_index == t_nhw else weight[t_nhw]
output[n][h][w] = -input[n][t_nhw][h][w] * norm
total_weight_data += norm
for tup in product(*[range(size) for size in out_size]):
t_nx = target[tup]
norm = 0. if ignore_index == t_nx else weight[t_nx]
input_index = list(tup)
input_index.insert(1, t_nx)
output[tup] = -input[tuple(input_index)] * norm
total_weight_data += norm

if reduce and size_average:
return output.sum() / total_weight_data
Expand Down Expand Up @@ -322,7 +325,7 @@ def smoothl1loss_reference(input, target, size_average=True, reduce=True):
loss_reference_fns = {
'KLDivLoss': kldivloss_reference,
'NLLLoss': nllloss_reference,
'NLLLoss2d': nllloss2d_reference,
'NLLLossNd': nlllossNd_reference,
'SmoothL1Loss': smoothl1loss_reference,
}

Expand Down Expand Up @@ -424,7 +427,7 @@ def smoothl1loss_reference(input, target, size_average=True, reduce=True):
input_size=(2, 3, 5, 5),
target_fn=lambda: torch.rand(2, 5, 5).mul(3).floor().long(),
reference_fn=lambda i, t, m:
nllloss2d_reference(i, t, size_average=get_size_average(m)),
nlllossNd_reference(i, t, size_average=get_size_average(m)),
check_no_size_average=True,
),
dict(
Expand All @@ -433,7 +436,7 @@ def smoothl1loss_reference(input, target, size_average=True, reduce=True):
input_size=(2, 3, 5, 5),
target=torch.rand(2, 5, 5).mul(3).floor().long(),
reference_fn=lambda i, t, m:
nllloss2d_reference(i, t, weight=get_weight(m)),
nlllossNd_reference(i, t, weight=get_weight(m)),
desc='weights',
),
dict(
Expand All @@ -442,7 +445,7 @@ def smoothl1loss_reference(input, target, size_average=True, reduce=True):
input_size=(2, 3, 5, 5),
target_fn=lambda: torch.rand(2, 5, 5).mul(3).floor().long(),
reference_fn=lambda i, t, m:
nllloss2d_reference(i, t, ignore_index=1),
nlllossNd_reference(i, t, ignore_index=1),
desc='ignore_index',
),
dict(
Expand Down
63 changes: 59 additions & 4 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from torch.nn.parallel._functions import Broadcast
from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
module_tests, criterion_tests, TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \
TEST_CUDNN_VERSION, loss_reference_fns
TEST_CUDNN_VERSION, loss_reference_fns, get_size_average
from common import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, \
TEST_SCIPY, download_file

Expand Down Expand Up @@ -3927,6 +3927,15 @@ def forward(self, *args):
target_fn=lambda: torch.randn(15, 10).gt(0).double(),
desc='weights'
),
dict(
module_name='NLLLoss',
input_size=(2, 3, 5, 5, 2, 2),
target_fn=lambda: torch.rand(2, 5, 5, 2, 2).mul(3).floor().long(),
reference_fn=lambda i, t, m:
loss_reference_fns['NLLLossNd'](i, t, size_average=get_size_average(m)),
check_no_size_average=True,
desc='higher_dim'
),
dict(
module_name='PoissonNLLLoss',
input_size=(2, 3, 4, 5),
Expand Down Expand Up @@ -4066,7 +4075,7 @@ def nllloss2d_no_reduce_test():
lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
reference_fn=lambda i, _:
loss_reference_fns['NLLLoss2d'](i, t.type_as(i).long(), **kwargs),
loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
pickle=False)


Expand All @@ -4079,7 +4088,7 @@ def nllloss2d_no_reduce_ignore_index_test():
lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
reference_fn=lambda i, _:
loss_reference_fns['NLLLoss2d'](i, t.type_as(i).long(), **kwargs),
loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
pickle=False)


Expand All @@ -4096,7 +4105,50 @@ def kwargs(i):
lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i.data))),
input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
reference_fn=lambda i, _:
loss_reference_fns['NLLLoss2d'](i, t.type_as(i).long(), **kwargs(i)),
loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)),
pickle=False)


def nlllossNd_no_reduce_test():
t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
kwargs = {'reduce': False}
return dict(
fullname='NLLLossNd_no_reduce',
constructor=wrap_functional(
lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
reference_fn=lambda i, _:
loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
pickle=False)


def nlllossNd_no_reduce_ignore_index_test():
t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
kwargs = {'ignore_index': 1, 'reduce': False}
return dict(
fullname='NLLLossNd_no_reduce_ignore_index',
constructor=wrap_functional(
lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
reference_fn=lambda i, _:
loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
pickle=False)


def nlllossNd_no_reduce_weights_test():
t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
weight = torch.rand(3)

def kwargs(i):
return {'weight': weight.type_as(i), 'reduce': False}

return dict(
fullname='NLLLossNd_no_reduce_weights',
constructor=wrap_functional(
lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i.data))),
input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
reference_fn=lambda i, _:
loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)),
pickle=False)


Expand Down Expand Up @@ -4124,6 +4176,9 @@ def smoothl1loss_no_reduce_test():
nllloss2d_no_reduce_test(),
nllloss2d_no_reduce_weights_test(),
nllloss2d_no_reduce_ignore_index_test(),
nlllossNd_no_reduce_test(),
nlllossNd_no_reduce_weights_test(),
nlllossNd_no_reduce_ignore_index_test(),
smoothl1loss_no_reduce_test(),
dict(
module_name='BatchNorm1d',
Expand Down
23 changes: 20 additions & 3 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,8 +1072,12 @@ def nll_loss(input, target, weight=None, size_average=True, ignore_index=-100, r
Args:
input: :math:`(N, C)` where `C = number of classes` or `(N, C, H, W)`
in case of 2D - Loss
target: :math:`(N)` where each value is `0 <= targets[i] <= C-1`
in case of 2D Loss, or `(N, C, *) in the case of K-dimensional Loss,
where :math:`K > 2` and `*` is `K` extra dimensions.
target: :math:`(N)` where each value is `0 <= targets[i] <= C-1`.
In the case of 2D Loss, then :math:`(N, H, W)`. For K-dimensional
Loss where :math:`K > 2`, then :math:`(N, *)`, where `*` is `K`
extra dimensions.
weight (Tensor, optional): a manual rescaling weight given to each
class. If given, has to be a Tensor of size `C`
size_average (bool, optional): By default, the losses are averaged
Expand All @@ -1099,8 +1103,21 @@ def nll_loss(input, target, weight=None, size_average=True, ignore_index=-100, r
return torch._C._nn.nll_loss(input, target, weight, size_average, ignore_index, reduce)
elif dim == 4:
return torch._C._nn.nll_loss2d(input, target, weight, size_average, ignore_index, reduce)
elif dim > 4:
n = input.size(0)
c = input.size(1)
out_size = (n,) + input.size()[2:]
if target.size()[1:] != input.size()[2:]:
raise ValueError('Expected target size {}, got {}'.format(
out_size, input.size()))
input = input.contiguous().view(n, c, 1, -1)
target = target.contiguous().view(n, 1, -1)
if reduce:
return torch._C._nn.nll_loss2d(input, target, weight, size_average, ignore_index, reduce)
out = torch._C._nn.nll_loss2d(input, target, weight, size_average, ignore_index, reduce)
return out.view(out_size)
else:
raise ValueError('Expected 2 or 4 dimensions (got {})'.format(dim))
raise ValueError('Expected 2, 4, or more than 4 dimensions (got {})'.format(dim))


def poisson_nll_loss(input, target, log_input=True, full=False, size_average=True, eps=1e-8):
Expand Down
10 changes: 8 additions & 2 deletions torch/nn/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,15 @@ class NLLLoss(_WeightedLoss):
Default: ``True``
Shape:
- Input: :math:`(N, C)` where `C = number of classes`
- Target: :math:`(N)` where each value is `0 <= targets[i] <= C-1`
- Input: :math:`(N, C)` where `C = number of classes`.
In the case of K-dimensional loss where :math:`K >= 2`, then
:math:`(N, C, *)` where `*` is `K` extra dimensions.
- Target: :math:`(N)` where each value is `0 <= targets[i] <= C-1`.
In the case of K-dimensional loss, where :math:`K >= 2`, then
:math:`(N, C, *)` where `*` is `K` extra dimensions.
- Output: scalar. If reduce is ``False``, then :math:`(N)` instead.
In the case of K-dimensional loss and reduce is ``False``, then
:math:`(N, C, *)`, the same size as the target.
Examples::
Expand Down

0 comments on commit 30e6898

Please sign in to comment.