Skip to content

Commit

Permalink
[pytorch] Add triplet margin loss with custom distance
Browse files Browse the repository at this point in the history
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

* Consolidate tests, clarify functional limitations

* Documentation updates

* Remove stray imports

* Fix CI

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 731708e0ced604dc46e1b2b81a91dcecd9607d8f
Pull Request resolved: #43680
  • Loading branch information
ethch18 committed Sep 11, 2020
1 parent 30fccc5 commit a723924
Show file tree
Hide file tree
Showing 8 changed files with 257 additions and 11 deletions.
7 changes: 5 additions & 2 deletions docs/source/nn.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,11 @@ Loss functions

.. autofunction:: triplet_margin_loss

:hidden:`triplet_margin_with_distance_loss`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: triplet_margin_with_distance_loss

Vision functions
----------------

Expand Down Expand Up @@ -533,5 +538,3 @@ DataParallel functions (multi-GPU, distributed)
~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torch.nn.parallel.data_parallel


3 changes: 2 additions & 1 deletion docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ These are the basic building block for graphs
:depth: 2
:local:
:backlinks: top


.. currentmodule:: torch.nn

Expand Down Expand Up @@ -269,6 +269,7 @@ Loss Functions
nn.CosineEmbeddingLoss
nn.MultiMarginLoss
nn.TripletMarginLoss
nn.TripletMarginWithDistanceLoss

Vision Layers
----------------
Expand Down
77 changes: 77 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12134,6 +12134,83 @@ def test_threshold_inplace_overlap(self, device):
F.threshold(x, 0.5, 0.5, inplace=True)
F.threshold_(x, 0.5, 0.5)

def test_triplet_margin_with_distance_loss_default_parity(self, device):
# Test for `nn.TripletMarginWithDistanceLoss` and
# `F.triplet_margin_with_distance_loss`. Checks
# for parity against the respective non-distance-agnostic
# implementations of triplet margin loss (``nn.TripletMarginLoss`
# and `F.triplet_margin_loss`) under *default args*.

for extra_args in \
itertools.product((0.5, 1, 1.5), (True, False), ('none', 'mean', 'sum')):
kwargs = {'margin': extra_args[0], 'swap': extra_args[1], 'reduction': extra_args[2]}

anchor = torch.randn(5, 10, device=device, requires_grad=True)
positive = torch.randn(5, 10, device=device, requires_grad=True)
negative = torch.randn(5, 10, device=device, requires_grad=True)

# Test forward, functional
expected = F.triplet_margin_loss(anchor, positive, negative, **kwargs)
actual = F.triplet_margin_with_distance_loss(anchor, positive, negative, **kwargs)
self.assertEqual(actual, expected, rtol=1e-6, atol=1e-6)

# Test forward, module
loss_ref = nn.TripletMarginLoss(**kwargs)
loss_op = nn.TripletMarginWithDistanceLoss(**kwargs)
self.assertEqual(loss_op(anchor, positive, negative),
loss_ref(anchor, positive, negative),
rtol=1e-6, atol=1e-6)

# Test backward
self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_with_distance_loss(
a, p, n, **kwargs), (anchor, positive, negative)))
self.assertTrue(gradcheck(lambda a, p, n: loss_op(a, p, n),
(anchor, positive, negative)))

def test_triplet_margin_with_distance_loss(self, device):
# Test for parity between `nn.TripletMarginWithDistanceLoss` and
# `F.triplet_margin_with_distance_loss`.

pairwise_distance = nn.PairwiseDistance()

def cosine_distance(x, y):
return 1.0 - F.cosine_similarity(x, y)

distance_functions = (pairwise_distance, cosine_distance,
lambda x, y: 1.0 - F.cosine_similarity(x, y))

reductions = ('mean', 'none', 'sum')
margins = (1.0, 1.5, 0.5)
swaps = (True, False)

