Skip to content

Commit

Permalink
Implement MarginRankingLoss as native function and add reduce=True ar…
Browse files Browse the repository at this point in the history
…g to it (#5346)

* add reduce=True arg to MarginRankingLoss

* make default margin arg match for legacy

* remove accidentally added test

* fix test

* fix native_functions.yaml alphabetical order
  • Loading branch information
li-roy authored and soumith committed Mar 21, 2018
1 parent a3bd7b2 commit e4eee7c
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 59 deletions.
11 changes: 11 additions & 0 deletions aten/src/ATen/native/Loss.cpp
Expand Up @@ -60,4 +60,15 @@ Tensor triplet_margin_loss(const Tensor& anchor, const Tensor& positive, const T
}
return output;
}

Tensor margin_ranking_loss(const Tensor& input1, const Tensor& input2, const Tensor& target, double margin, bool size_average, bool reduce) {
auto output = (-target * (input1 - input2) + margin).clamp_min_(0);

if (reduce && size_average) {
return output.sum() / output.numel();
} else if (reduce) {
return output.sum();
}
return output;
}
}} // namespace at::native
3 changes: 3 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -324,6 +324,9 @@
- func: logspace_out(Tensor result, Scalar start, Scalar end, int64_t steps=100) -> Tensor
variants: function

- func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, double margin, bool size_average, bool reduce) -> Tensor
variants: function

- func: matmul(Tensor self, Tensor other) -> Tensor

- func: max_pool1d(Tensor self, IntList[1] kernel_size, IntList[1] stride={}, IntList[1] padding=0, IntList[1] dilation=1, bool ceil_mode=false) -> (Tensor, Tensor)
Expand Down
16 changes: 15 additions & 1 deletion test/common_nn.py
Expand Up @@ -439,6 +439,15 @@ def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps
return output


def marginrankingloss_reference(input1, input2, target, margin=0, size_average=True, reduce=True):
output = (-target * (input1 - input2) + margin).clamp(min=0)
if reduce and size_average:
return output.mean()
elif reduce:
return output.sum()
return output


loss_reference_fns = {
'KLDivLoss': kldivloss_reference,
'NLLLoss': nllloss_reference,
Expand All @@ -450,6 +459,7 @@ def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps
'MultiMarginLoss': multimarginloss_reference,
'CosineEmbeddingLoss': cosineembeddingloss_reference,
'TripletMarginLoss': tripletmarginloss_reference,
'MarginRankingLoss': marginrankingloss_reference,
}


Expand Down Expand Up @@ -685,13 +695,17 @@ def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps
module_name='MarginRankingLoss',
input_fn=lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10)),
target_fn=lambda: torch.randn(50).sign(),
reference_fn=lambda i, t, m:
marginrankingloss_reference(i[0], i[1], t, size_average=get_size_average(m)),
check_no_size_average=True,
),
dict(
module_name='MarginRankingLoss',
constructor_args=(2,),
constructor_args=(0.5,),
input_fn=lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10)),
target_fn=lambda: torch.randn(50).sign(),
reference_fn=lambda i, t, m:
marginrankingloss_reference(i[0], i[1], t, margin=0.5, size_average=get_size_average(m)),
desc='margin',
check_no_size_average=True,
),
Expand Down
18 changes: 18 additions & 0 deletions test/test_nn.py
Expand Up @@ -3896,6 +3896,24 @@ def test_cosine_embedding_loss_margin_no_reduce(self):
self.assertEqual(F.cosine_embedding_loss(input1, input2, target, margin=0.5, reduce=False),
loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target, margin=0.5, reduce=False))

def test_margin_ranking_loss_no_reduce(self):
input1 = Variable(torch.randn(15).mul(10), requires_grad=True)
input2 = Variable(torch.randn(15).mul(10), requires_grad=True)
target = Variable(torch.randn(15).sign())
self.assertTrue(gradcheck(lambda x, y, z: F.margin_ranking_loss(
x, y, z, reduce=False), (input1, input2, target)))
self.assertEqual(F.margin_ranking_loss(input1, input2, target, reduce=False),
loss_reference_fns['MarginRankingLoss'](input1, input2, target, reduce=False))

def test_margin_ranking_loss_margin_no_reduce(self):
input1 = Variable(torch.randn(15).mul(10), requires_grad=True)
input2 = Variable(torch.randn(15).mul(10), requires_grad=True)
target = Variable(torch.randn(15).sign())
self.assertTrue(gradcheck(lambda x, y, z: F.margin_ranking_loss(
x, y, z, margin=0.5, reduce=False), (input1, input2, target)))
self.assertEqual(F.margin_ranking_loss(input1, input2, target, margin=0.5, reduce=False),
loss_reference_fns['MarginRankingLoss'](input1, input2, target, margin=0.5, reduce=False))

