Skip to content

Commit

Permalink
Update on "[pytorch] Add triplet margin loss with custom distance"
Browse files Browse the repository at this point in the history
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23363898](https://our.internmc.facebook.com/intern/diff/D23363898)

[ghstack-poisoned]
  • Loading branch information
ethch18 committed Aug 31, 2020
1 parent de25d9a commit b1a40fb
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 39 deletions.
9 changes: 4 additions & 5 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3737,15 +3737,14 @@ def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, s
swap, reduction_enum)


def triplet_margin_loss_with_distance(anchor: Tensor, positive: Tensor, negative: Tensor,
distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None,
is_similarity_function: bool = False, margin: float = 1.0,
swap: bool = False, reduction: str = "mean"):
def triplet_margin_loss_with_distance(anchor, positive, negative, distance_function=None, is_similarity_function=False,
margin=1.0, swap=False, reduction="mean"):
# type: (Tensor, Tensor, Tensor, Optional[Callable[[Tensor, Tensor], Tensor]], bool, float, bool, str) -> Tensor
r"""
See :class:`~torch.nn.TripletMarginLossWithDistance` for details
"""
if torch.jit.is_scripting():
raise NotImplementedError("F.triplet_margin_loss_with_distance does not support JIT: "
raise NotImplementedError("F.triplet_margin_loss_with_distance does not support JIT scripting: "
"Callables cannot be scripted unless they are properties of "
"a module. Please use nn.TripletMarginLossWithDistance instead.")

Expand Down
111 changes: 77 additions & 34 deletions torch/nn/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,9 @@ class TripletMarginLoss(_Loss):
.. math::
d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p
See also :class:`~torch.nn.TripletMarginLossWithDistance`, which computes the
triplet margin loss for input tensors using a custom distance function.
Args:
margin (float, optional): Default: :math:`1`.
p (int, optional): The norm degree for pairwise distance. Default: :math:`2`.
Expand All @@ -1226,7 +1229,8 @@ class TripletMarginLoss(_Loss):
Shape:
- Input: :math:`(N, D)` where :math:`D` is the vector dimension.
- Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`.
- Output: If :attr:`reduction` is ``'none'``, then a tensor of shape :math:`(N)`,
or a scalar otherwise.
>>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
>>> anchor = torch.randn(100, 128, requires_grad=True)
Expand Down Expand Up @@ -1258,61 +1262,100 @@ def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:


class TripletMarginLossWithDistance(_Loss):
r"""Creates a criterion that measures the triplet loss given an input
tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
This is used for measuring a relative similarity between samples. A triplet
is composed by `a`, `p` and `n` (i.e., `anchor`, `positive examples` and `negative
examples` respectively). The shapes of all input tensors should be
:math:`(N, D)`.
The distance swap is described in detail in the paper `Learning shallow
convolutional feature descriptors with triplet losses`_ by
V. Balntas, E. Riba et al.
r"""Creates a criterion that measures the triplet loss given input
tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor,
positive, and negative examples, respectively); and a real-valued function
between them.
The loss function for each sample in the mini-batch is:
The unreduced loss (i.e., with `reduction` set to `'none'`)
can be described as:
.. math::
L(a, p, n) = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\}
\ell(a, p, n) = L = \{l_1,\dots,\l_N}^\top, \quad
l_i = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\}
where :math:`N` is the batch size; :math:`d` is a real-valued function quantifying
the separation between two tensors, referred to as `distance_function`;
and :math:`margin` is a non-negative margin enforced between the positive and
negative distances. The input tensors have :math:`N` elements each and can be of
any shape that the distance function can handle.
If :attr:`reduction` is not ``'none'``
(default ``'mean'``), then:
.. math::
\ell(x, y) =
\begin{cases}
\operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
\operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
\end{cases}
where :math:`d(x_i, y_i)` represents the output of `distance_function` on the two
inputs. See also :class:`~torch.nn.TripletMarginLoss`.
See also :class:`~torch.nn.TripletMarginLoss`, which computes the triplet
loss for input tensors using the :math:`l_p` distance as the distance function.
Args:
distance_function (callable, optional): A distance function between two Tensors which,
if specified, will be used instead of the pairwise distance. If not specified,
`nn.PairwiseDistance` will be used. Default: ``None``
is_similarity_function (bool, optional): Whether `distance_function` represents a
similarity metrics. Default: ``False``
margin (float, optional): Default: :math:`1`.
swap (bool, optional): The distance swap is described in detail in the paper
similarity metric, i.e., larger is closer. If True, computes the difference of
distances as :math:`d(a_i, n_i) - d(a_i, p_i)` so that larger loss values occur
when the negative example is more similar to the anchor than the positive example
is. Default: ``False``
margin (float, optional): A non-negative margin enforced between the positive and
negative distances. Larger margins penalize cases where the negative examples
are not distant enough from the anchors, relative to the positives. Default: :math:`1`.
swap (bool, optional): Whether to use the distance swap described in the paper
`Learning shallow convolutional feature descriptors with triplet losses` by
V. Balntas, E. Riba et al. Default: ``False``.
reduction (string, optional): Specifies the reduction to apply to the output:
V. Balntas, E. Riba et al. If True, and if the positive example is closer to the
negative example than the anchor is, swaps the positive example and the anchor in
the loss computation. Default: ``False``.
reduction (string, optional): Specifies the (optional) reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
and :attr:`reduce` are in the process of being deprecated, and in the meantime,
specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
Shape:
- Input: :math:`(N, D)` where :math:`D` is the vector dimension.
- Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`.
>>> distance_function = nn.PairwiseDistance(p=2)
>>> triplet_loss = nn.TripletMarginLossWithDistance(distance_function=distance_function, margin=1.0)
>>> anchor = torch.randn(100, 128, requires_grad=True)
>>> positive = torch.randn(100, 128, requires_grad=True)
>>> negative = torch.randn(100, 128, requires_grad=True)
- Input: :math:`(N, *)` where :math:`*` represents any number of additional dimensions
as supported by the distance function.
- Output: If :attr:`reduction` is ``'none'``, then a tensor of shape :math:`(N)`,
or a scalar otherwise.
Example::
>>> # Initialize embeddings
>>> embedding = nn.Embedding(1000, 128)
>>> anchor_ids = torch.randint(0, 1000, (1,), requires_grad=True)
>>> positive_ids = torch.randint(0, 1000, (1,), requires_grad=True)
>>> negative_ids = torch.randint(0, 1000, (1,), requires_grad=True)
>>> anchor = embedding(anchor_ids)
>>> positive = embedding(positive_ids)
>>> negative = embedding(negative_ids)
>>>
>>> # Built-in Distance Function
>>> triplet_loss = nn.TripletMarginLossWithDistance(distance_function=nn.PairwiseDistance())
>>> output = triplet_loss(anchor, positive, negative)
>>> output.backward()
>>>
>>> # Built-in Similarity Function
>>> triplet_loss = nn.TripletMarginLossWithDistance(distance_function=nn.CosineSimilarity(), is_similarity_function=True)
>>> output = triplet_loss(anchor, positive, negative)
>>> output.backward()
>>>
>>> # User-defined Similarity Function
>>> def l_infinity(x1, x2):
... return torch.max(torch.abs(x1 - x2), dim=1).values
...
>>> triplet_loss = nn.TripletMarginLossWithDistance(distance_function=l_infinity, margin=1.5)
>>> output = triplet_loss(anchor, positive, negative)
>>> output.backward()
.. _Learning shallow convolutional feature descriptors with triplet losses:
Reference:
V. Balntas, et al.: Learning shallow convolutional feature descriptors with triplet losses:
http://www.bmva.org/bmvc/2016/papers/paper119/index.html
"""
__constants__ = ['distance_function', 'is_similarity_function', 'margin', 'swap', 'reduction']
distance_function: Optional[Callable[[Tensor, Tensor], Tensor]]
__constants__ = ['is_similarity_function', 'margin', 'swap', 'reduction']
is_similarity_function: bool
margin: float
swap: bool
Expand Down

0 comments on commit b1a40fb

Please sign in to comment.