Skip to content

Commit

Permalink
Update on "[pytorch] Add triplet margin loss with custom distance"
Browse files Browse the repository at this point in the history
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23363898](https://our.internmc.facebook.com/intern/diff/D23363898)

[ghstack-poisoned]
  • Loading branch information
ethch18 committed Sep 18, 2020
1 parent 2c5b569 commit 4b34a03
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
3 changes: 3 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9812,6 +9812,7 @@ def v(fn):
v(lambda: F.multilabel_margin_loss(input, zeros, reduction=reduction))

v(lambda: F.triplet_margin_loss(input, input, input, reduction=reduction))
v(lambda: F.triplet_margin_with_distance_loss(input, input, input, reduction=reduction))
v(lambda: F.margin_ranking_loss(input, input, input.sign(), reduction=reduction))
v(lambda: F.cosine_embedding_loss(input, input, input[:, 0].sign(), reduction=reduction))

Expand Down Expand Up @@ -12134,6 +12135,7 @@ def test_threshold_inplace_overlap(self, device):
F.threshold(x, 0.5, 0.5, inplace=True)
F.threshold_(x, 0.5, 0.5)

@onlyOnCPUAndCUDA
def test_triplet_margin_with_distance_loss_default_parity(self, device):
# Test for `nn.TripletMarginWithDistanceLoss` and
# `F.triplet_margin_with_distance_loss`. Checks
Expand Down Expand Up @@ -12167,6 +12169,7 @@ def test_triplet_margin_with_distance_loss_default_parity(self, device):
self.assertTrue(gradcheck(lambda a, p, n: loss_op(a, p, n),
(anchor, positive, negative)))

@onlyOnCPUAndCUDA
def test_triplet_margin_with_distance_loss(self, device):
# Test for parity between `nn.TripletMarginWithDistanceLoss` and
# `F.triplet_margin_with_distance_loss`.
Expand Down
1 change: 0 additions & 1 deletion torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3730,7 +3730,6 @@ def triplet_margin_with_distance_loss(anchor, positive, negative, *, distance_fu
# type: (Tensor, Tensor, Tensor, Optional[Callable[[Tensor, Tensor], Tensor]], float, bool, str) -> Tensor
r"""
See :class:`~torch.nn.TripletMarginWithDistanceLoss` for details.
Note: does not support JIT scripting.
"""
if torch.jit.is_scripting():
raise NotImplementedError("F.triplet_margin_with_distance_loss does not support JIT scripting: "
Expand Down
17 changes: 8 additions & 9 deletions torch/nn/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,10 +1254,10 @@ def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:
class TripletMarginWithDistanceLoss(_Loss):
r"""Creates a criterion that measures the triplet loss given input
tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor,
positive, and negative examples, respectively); and a nonnegative,
positive, and negative examples, respectively) and a nonnegative,
real-valued function ("distance function") used to compute the relationship
between the anchor and positive examples ("positive distance") and the
anchor and negative examples ("negative distance").
between the anchor and positive example ("positive distance") and the
anchor and negative example ("negative distance").
The unreduced loss (i.e., with :attr:`reduction` set to ``'none'``)
can be described as:
Expand All @@ -1267,10 +1267,11 @@ class TripletMarginWithDistanceLoss(_Loss):
l_i = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\}
where :math:`N` is the batch size; :math:`d` is a nonnegative, real-valued function
quantifying the closeness of two tensors, referred to as :attr:`distance_function`;
and :math:`margin` is a non-negative margin between the positive and negative
distances that is required for the loss to be 0. The input tensors have :math:`N`
elements each and can be of any shape that the distance function can handle.
quantifying the closeness of two tensors, referred to as the :attr:`distance_function`;
and :math:`margin` is a non-negative margin representing the minimum difference
between the positive and negative distances that is required for the loss to
be 0. The input tensors have :math:`N` elements each and can be of any shape
that the distance function can handle.
If :attr:`reduction` is not ``'none'``
(default ``'mean'``), then:
Expand All @@ -1285,8 +1286,6 @@ class TripletMarginWithDistanceLoss(_Loss):
See also :class:`~torch.nn.TripletMarginLoss`, which computes the triplet
loss for input tensors using the :math:`l_p` distance as the distance function.
Note: does not support JIT scripting.
Args:
distance_function (callable, optional): A nonnegative, real-valued function that
quantifies the closeness of two tensors. If not specified,
Expand Down

0 comments on commit 4b34a03

Please sign in to comment.