Skip to content

Commit

Permalink
Properly check that reduction strings are valid for l1_loss, smoothl1…
Browse files Browse the repository at this point in the history
…_loss, and mse_loss.

ghstack-source-id: 9751dfa1e5a2adf6d67319dcace1c937ad0b1705
Pull Request resolved: #43527
  • Loading branch information
gchanan committed Aug 27, 2020
1 parent ad93092 commit 3eb25cf
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
37 changes: 37 additions & 0 deletions test/test_nn.py
Expand Up @@ -9736,6 +9736,43 @@ def verify_reduction_scalars(input, reduction, output):
output = m(sigmoid(input), target)
verify_reduction_scalars(input, reduction, output)

# verify that bogus reduction strings are errors
@onlyOnCPUAndCUDA
def test_invalid_reduction_strings(self, device):
input = torch.randn(3, 5, requires_grad=True, device=device)
target = torch.tensor([1, 0, 4], device=device)

for reduction in ['none', 'invalid']:
def v(fn):
if reduction == 'invalid':
self.assertRaises(ValueError, lambda: fn())
else:
fn()

v(lambda: F.nll_loss(input, target, reduction=reduction))
v(lambda: F.cross_entropy(input, target, reduction=reduction))
v(lambda: F.multi_margin_loss(input, target, reduction=reduction))

v(lambda: F.kl_div(input, input, reduction=reduction))
v(lambda: F.smooth_l1_loss(input, input, reduction=reduction))
v(lambda: F.l1_loss(input, input, reduction=reduction))
v(lambda: F.mse_loss(input, input, reduction=reduction))
v(lambda: F.hinge_embedding_loss(input, input, reduction=reduction))
v(lambda: F.poisson_nll_loss(input, input, reduction=reduction))
v(lambda: F.binary_cross_entropy_with_logits(input, input, reduction=reduction))

zeros = torch.zeros_like(input).to(torch.int64)
v(lambda: F.multilabel_soft_margin_loss(input, zeros, reduction=reduction))
v(lambda: F.multilabel_margin_loss(input, zeros, reduction=reduction))

v(lambda: F.triplet_margin_loss(input, input, input, reduction=reduction))
v(lambda: F.margin_ranking_loss(input, input, input.sign(), reduction=reduction))
v(lambda: F.cosine_embedding_loss(input, input, input[:, 0].sign(), reduction=reduction))

# FIXME: should we allow derivatives on these?
v(lambda: F.binary_cross_entropy(torch.sigmoid(input), input.detach(), reduction=reduction))
v(lambda: F.soft_margin_loss(input, input.sign().detach(), reduction=reduction))

# We don't want to make propagating NaN a hard requirement on ops, but for
# these easy ones, we should make them do so.
def test_nonlinearity_propagate_nan(self, device):
Expand Down
3 changes: 3 additions & 0 deletions torch/nn/functional.py
Expand Up @@ -2602,6 +2602,7 @@ def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mea
if size_average is not None or reduce is not None:
reduction = _Reduction.legacy_get_string(size_average, reduce)
if target.requires_grad:
_Reduction.get_enum(reduction) # throw an error if reduction is invalid
ret = _smooth_l1_loss(input, target, delta=delta)
if reduction != 'none':
ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
Expand Down Expand Up @@ -2633,6 +2634,7 @@ def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
if size_average is not None or reduce is not None:
reduction = _Reduction.legacy_get_string(size_average, reduce)
if target.requires_grad:
_Reduction.get_enum(reduction) # throw an error if reduction is invalid
ret = torch.abs(input - target)
if reduction != 'none':
ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
Expand Down Expand Up @@ -2664,6 +2666,7 @@ def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):
if size_average is not None or reduce is not None:
reduction = _Reduction.legacy_get_string(size_average, reduce)
if target.requires_grad:
_Reduction.get_enum(reduction) # throw an error if reduction is invalid
ret = (input - target) ** 2
if reduction != 'none':
ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
Expand Down

0 comments on commit 3eb25cf

Please sign in to comment.