Skip to content

Commit

Permalink
[pytorch] Add triplet margin loss with custom distance
Browse files Browse the repository at this point in the history
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

* Consolidate tests, clarify functional limitations

* Documentation updates

* Remove stray imports

* Fix CI

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 7052170dfc4c796c16ab4e5dada4cc8e7eb9dba7
Pull Request resolved: #43680
  • Loading branch information
ethch18 committed Sep 18, 2020
1 parent 07b7e44 commit 81aad24
Show file tree
Hide file tree
Showing 8 changed files with 258 additions and 11 deletions.
7 changes: 5 additions & 2 deletions docs/source/nn.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,11 @@ Loss functions

.. autofunction:: triplet_margin_loss

:hidden:`triplet_margin_with_distance_loss`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: triplet_margin_with_distance_loss

Vision functions
----------------

Expand Down Expand Up @@ -533,5 +538,3 @@ DataParallel functions (multi-GPU, distributed)
~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torch.nn.parallel.data_parallel


3 changes: 2 additions & 1 deletion docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ These are the basic building block for graphs
:depth: 2
:local:
:backlinks: top


.. currentmodule:: torch.nn

Expand Down Expand Up @@ -269,6 +269,7 @@ Loss Functions
nn.CosineEmbeddingLoss
nn.MultiMarginLoss
nn.TripletMarginLoss
nn.TripletMarginWithDistanceLoss

Vision Layers
----------------
Expand Down
80 changes: 80 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9853,6 +9853,7 @@ def v(fn):
v(lambda: F.multilabel_margin_loss(input, zeros, reduction=reduction))

v(lambda: F.triplet_margin_loss(input, input, input, reduction=reduction))
v(lambda: F.triplet_margin_with_distance_loss(input, input, input, reduction=reduction))
v(lambda: F.margin_ranking_loss(input, input, input.sign(), reduction=reduction))
v(lambda: F.cosine_embedding_loss(input, input, input[:, 0].sign(), reduction=reduction))

Expand Down Expand Up @@ -12174,6 +12175,85 @@ def test_threshold_inplace_overlap(self, device):
F.threshold(x, 0.5, 0.5, inplace=True)
F.threshold_(x, 0.5, 0.5)

@onlyOnCPUAndCUDA
def test_triplet_margin_with_distance_loss_default_parity(self, device):
# Test for `nn.TripletMarginWithDistanceLoss` and
# `F.triplet_margin_with_distance_loss`. Checks
# for parity against the respective non-distance-agnostic
# implementations of triplet margin loss (``nn.TripletMarginLoss`
# and `F.triplet_margin_loss`) under *default args*.

for extra_args in \
itertools.product((0.5, 1, 1.5), (True, False), ('none', 'mean', 'sum')):
kwargs = {'margin': extra_args[0], 'swap': extra_args[1], 'reduction': extra_args[2]}

anchor = torch.randn(5, 10, device=device, requires_grad=True)
positive = torch.randn(5, 10, device=device, requires_grad=True)
negative = torch.randn(5, 10, device=device, requires_grad=True)

# Test forward, functional
expected = F.triplet_margin_loss(anchor, positive, negative, **kwargs)
actual = F.triplet_margin_with_distance_loss(anchor, positive, negative, **kwargs)
self.assertEqual(actual, expected, rtol=1e-6, atol=1e-6)

# Test forward, module
loss_ref = nn.TripletMarginLoss(**kwargs)
loss_op = nn.TripletMarginWithDistanceLoss(**kwargs)
self.assertEqual(loss_op(anchor, positive, negative),
loss_ref(anchor, positive, negative),
rtol=1e-6, atol=1e-6)

# Test backward
self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_with_distance_loss(
a, p, n, **kwargs), (anchor, positive, negative)))
self.assertTrue(gradcheck(lambda a, p, n: loss_op(a, p, n),
(anchor, positive, negative)))

@onlyOnCPUAndCUDA
def test_triplet_margin_with_distance_loss(self, device):
# Test for parity between `nn.TripletMarginWithDistanceLoss` and
# `F.triplet_margin_with_distance_loss`.

