Skip to content

Commit

Permalink
Update on "[pytorch] Add triplet margin loss with custom distance"
Browse files Browse the repository at this point in the history
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23363898](https://our.internmc.facebook.com/intern/diff/D23363898)

[ghstack-poisoned]
  • Loading branch information
ethch18 committed Sep 3, 2020
1 parent e899aa9 commit 91ffeb2
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 30 deletions.
4 changes: 2 additions & 2 deletions docs/source/nn.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -483,10 +483,10 @@ Loss functions

.. autofunction:: triplet_margin_loss

:hidden:`triplet_margin_loss_with_distance`
:hidden:`triplet_margin_with_distance_loss`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: triplet_margin_loss_with_distance
.. autofunction:: triplet_margin_with_distance_loss

Vision functions
----------------
Expand Down
2 changes: 1 addition & 1 deletion docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ Loss Functions
nn.CosineEmbeddingLoss
nn.MultiMarginLoss
nn.TripletMarginLoss
nn.TripletMarginLossWithDistance
nn.TripletMarginWithDistanceLoss

Vision Layers
----------------
Expand Down
24 changes: 12 additions & 12 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12067,9 +12067,9 @@ def test_threshold_inplace_overlap(self, device):
F.threshold(x, 0.5, 0.5, inplace=True)
F.threshold_(x, 0.5, 0.5)

def test_triplet_margin_loss_with_distance_default_parity(self, device):
# Test for `nn.TripletMarginLossWithDistance` and
# `F.triplet_margin_loss_with_distance`. Checks
def test_triplet_margin_with_distance_loss_default_parity(self, device):
# Test for `nn.TripletMarginWithDistanceLoss` and
# `F.triplet_margin_with_distance_loss`. Checks
# for parity against the respective non-distance-agnostic
# implementations of triplet margin loss (``nn.TripletMarginLoss`
# and `F.triplet_margin_loss`) under *default args*.
Expand All @@ -12084,25 +12084,25 @@ def test_triplet_margin_loss_with_distance_default_parity(self, device):

# Test forward, functional
expected = F.triplet_margin_loss(anchor, positive, negative, **kwargs)
actual = F.triplet_margin_loss_with_distance(anchor, positive, negative, **kwargs)
actual = F.triplet_margin_with_distance_loss(anchor, positive, negative, **kwargs)
self.assertEqual(actual, expected, rtol=1e-6, atol=1e-6)

# Test forward, module
loss_ref = nn.TripletMarginLoss(**kwargs)
loss_op = nn.TripletMarginLossWithDistance(**kwargs)
loss_op = nn.TripletMarginWithDistanceLoss(**kwargs)
self.assertEqual(loss_op(anchor, positive, negative),
loss_ref(anchor, positive, negative),
rtol=1e-6, atol=1e-6)

# Test backward
self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_loss_with_distance(
self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_with_distance_loss(
a, p, n, **kwargs), (anchor, positive, negative)))
self.assertTrue(gradcheck(lambda a, p, n: loss_op(a, p, n),
(anchor, positive, negative)))

def test_triplet_margin_loss_with_distance(self, device):
# Test for parity between `nn.TripletMarginLossWithDistance` and
# `F.triplet_margin_loss_with_distance`.
def test_triplet_margin_with_distance_loss(self, device):
# Test for parity between `nn.TripletMarginWithDistanceLoss` and
# `F.triplet_margin_with_distance_loss`.

def pairwise_similarity(x, y):
return 1.0 - F.pairwise_distance(x, y)
Expand All @@ -12125,11 +12125,11 @@ def cosine_distance(x, y):
negative = torch.randn(5, 10, device=device, requires_grad=True)