def test_triplet_margin_loss(self):
input1 = Variable(torch.randn(5, 10), requires_grad=True)
input2 = Variable(torch.randn(5, 10), requires_grad=True)
Expand Down
2 changes: 1 addition & 1 deletion torch/legacy/nn/MarginRankingCriterion.py
Expand Up @@ -4,7 +4,7 @@

class MarginRankingCriterion(Criterion):

def __init__(self, margin=1, sizeAverage=True):
def __init__(self, margin=0, sizeAverage=True):
super(MarginRankingCriterion, self).__init__()
self.margin = margin
self.sizeAverage = sizeAverage
Expand Down
42 changes: 0 additions & 42 deletions torch/nn/_functions/loss.py

This file was deleted.

2 changes: 0 additions & 2 deletions torch/nn/backends/thnn.py
Expand Up @@ -23,7 +23,6 @@ def _initialize_backend():
from .._functions.rnn import RNN, \
RNNTanhCell, RNNReLUCell, GRUCell, LSTMCell
from .._functions.dropout import Dropout, FeatureDropout
from .._functions.loss import MarginRankingLoss

backend.register_function('RNN', RNN)
backend.register_function('RNNTanhCell', RNNTanhCell)
Expand All @@ -33,7 +32,6 @@ def _initialize_backend():
backend.register_function('Dropout', Dropout)
backend.register_function('Dropout2d', FeatureDropout)
backend.register_function('Dropout3d', FeatureDropout)
backend.register_function('MarginRankingLoss', MarginRankingLoss)
for cls in _thnn_functions:
name = cls.__name__
backend.register_function(name, cls)
Expand Down
6 changes: 3 additions & 3 deletions torch/nn/functional.py
Expand Up @@ -1611,15 +1611,15 @@ def mse_loss(input, target, size_average=True, reduce=True):
input, target, size_average, reduce)


def margin_ranking_loss(input1, input2, target, margin=0, size_average=True):
r"""margin_ranking_loss(input1, input2, target, margin=0, size_average=True) -> Tensor
def margin_ranking_loss(input1, input2, target, margin=0, size_average=True, reduce=True):
r"""margin_ranking_loss(input1, input2, target, margin=0, size_average=True, reduce=True) -> Tensor
See :class:`~torch.nn.MarginRankingLoss` for details.
"""
if input1.dim() == 0 or input2.dim() == 0 or target.dim() == 0:
raise RuntimeError(("margin_ranking_loss does not support scalars, got sizes: "
"input1: {}, input2: {}, target: {} ".format(input1.size(), input2.size(), target.size())))
return _functions.loss.MarginRankingLoss.apply(input1, input2, target, margin, size_average)
return torch._C._VariableFunctions.margin_ranking_loss(input1, input2, target, margin, size_average, reduce)


def hinge_embedding_loss(input, target, margin=1.0, size_average=True, reduce=True):
Expand Down
31 changes: 21 additions & 10 deletions torch/nn/modules/loss.py
Expand Up @@ -838,7 +838,7 @@ def forward(self, input1, input2, target):
self.reduce)


class MarginRankingLoss(Module):
class MarginRankingLoss(_Loss):
r"""Creates a criterion that measures the loss given
inputs `x1`, `x2`, two 1D mini-batch `Tensor`s,
and a label 1D mini-batch tensor `y` with values (`1` or `-1`).
Expand All @@ -851,20 +851,31 @@ class MarginRankingLoss(Module):
.. math::
\text{loss}(x, y) = \max(0, -y * (x1 - x2) + \text{margin})
if the internal variable `size_average = True`,
the loss function averages the loss over the batch samples;
if `size_average = False`, then the loss function sums over the batch
samples.
By default, `size_average` equals to ``True``.
Args:
margin (float, optional): Has a default value of `0`.
size_average (bool, optional): By default, the losses are averaged over
observations for each minibatch. However, if the field :attr:`size_average`
is set to ``False``, the losses are instead summed for each minibatch.
Default: ``True``
reduce (bool, optional): By default, the losses are averaged or summed over
observations for each minibatch depending on :attr:`size_average`. When
:attr:`reduce` is ``False``, returns a loss per batch element instead and
ignores :attr:`size_average`. Default: ``True``
Shape:
- Input: :math:`(N, D)` where `N` is the batch size and `D` is the size of a sample.
- Target: :math:`(N)`
- Output: scalar. If `reduce` is False, then `(N)`.
"""

def __init__(self, margin=0, size_average=True):
super(MarginRankingLoss, self).__init__()
def __init__(self, margin=0, size_average=True, reduce=True):
super(MarginRankingLoss, self).__init__(size_average)
self.margin = margin
self.size_average = size_average
self.reduce = reduce

def forward(self, input1, input2, target):
return F.margin_ranking_loss(input1, input2, target, self.margin, self.size_average)
return F.margin_ranking_loss(input1, input2, target, self.margin, self.size_average,
self.reduce)


class MultiMarginLoss(_WeightedLoss):
Expand Down

0 comments on commit e4eee7c

Please sign in to comment.