pairwise_distance = nn.PairwiseDistance()

def cosine_distance(x, y):
return 1.0 - F.cosine_similarity(x, y)

distance_functions = (pairwise_distance, cosine_distance,
lambda x, y: 1.0 - F.cosine_similarity(x, y))

reductions = ('mean', 'none', 'sum')
margins = (1.0, 1.5, 0.5)
swaps = (True, False)

for distance_fn, reduction, margin, swap \
in itertools.product(distance_functions, reductions, margins, swaps):
anchor = torch.randn(5, 10, device=device, requires_grad=True)
positive = torch.randn(5, 10, device=device, requires_grad=True)
negative = torch.randn(5, 10, device=device, requires_grad=True)

# Test backward
self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_with_distance_loss(
a, p, n, distance_function=distance_fn, reduction=reduction, margin=margin, swap=swap),
(anchor, positive, negative)))
loss_op = nn.TripletMarginWithDistanceLoss(distance_function=distance_fn,
reduction=reduction, margin=margin, swap=swap)
self.assertTrue(gradcheck(lambda a, p, n: loss_op(
a, p, n), (anchor, positive, negative)))
traced_loss_op = torch.jit.trace(loss_op, (anchor, positive, negative))
self.assertTrue(gradcheck(lambda a, p, n: traced_loss_op(
a, p, n), (anchor, positive, negative)))

# Test forward parity
functional = F.triplet_margin_with_distance_loss(anchor, positive, negative,
distance_function=distance_fn,
reduction=reduction, margin=margin, swap=swap)
modular = loss_op(anchor, positive, negative)
traced = traced_loss_op(anchor, positive, negative)
self.assertEqual(functional, modular, atol=1e-6, rtol=1e-6)
self.assertEqual(traced, modular, atol=1e-6, rtol=1e-6)


class TestModuleGlobalHooks(TestCase):

def tearDown(self):
Expand Down
36 changes: 36 additions & 0 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3728,6 +3728,42 @@ def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, s
swap, reduction_enum)


def triplet_margin_with_distance_loss(anchor, positive, negative, *, distance_function=None,
margin=1.0, swap=False, reduction="mean"):
# type: (Tensor, Tensor, Tensor, Optional[Callable[[Tensor, Tensor], Tensor]], float, bool, str) -> Tensor
r"""
See :class:`~torch.nn.TripletMarginWithDistanceLoss` for details.
"""
if torch.jit.is_scripting():
raise NotImplementedError("F.triplet_margin_with_distance_loss does not support JIT scripting: "
"functions requiring Callables cannot be scripted.")

tens_ops = (anchor, positive, negative)
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
return handle_torch_function(
triplet_margin_with_distance_loss, tens_ops, anchor, positive, negative,
distance_function=distance_function, margin=margin, swap=swap, reduction=reduction)

distance_function = distance_function if distance_function is not None else pairwise_distance

positive_dist = distance_function(anchor, positive)
negative_dist = distance_function(anchor, negative)

if swap:
swap_dist = distance_function(positive, negative)
negative_dist = torch.min(negative_dist, swap_dist)

output = torch.clamp(positive_dist - negative_dist + margin, min=0.0)

reduction_enum = _Reduction.get_enum(reduction)
if reduction_enum == 1:
return output.mean()
elif reduction_enum == 2:
return output.sum()
else:
return output