# Test backward
self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_loss_with_distance(
self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_with_distance_loss(
a, p, n, distance_function=distance_fn, is_similarity_function=is_similarity_fn,
reduction=reduction, margin=margin, swap=swap),
(anchor, positive, negative)))
loss_op = nn.TripletMarginLossWithDistance(distance_function=distance_fn,
loss_op = nn.TripletMarginWithDistanceLoss(distance_function=distance_fn,
is_similarity_function=is_similarity_fn,
reduction=reduction, margin=margin, swap=swap)
self.assertTrue(gradcheck(lambda a, p, n: loss_op(
Expand All @@ -12139,7 +12139,7 @@ def cosine_distance(x, y):
a, p, n), (anchor, positive, negative)))

# Test forward parity
functional = F.triplet_margin_loss_with_distance(anchor, positive, negative,
functional = F.triplet_margin_with_distance_loss(anchor, positive, negative,
distance_function=distance_fn,
is_similarity_function=is_similarity_fn,
reduction=reduction, margin=margin, swap=swap)
Expand Down
8 changes: 4 additions & 4 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3729,21 +3729,21 @@ def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, s
swap, reduction_enum)


def triplet_margin_loss_with_distance(anchor, positive, negative, distance_function=None, is_similarity_function=False,
def triplet_margin_with_distance_loss(anchor, positive, negative, distance_function=None, is_similarity_function=False,
margin=1.0, swap=False, reduction="mean"):
# type: (Tensor, Tensor, Tensor, Optional[Callable[[Tensor, Tensor], Tensor]], bool, float, bool, str) -> Tensor
r"""
See :class:`~torch.nn.TripletMarginLossWithDistance` for details.
See :class:`~torch.nn.TripletMarginWithDistanceLoss` for details.
Note: does not support JIT scripting.
"""
if torch.jit.is_scripting():
raise NotImplementedError("F.triplet_margin_loss_with_distance does not support JIT scripting: "
raise NotImplementedError("F.triplet_margin_with_distance_loss does not support JIT scripting: "
"functions requiring Callables cannot be scripted.")

tens_ops = (anchor, positive, negative)
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
return handle_torch_function(
triplet_margin_loss_with_distance, tens_ops, anchor, positive, negative,
triplet_margin_with_distance_loss, tens_ops, anchor, positive, negative,
distance_function=distance_function, is_similarity_function=is_similarity_function,
margin=margin, swap=swap, reduction=reduction)

Expand Down
2 changes: 1 addition & 1 deletion torch/nn/functional.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor, marg
reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...


def triplet_margin_loss_with_distance(anchor: Tensor, positive: Tensor, negative: Tensor,
def triplet_margin_with_distance_loss(anchor: Tensor, positive: Tensor, negative: Tensor,
distance_function: Optional[Callable[[Tensor, Tensor], Tensor]]=...,
is_similarity_function: bool=..., margin: float=...,
swap: bool=..., reduction: str=...) -> Tensor: ...
Expand Down
4 changes: 2 additions & 2 deletions torch/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .loss import L1Loss, NLLLoss, KLDivLoss, MSELoss, BCELoss, BCEWithLogitsLoss, NLLLoss2d, \
CosineEmbeddingLoss, CTCLoss, HingeEmbeddingLoss, MarginRankingLoss, \
MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, SmoothL1Loss, \
SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, TripletMarginLossWithDistance, PoissonNLLLoss
SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, TripletMarginWithDistanceLoss, PoissonNLLLoss
from .container import Container, Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict
from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \
MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, \
Expand Down Expand Up @@ -54,5 +54,5 @@
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'TripletMarginLossWithDistance'
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'TripletMarginWithDistanceLoss'
]
14 changes: 7 additions & 7 deletions torch/nn/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,7 +1192,7 @@ class TripletMarginLoss(_Loss):
.. math::
d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p
See also :class:`~torch.nn.TripletMarginLossWithDistance`, which computes the
See also :class:`~torch.nn.TripletMarginWithDistanceLoss`, which computes the
triplet margin loss for input tensors using a custom distance function.
Args:
Expand Down Expand Up @@ -1251,7 +1251,7 @@ def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:
eps=self.eps, swap=self.swap, reduction=self.reduction)


class TripletMarginLossWithDistance(_Loss):
class TripletMarginWithDistanceLoss(_Loss):
r"""Creates a criterion that measures the triplet loss given input
tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor,
positive, and negative examples, respectively); and a nonnegative,
Expand Down Expand Up @@ -1329,20 +1329,20 @@ class TripletMarginLossWithDistance(_Loss):
>>> negative = embedding(negative_ids)
>>>
>>> # Built-in Distance Function
>>> triplet_loss = nn.TripletMarginLossWithDistance(distance_function=nn.PairwiseDistance())
>>> triplet_loss = nn.TripletMarginWithDistanceLoss(distance_function=nn.PairwiseDistance())
>>> output = triplet_loss(anchor, positive, negative)
>>> output.backward()
>>>
>>> # Built-in Similarity Function
>>> triplet_loss = nn.TripletMarginLossWithDistance(distance_function=nn.CosineSimilarity(), is_similarity_function=True)
>>> triplet_loss = nn.TripletMarginWithDistanceLoss(distance_function=nn.CosineSimilarity(), is_similarity_function=True)
>>> output = triplet_loss(anchor, positive, negative)
>>> output.backward()
>>>
>>> # User-defined Similarity Function
>>> def l_infinity(x1, x2):
>>> return torch.max(torch.abs(x1 - x2), dim=1).values
>>>
>>> triplet_loss = nn.TripletMarginLossWithDistance(distance_function=l_infinity, margin=1.5)
>>> triplet_loss = nn.TripletMarginWithDistanceLoss(distance_function=l_infinity, margin=1.5)
>>> output = triplet_loss(anchor, positive, negative)
>>> output.backward()
Expand All @@ -1357,14 +1357,14 @@ class TripletMarginLossWithDistance(_Loss):

def __init__(self, distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None, is_similarity_function: bool = False,
margin: float = 1.0, swap: bool = False, reduction: str = 'mean'):
super(TripletMarginLossWithDistance, self).__init__(size_average=None, reduce=None, reduction=reduction)
super(TripletMarginWithDistanceLoss, self).__init__(size_average=None, reduce=None, reduction=reduction)
self.distance_function = distance_function if distance_function is not None else PairwiseDistance()
self.is_similarity_function = is_similarity_function
self.margin = margin
self.swap = swap

def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:
return F.triplet_margin_loss_with_distance(anchor, positive, negative,
return F.triplet_margin_with_distance_loss(anchor, positive, negative,
distance_function=self.distance_function,
is_similarity_function=self.is_similarity_function,
margin=self.margin, swap=self.swap, reduction=self.reduction)
Expand Down
2 changes: 1 addition & 1 deletion torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1,
torch.nn.functional.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06,
swap=False, size_average=None, reduce=None, reduction='mean': -1),
torch.nn.functional.triplet_margin_loss_with_distance: (lambda anchor, positive, negative, distance_function=None,
torch.nn.functional.triplet_margin_with_distance_loss: (lambda anchor, positive, negative, distance_function=None,
is_similarity_function=False, margin=1.0,
swap=False, reduction='mean': -1),
torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
Expand Down

0 comments on commit 91ffeb2

Please sign in to comment.