Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions aten/src/ATen/native/Loss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@ Tensor poisson_nll_loss(const Tensor& input, const Tensor& target, const bool lo
} else {
loss = input - target * at::log(input + eps);
}

if (full) {
auto mask1 = (target > 1);
loss.masked_select(mask1) += (target * at::log(target) - target + 0.5 * at::log(2 * M_PI * target)).masked_select(mask1);
auto stirling_term = target * at::log(target) - target + 0.5 * at::log(2 * M_PI * target);
loss += stirling_term.masked_fill(target <= 1, 0);
}

return apply_loss_reduction(loss, reduction);
Expand Down
30 changes: 26 additions & 4 deletions test/common_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from itertools import product
from functools import reduce
from operator import mul
from math import pi


import torch
Expand Down Expand Up @@ -280,10 +281,11 @@ def forward(self, *args):
def poissonnllloss_no_reduce_test():
t = torch.randn(10, 10)
return dict(
fullname='PoissonNLLLLoss_no_reduce',
fullname='PoissonNLLLoss_no_reduce',
constructor=wrap_functional(
lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none')),
input_fn=lambda: torch.rand(10, 10),
reference_fn=lambda i, *_: i.exp() - t.mul(i),
pickle=False)


Expand Down Expand Up @@ -3136,18 +3138,38 @@ def padding3d_circular(input, pad):
check_sum_reduction=True,
desc='dim_is_3',
),
dict(
module_name='PoissonNLLLoss', # Default is log_input=True, full=False
input_size=(2, 3, 4, 5),
target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(),
reference_fn=lambda i, t, _: (i.exp() - t.mul(i)).mean(),
desc='no_full_loss',
),
dict(
module_name='PoissonNLLLoss',
constructor_args=(False, False), # log_input=False, full=False
input_fn=lambda: torch.randn(2, 3, 4, 5).abs_().add_(0.001),
target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(),
reference_fn=lambda i, t, _: (i - t.mul((i + 1e-8).log())).mean(),
desc='no_full_loss_no_log_input',
),
dict(
module_name='PoissonNLLLoss',
constructor_args=(True, True), # log_input=True, full=True
input_size=(2, 3, 4, 5),
target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(),
desc='no_full_loss', # without sterling approx
reference_fn=lambda i, t, _:
(i.exp() - t.mul(i) + (t.mul(t.log()) - t + 0.5 * (2. * pi * t).log()).masked_fill(t <= 1, 0)).mean(),
desc='full_loss',
),
dict(
module_name='PoissonNLLLoss',
constructor_args=(False,),
constructor_args=(False, True), # log_input=False, full=True
input_fn=lambda: torch.randn(2, 3, 4, 5).abs_().add_(0.001),
target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(),
desc='full_loss', # with sterling approx
reference_fn=lambda i, t, _:
(i - t.mul((i + 1e-8).log()) + (t.mul(t.log()) - t + 0.5 * (2. * pi * t).log()).masked_fill(t <= 1, 0)).mean(),
desc='full_loss_no_log_input',
),
dict(
module_name='L1Loss',
Expand Down