for distance_fn, reduction, margin, swap \
in itertools.product(distance_functions, reductions, margins, swaps):
anchor = torch.randn(5, 10, device=device, requires_grad=True)
positive = torch.randn(5, 10, device=device, requires_grad=True)
negative = torch.randn(5, 10, device=device, requires_grad=True)

# Test backward
self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_with_distance_loss(
a, p, n, distance_function=distance_fn, reduction=reduction, margin=margin, swap=swap),
(anchor, positive, negative)))
loss_op = nn.TripletMarginWithDistanceLoss(distance_function=distance_fn,
reduction=reduction, margin=margin, swap=swap)
self.assertTrue(gradcheck(lambda a, p, n: loss_op(
a, p, n), (anchor, positive, negative)))
traced_loss_op = torch.jit.trace(loss_op, (anchor, positive, negative))
self.assertTrue(gradcheck(lambda a, p, n: traced_loss_op(
a, p, n), (anchor, positive, negative)))

# Test forward parity
functional = F.triplet_margin_with_distance_loss(anchor, positive, negative,
distance_function=distance_fn,
reduction=reduction, margin=margin, swap=swap)
modular = loss_op(anchor, positive, negative)
traced = traced_loss_op(anchor, positive, negative)
self.assertEqual(functional, modular, atol=1e-6, rtol=1e-6)
self.assertEqual(traced, modular, atol=1e-6, rtol=1e-6)


class TestModuleGlobalHooks(TestCase):

def tearDown(self):
Expand Down
37 changes: 37 additions & 0 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3725,6 +3725,43 @@ def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, s
swap, reduction_enum)


def triplet_margin_with_distance_loss(anchor, positive, negative, *, distance_function=None,
margin=1.0, swap=False, reduction="mean"):
# type: (Tensor, Tensor, Tensor, Optional[Callable[[Tensor, Tensor], Tensor]], float, bool, str) -> Tensor
r"""
See :class:`~torch.nn.TripletMarginWithDistanceLoss` for details.
Note: does not support JIT scripting.
"""
if torch.jit.is_scripting():
raise NotImplementedError("F.triplet_margin_with_distance_loss does not support JIT scripting: "
"functions requiring Callables cannot be scripted.")

tens_ops = (anchor, positive, negative)
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
return handle_torch_function(
triplet_margin_with_distance_loss, tens_ops, anchor, positive, negative,
distance_function=distance_function, margin=margin, swap=swap, reduction=reduction)

distance_function = distance_function if distance_function is not None else pairwise_distance

positive_dist = distance_function(anchor, positive)
negative_dist = distance_function(anchor, negative)

if swap:
swap_dist = distance_function(positive, negative)
negative_dist = torch.min(negative_dist, swap_dist)

output = torch.clamp(positive_dist - negative_dist + margin, min=0.0)

reduction_enum = _Reduction.get_enum(reduction)
if reduction_enum == 1:
return output.mean()
elif reduction_enum == 2:
return output.sum()
else:
return output