def normalize(input, p=2, dim=1, eps=1e-12, out=None):
# type: (Tensor, float, int, float, Optional[Tensor]) -> Tensor
r"""Performs :math:`L_p` normalization of inputs over specified dimension.
Expand Down
11 changes: 8 additions & 3 deletions torch/nn/functional.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ GRID_SAMPLE_PADDING_MODES = Dict[str, int]
# This was necessary since the JIT uses BroadcastingList* types but static checking with mypy etc requires a `Sequence`
# type. There is no way to express the expected lengths of these lists in the current Python typing system.
#
# Functions created via `_add_docstr` in `functional.py` where merely typed as `Any` by `stubgen`, so those were
# deleted from the stub and replaced by generated declarations. See `gen_pyi` for the implementation of the code
# generation logic for those functions. In the future, it might be worth looking into using the mypy plugin system
# Functions created via `_add_docstr` in `functional.py` where merely typed as `Any` by `stubgen`, so those were
# deleted from the stub and replaced by generated declarations. See `gen_pyi` for the implementation of the code
# generation logic for those functions. In the future, it might be worth looking into using the mypy plugin system
# to encode the type semantics of `_add_docstr`, should that system ever become widespread.
def fractional_max_pool2d_with_indices(input: Tensor, kernel_size: _size, output_size: Optional[_size] = ...,
output_ratio: Optional[_ratio_any_t] = ..., return_indices: bool = ...,
Expand Down Expand Up @@ -319,6 +319,11 @@ def triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor, marg
reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...


def triplet_margin_with_distance_loss(anchor: Tensor, positive: Tensor, negative: Tensor, *,
distance_function: Optional[Callable[[Tensor, Tensor], Tensor]]=...,
margin: float=..., swap: bool=..., reduction: str=...) -> Tensor: ...


def normalize(input: Tensor, p: float = ..., dim: int = ..., eps: float = ...,
out: Optional[Tensor] = ...) -> Tensor: ...

Expand Down
6 changes: 3 additions & 3 deletions torch/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
Hardsigmoid, Hardswish, SiLU
from .loss import L1Loss, NLLLoss, KLDivLoss, MSELoss, BCELoss, BCEWithLogitsLoss, NLLLoss2d, \
CosineEmbeddingLoss, CTCLoss, HingeEmbeddingLoss, MarginRankingLoss, \
MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, \
SmoothL1Loss, SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, PoissonNLLLoss
MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, SmoothL1Loss, \
SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, TripletMarginWithDistanceLoss, PoissonNLLLoss
from .container import Container, Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict
from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \
MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, \
Expand Down Expand Up @@ -54,5 +54,5 @@
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU',
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'TripletMarginWithDistanceLoss'
]
123 changes: 121 additions & 2 deletions torch/nn/modules/loss.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import warnings

from .distance import PairwiseDistance
from .module import Module
from .. import functional as F
from .. import _reduction as _Reduction

from torch import Tensor
from typing import Optional
from typing import Callable, Optional


class _Loss(Module):
Expand Down Expand Up @@ -1191,6 +1192,9 @@ class TripletMarginLoss(_Loss):
.. math::
d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p
See also :class:`~torch.nn.TripletMarginWithDistanceLoss`, which computes the
triplet margin loss for input tensors using a custom distance function.
Args:
margin (float, optional): Default: :math:`1`.
p (int, optional): The norm degree for pairwise distance. Default: :math:`2`.
Expand All @@ -1215,7 +1219,8 @@ class TripletMarginLoss(_Loss):
Shape:
- Input: :math:`(N, D)` where :math:`D` is the vector dimension.
- Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`.
- Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar
otherwise.
>>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
>>> anchor = torch.randn(100, 128, requires_grad=True)
Expand Down Expand Up @@ -1246,6 +1251,120 @@ def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:
eps=self.eps, swap=self.swap, reduction=self.reduction)


class TripletMarginWithDistanceLoss(_Loss):
r"""Creates a criterion that measures the triplet loss given input
tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor,
positive, and negative examples, respectively), and a nonnegative,
real-valued function ("distance function") used to compute the relationship
between the anchor and positive example ("positive distance") and the
anchor and negative example ("negative distance").
The unreduced loss (i.e., with :attr:`reduction` set to ``'none'``)
can be described as:
.. math::
\ell(a, p, n) = L = \{l_1,\dots,l_N\}^\top, \quad
l_i = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\}
where :math:`N` is the batch size; :math:`d` is a nonnegative, real-valued function
quantifying the closeness of two tensors, referred to as the :attr:`distance_function`;
and :math:`margin` is a non-negative margin representing the minimum difference
between the positive and negative distances that is required for the loss to
be 0. The input tensors have :math:`N` elements each and can be of any shape
that the distance function can handle.
If :attr:`reduction` is not ``'none'``
(default ``'mean'``), then:
.. math::
\ell(x, y) =
\begin{cases}
\operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
\operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
\end{cases}
See also :class:`~torch.nn.TripletMarginLoss`, which computes the triplet
loss for input tensors using the :math:`l_p` distance as the distance function.
Args:
distance_function (callable, optional): A nonnegative, real-valued function that
quantifies the closeness of two tensors. If not specified,
`nn.PairwiseDistance` will be used. Default: ``None``
margin (float, optional): A non-negative margin representing the minimum difference
between the positive and negative distances required for the loss to be 0. Larger
margins penalize cases where the negative examples are not distant enough from the
anchors, relative to the positives. Default: :math:`1`.
swap (bool, optional): Whether to use the distance swap described in the paper
`Learning shallow convolutional feature descriptors with triplet losses` by
V. Balntas, E. Riba et al. If True, and if the positive example is closer to the
negative example than the anchor is, swaps the positive example and the anchor in
the loss computation. Default: ``False``.
reduction (string, optional): Specifies the (optional) reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
Shape:
- Input: :math:`(N, *)` where :math:`*` represents any number of additional dimensions
as supported by the distance function.
- Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar
otherwise.
Examples::
>>> # Initialize embeddings
>>> embedding = nn.Embedding(1000, 128)
>>> anchor_ids = torch.randint(0, 1000, (1,), requires_grad=True)
>>> positive_ids = torch.randint(0, 1000, (1,), requires_grad=True)
>>> negative_ids = torch.randint(0, 1000, (1,), requires_grad=True)
>>> anchor = embedding(anchor_ids)
>>> positive = embedding(positive_ids)
>>> negative = embedding(negative_ids)
>>>
>>> # Built-in Distance Function
>>> triplet_loss = \
>>> nn.TripletMarginWithDistanceLoss(distance_function=nn.PairwiseDistance())
>>> output = triplet_loss(anchor, positive, negative)
>>> output.backward()
>>>
>>> # Custom Distance Function
>>> def l_infinity(x1, x2):
>>> return torch.max(torch.abs(x1 - x2), dim=1).values
>>>
>>> triplet_loss = \
>>> nn.TripletMarginWithDistanceLoss(distance_function=l_infinity, margin=1.5)
>>> output = triplet_loss(anchor, positive, negative)
>>> output.backward()
>>>
>>> # Custom Distance Function (Lambda)
>>> triplet_loss = \
>>> nn.TripletMarginWithDistanceLoss(
>>> distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y))
>>> output = triplet_loss(anchor, positive, negative)
>>> output.backward()
Reference:
V. Balntas, et al.: Learning shallow convolutional feature descriptors with triplet losses:
http://www.bmva.org/bmvc/2016/papers/paper119/index.html
"""
__constants__ = ['margin', 'swap', 'reduction']
margin: float
swap: bool

def __init__(self, *, distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None,
margin: float = 1.0, swap: bool = False, reduction: str = 'mean'):
super(TripletMarginWithDistanceLoss, self).__init__(size_average=None, reduce=None, reduction=reduction)
self.distance_function = distance_function if distance_function is not None else PairwiseDistance()
self.margin = margin
self.swap = swap

def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:
return F.triplet_margin_with_distance_loss(anchor, positive, negative,
distance_function=self.distance_function,
margin=self.margin, swap=self.swap, reduction=self.reduction)


class CTCLoss(_Loss):
r"""The Connectionist Temporal Classification loss.
Expand Down
3 changes: 3 additions & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1,
torch.nn.functional.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06,
swap=False, size_average=None, reduce=None, reduction='mean': -1),
torch.nn.functional.triplet_margin_with_distance_loss: (lambda anchor, positive, negative, *,
distance_function=None, margin=1.0,
swap=False, reduction='mean': -1),
torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
torch.nonzero: lambda input, as_tuple=False: -1,
torch.norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
Expand Down

0 comments on commit 81aad24

Please sign in to comment.