def normalize(input, p=2, dim=1, eps=1e-12, out=None):
# type: (Tensor, float, int, float, Optional[Tensor]) -> Tensor
r"""Performs :math:`L_p` normalization of inputs over specified dimension.
Expand Down
11 changes: 8 additions & 3 deletions torch/nn/functional.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ GRID_SAMPLE_PADDING_MODES = Dict[str, int]
# This was necessary since the JIT uses BroadcastingList* types but static checking with mypy etc requires a `Sequence`
# type. There is no way to express the expected lengths of these lists in the current Python typing system.
#
# Functions created via `_add_docstr` in `functional.py` where merely typed as `Any` by `stubgen`, so those were
# deleted from the stub and replaced by generated declarations. See `gen_pyi` for the implementation of the code
# generation logic for those functions. In the future, it might be worth looking into using the mypy plugin system
# Functions created via `_add_docstr` in `functional.py` where merely typed as `Any` by `stubgen`, so those were
# deleted from the stub and replaced by generated declarations. See `gen_pyi` for the implementation of the code
# generation logic for those functions. In the future, it might be worth looking into using the mypy plugin system
# to encode the type semantics of `_add_docstr`, should that system ever become widespread.
def fractional_max_pool2d_with_indices(input: Tensor, kernel_size: _size, output_size: Optional[_size] = ...,
output_ratio: Optional[_ratio_any_t] = ..., return_indices: bool = ...,
Expand Down Expand Up @@ -319,6 +319,11 @@ def triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor, marg
reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...


def triplet_margin_with_distance_loss(anchor: Tensor, positive: Tensor, negative: Tensor, *,
distance_function: Optional[Callable[[Tensor, Tensor], Tensor]]=...,
margin: float=..., swap: bool=..., reduction: str=...) -> Tensor: ...


def normalize(input: Tensor, p: float = ..., dim: int = ..., eps: float = ...,
out: Optional[Tensor] = ...) -> Tensor: ...

Expand Down
6 changes: 3 additions & 3 deletions torch/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
Hardsigmoid, Hardswish, SiLU
from .loss import L1Loss, NLLLoss, KLDivLoss, MSELoss, BCELoss, BCEWithLogitsLoss, NLLLoss2d, \
CosineEmbeddingLoss, CTCLoss, HingeEmbeddingLoss, MarginRankingLoss, \
MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, \
SmoothL1Loss, SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, PoissonNLLLoss
MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, SmoothL1Loss, \
SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, TripletMarginWithDistanceLoss, PoissonNLLLoss
from .container import Container, Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict
from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \
MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, \
Expand Down Expand Up @@ -54,5 +54,5 @@
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU',
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'TripletMarginWithDistanceLoss'
]
124 changes: 122 additions & 2 deletions torch/nn/modules/loss.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import warnings

from .distance import PairwiseDistance
from .module import Module
from .. import functional as F
from .. import _reduction as _Reduction

from torch import Tensor
from typing import Optional
from typing import Callable, Optional


class _Loss(Module):
Expand Down Expand Up @@ -1191,6 +1192,9 @@ class TripletMarginLoss(_Loss):
.. math::
d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p
See also :class:`~torch.nn.TripletMarginWithDistanceLoss`, which computes the
triplet margin loss for input tensors using a custom distance function.
Args:
margin (float, optional): Default: :math:`1`.
p (int, optional): The norm degree for pairwise distance. Default: :math:`2`.
Expand All @@ -1215,7 +1219,8 @@ class TripletMarginLoss(_Loss):
Shape:
- Input: :math:`(N, D)` where :math:`D` is the vector dimension.
- Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`.
- Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar
otherwise.
>>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
>>> anchor = torch.randn(100, 128, requires_grad=True)
Expand Down Expand Up @@ -1246,6 +1251,121 @@ def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:
eps=self.eps, swap=self.swap, reduction=self.reduction)


class TripletMarginWithDistanceLoss(_Loss):
r"""Creates a criterion that measures the triplet loss given input
tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor,
positive, and negative examples, respectively); and a nonnegative,
real-valued function ("distance function") used to compute the relationship
between the anchor and positive examples ("positive distance") and the
anchor and negative examples ("negative distance").
The unreduced loss (i.e., with :attr:`reduction` set to ``'none'``)
can be described as:
.. math::
\ell(a, p, n) = L = \{l_1,\dots,l_N\}^\top, \quad
l_i = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\}
where :math:`N` is the batch size; :math:`d` is a nonnegative, real-valued function
quantifying the closeness of two tensors, referred to as :attr:`distance_function`;
and :math:`margin` is a non-negative margin between the positive and negative
distances that is required for the loss to be 0. The input tensors have :math:`N`
elements each and can be of any shape that the distance function can handle.
If :attr:`reduction` is not ``'none'``
(default ``'mean'``), then:
.. math::
\ell(x, y) =
\begin{cases}
\operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
\operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
\end{cases}
See also :class:`~torch.nn.TripletMarginLoss`, which computes the triplet
loss for input tensors using the :math:`l_p` distance as the distance function.
Note: does not support JIT scripting.
Args:
distance_function (callable, optional): A nonnegative, real-valued function that
quantifies the closeness of two tensors. If not specified,
`nn.PairwiseDistance` will be used. Default: ``None``
margin (float, optional): A non-negative margin representing the minimum difference
between the positive and negative distances required for the loss to be 0. Larger
margins penalize cases where the negative examples are not distant enough from the
anchors, relative to the positives. Default: :math:`1`.
swap (bool, optional): Whether to use the distance swap described in the paper
`Learning shallow convolutional feature descriptors with triplet losses` by
V. Balntas, E. Riba et al. If True, and if the positive example is closer to the
negative example than the anchor is, swaps the positive example and the anchor in
the loss computation. Default: ``False``.
reduction (string, optional): Specifies the (optional) reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
Shape:
- Input: :math:`(N, *)` where :math:`*` represents any number of additional dimensions
as supported by the distance function.
- Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar
otherwise.
Examples::
>>> # Initialize embeddings
>>> embedding = nn.Embedding(1000, 128)
>>> anchor_ids = torch.randint(0, 1000, (1,), requires_grad=True)
>>> positive_ids = torch.randint(0, 1000, (1,), requires_grad=True)
>>> negative_ids = torch.randint(0, 1000, (1,), requires_grad=True)
>>> anchor = embedding(anchor_ids)
>>> positive = embedding(positive_ids)
>>> negative = embedding(negative_ids)
>>>
>>> # Built-in Distance Function
>>> triplet_loss = \
>>> nn.TripletMarginWithDistanceLoss(distance_function=nn.PairwiseDistance())
>>> output = triplet_loss(anchor, positive, negative)
>>> output.backward()
>>>
>>> # Custom Distance Function
>>> def l_infinity(x1, x2):
>>> return torch.max(torch.abs(x1 - x2), dim=1).values
>>>
>>> triplet_loss = \
>>> nn.TripletMarginWithDistanceLoss(distance_function=l_infinity, margin=1.5)
>>> output = triplet_loss(anchor, positive, negative)
>>> output.backward()
>>>
>>> # Custom Distance Function (Lambda)
>>> triplet_loss = \
>>> nn.TripletMarginWithDistanceLoss(
>>> distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y))
>>> output = triplet_loss(anchor, positive, negative)
>>> output.backward()
Reference:
V. Balntas, et al.: Learning shallow convolutional feature descriptors with triplet losses:
http://www.bmva.org/bmvc/2016/papers/paper119/index.html
"""
__constants__ = ['margin', 'swap', 'reduction']
margin: float
swap: bool

def __init__(self, *, distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None,
margin: float = 1.0, swap: bool = False, reduction: str = 'mean'):
super(TripletMarginWithDistanceLoss, self).__init__(size_average=None, reduce=None, reduction=reduction)
self.distance_function = distance_function if distance_function is not None else PairwiseDistance()
self.margin = margin
self.swap = swap

def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:
return F.triplet_margin_with_distance_loss(anchor, positive, negative,
distance_function=self.distance_function,
margin=self.margin, swap=self.swap, reduction=self.reduction)


class CTCLoss(_Loss):
r"""The Connectionist Temporal Classification loss.
Expand Down
3 changes: 3 additions & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1,
torch.nn.functional.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06,
swap=False, size_average=None, reduce=None, reduction='mean': -1),
torch.nn.functional.triplet_margin_with_distance_loss: (lambda anchor, positive, negative, *,
distance_function=None, margin=1.0,
swap=False, reduction='mean': -1),
torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
torch.nonzero: lambda input, as_tuple=False: -1,
torch.norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
Expand Down

0 comments on commit a723924

Please sign in to comment.