Skip to content

Commit

Permalink
[pytorch] Add triplet margin loss with custom distance
Browse files Browse the repository at this point in the history
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 14fbd6d444517d04ad3dd4f6b5e040411481905e
Pull Request resolved: #43680
  • Loading branch information
ethch18 committed Aug 27, 2020
1 parent 2b70f82 commit 4750845
Show file tree
Hide file tree
Showing 10 changed files with 338 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/source/amp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ Ops that can autocast to ``float32``
``renorm``,
``tan``,
``triplet_margin_loss``
``triplet_margin_loss_with_distance``

Ops that promote to the widest input type
"""""""""""""""""""""""""""""""""""""""""
Expand Down
7 changes: 5 additions & 2 deletions docs/source/nn.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,11 @@ Loss functions

.. autofunction:: triplet_margin_loss

:hidden:`triplet_margin_loss_with_distance`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: triplet_margin_loss_with_distance

Vision functions
----------------

Expand Down Expand Up @@ -533,5 +538,3 @@ DataParallel functions (multi-GPU, distributed)
~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torch.nn.parallel.data_parallel


3 changes: 2 additions & 1 deletion docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ These are the basic building block for graphs
:depth: 2
:local:
:backlinks: top


.. currentmodule:: torch.nn

Expand Down Expand Up @@ -269,6 +269,7 @@ Loss Functions
nn.CosineEmbeddingLoss
nn.MultiMarginLoss
nn.TripletMarginLoss
nn.TripletMarginLossWithDistance

Vision Layers
----------------
Expand Down
172 changes: 172 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6513,6 +6513,178 @@ def test_triplet_margin_loss_swap_no_reduce(self):
self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True, reduction='none'),
loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True, reduction='none'))

def test_triplet_margin_loss_with_distance_parity(self):
anchor = torch.randn(5, 10, requires_grad=True)
positive = torch.randn(5, 10, requires_grad=True)
negative = torch.randn(5, 10, requires_grad=True)
self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_loss_with_distance(
a, p, n), (anchor, positive, negative)))
self.assertEqual(F.triplet_margin_loss_with_distance(anchor, positive, negative),
F.triplet_margin_loss(anchor, positive, negative))

def test_triplet_margin_loss_with_distance(self):
anchor = torch.randn(5, 10, requires_grad=True)
positive = torch.randn(5, 10, requires_grad=True)
negative = torch.randn(5, 10, requires_grad=True)
distance_fn = nn.PairwiseDistance()
self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_loss_with_distance(
a, p, n, distance_function=distance_fn), (anchor, positive, negative)))
self.assertEqual(F.triplet_margin_loss_with_distance(anchor, positive, negative, distance_function=distance_fn),
F.triplet_margin_loss(anchor, positive, negative))

def test_triplet_margin_loss_with_distance_similarity(self):
anchor = torch.randn(5, 10, requires_grad=True)
positive = torch.randn(5, 10, requires_grad=True)
negative = torch.randn(5, 10, requires_grad=True)

def distance_fn(x, y):
return 1.0 - F.pairwise_distance(x, y)

self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_loss_with_distance(
a, p, n, distance_function=distance_fn, is_similarity_function=True), (anchor, positive, negative)))
self.assertEqual(F.triplet_margin_loss_with_distance(anchor, positive, negative,
distance_function=distance_fn,
is_similarity_function=True),
F.triplet_margin_loss(anchor, positive, negative))

def test_triplet_margin_loss_with_distance_swap(self):
anchor = torch.randn(5, 10, requires_grad=True)
positive = torch.randn(5, 10, requires_grad=True)
negative = torch.randn(5, 10, requires_grad=True)
distance_fn = nn.PairwiseDistance()
self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_loss_with_distance(
a, p, n, distance_function=distance_fn, swap=True), (anchor, positive, negative)))
self.assertEqual(F.triplet_margin_loss_with_distance(anchor, positive, negative, distance_function=distance_fn, swap=True),
F.triplet_margin_loss(anchor, positive, negative, swap=True))

def test_triplet_margin_loss_with_distance_similarity_swap(self):
anchor = torch.randn(5, 10, requires_grad=True)
positive = torch.randn(5, 10, requires_grad=True)
negative = torch.randn(5, 10, requires_grad=True)

def distance_fn(x, y):
return 1.0 - F.pairwise_distance(x, y)

self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_loss_with_distance(
a, p, n, distance_function=distance_fn, is_similarity_function=True, swap=True), (anchor, positive, negative)))
self.assertEqual(F.triplet_margin_loss_with_distance(anchor, positive, negative,
distance_function=distance_fn,
is_similarity_function=True, swap=True),
F.triplet_margin_loss(anchor, positive, negative, swap=True))

def test_triplet_margin_loss_with_distance_no_reduce(self):
anchor = torch.randn(5, 10, requires_grad=True)
positive = torch.randn(5, 10, requires_grad=True)
negative = torch.randn(5, 10, requires_grad=True)
distance_fn = nn.PairwiseDistance()
self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_loss_with_distance(
a, p, n, distance_function=distance_fn, reduction='none'), (anchor, positive, negative)))
self.assertEqual(F.triplet_margin_loss_with_distance(anchor, positive, negative,
distance_function=distance_fn,
reduction='none'),
F.triplet_margin_loss(anchor, positive, negative, reduction='none'))

def test_triplet_margin_loss_with_distance_margin(self):
anchor = torch.randn(5, 10, requires_grad=True)
positive = torch.randn(5, 10, requires_grad=True)
negative = torch.randn(5, 10, requires_grad=True)
distance_fn = nn.PairwiseDistance()
self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_loss_with_distance(
a, p, n, distance_function=distance_fn, margin=1.5), (anchor, positive, negative)))
self.assertEqual(F.triplet_margin_loss_with_distance(anchor, positive, negative,
distance_function=distance_fn, margin=1.5),
F.triplet_margin_loss(anchor, positive, negative, margin=1.5))

def test_triplet_margin_loss_with_distance_module_parity(self):
anchor = torch.randn(5, 10, requires_grad=True)
positive = torch.randn(5, 10, requires_grad=True)
negative = torch.randn(5, 10, requires_grad=True)
loss_base = nn.TripletMarginLoss()
loss_test = nn.TripletMarginLossWithDistance()
self.assertTrue(gradcheck(lambda a, p, n: loss_test(
a, p, n), (anchor, positive, negative)))
self.assertEqual(loss_test(anchor, positive, negative),
loss_base(anchor, positive, negative))

def test_triplet_margin_loss_with_distance_module(self):
anchor = torch.randn(5, 10, requires_grad=True)
positive = torch.randn(5, 10, requires_grad=True)
negative = torch.randn(5, 10, requires_grad=True)
distance_fn = nn.PairwiseDistance()
loss_base = nn.TripletMarginLoss()
loss_test = nn.TripletMarginLossWithDistance(distance_function=distance_fn)
self.assertTrue(gradcheck(lambda a, p, n: loss_test(
a, p, n), (anchor, positive, negative)))
self.assertEqual(loss_test(anchor, positive, negative),
loss_base(anchor, positive, negative))

def test_triplet_margin_loss_with_distance_module_similarity(self):
anchor = torch.randn(5, 10, requires_grad=True)
positive = torch.randn(5, 10, requires_grad=True)
negative = torch.randn(5, 10, requires_grad=True)

def distance_fn(x, y):
return 1.0 - F.pairwise_distance(x, y)

loss_base = nn.TripletMarginLoss()
loss_test = nn.TripletMarginLossWithDistance(distance_function=distance_fn, is_similarity_function=True)
self.assertTrue(gradcheck(lambda a, p, n: loss_test(
a, p, n), (anchor, positive, negative)))
self.assertEqual(loss_test(anchor, positive, negative),
loss_base(anchor, positive, negative))

def test_triplet_margin_loss_with_distance_module_swap(self):
anchor = torch.randn(5, 10, requires_grad=True)
positive = torch.randn(5, 10, requires_grad=True)
negative = torch.randn(5, 10, requires_grad=True)
distance_fn = nn.PairwiseDistance()
loss_base = nn.TripletMarginLoss(swap=True)
loss_test = nn.TripletMarginLossWithDistance(distance_function=distance_fn, swap=True)
self.assertTrue(gradcheck(lambda a, p, n: loss_test(
a, p, n), (anchor, positive, negative)))
self.assertEqual(loss_test(anchor, positive, negative),
loss_base(anchor, positive, negative))

def test_triplet_margin_loss_with_distance_module_similarity_swap(self):
anchor = torch.randn(5, 10, requires_grad=True)
positive = torch.randn(5, 10, requires_grad=True)
negative = torch.randn(5, 10, requires_grad=True)

def distance_fn(x, y):
return 1.0 - F.pairwise_distance(x, y)

loss_base = nn.TripletMarginLoss(swap=True)
loss_test = nn.TripletMarginLossWithDistance(
distance_function=distance_fn, is_similarity_function=True, swap=True)
self.assertTrue(gradcheck(lambda a, p, n: loss_test(
a, p, n), (anchor, positive, negative)))
self.assertEqual(loss_test(anchor, positive, negative),
loss_base(anchor, positive, negative))

def test_triplet_margin_loss_with_distance_module_no_reduce(self):
anchor = torch.randn(5, 10, requires_grad=True)
positive = torch.randn(5, 10, requires_grad=True)
negative = torch.randn(5, 10, requires_grad=True)
distance_fn = nn.PairwiseDistance()
loss_base = nn.TripletMarginLoss(reduction='none')
loss_test = nn.TripletMarginLossWithDistance(distance_function=distance_fn, reduction='none')
self.assertTrue(gradcheck(lambda a, p, n: loss_test(
a, p, n), (anchor, positive, negative)))
self.assertEqual(loss_test(anchor, positive, negative),
loss_base(anchor, positive, negative))

def test_triplet_margin_loss_with_distance_module_margin(self):
anchor = torch.randn(5, 10, requires_grad=True)
positive = torch.randn(5, 10, requires_grad=True)
negative = torch.randn(5, 10, requires_grad=True)
distance_fn = nn.PairwiseDistance()
loss_base = nn.TripletMarginLoss(margin=1.5)
loss_test = nn.TripletMarginLossWithDistance(distance_function=distance_fn, margin=1.5)
self.assertTrue(gradcheck(lambda a, p, n: loss_test(
a, p, n), (anchor, positive, negative)))
self.assertEqual(loss_test(anchor, positive, negative),
loss_base(anchor, positive, negative))

def test_pointwise_loss_target_grad_none_reduction(self):
i = torch.randn(5, 10)
t = torch.randn(5, 10, requires_grad=True)
Expand Down
36 changes: 36 additions & 0 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3737,6 +3737,42 @@ def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, s
swap, reduction_enum)


def triplet_margin_loss_with_distance(anchor, positive, negative, distance_function=None, is_similarity_function=False,
margin=1.0, swap=False, reduction="mean"):
# type: (Tensor, Tensor, Tensor, Optional[Callable[[Tensor, Tensor], Tensor]], bool, float, bool, str) -> Tensor
r"""
See :class:`~torch.nn.TripletMarginLossWithDistance` for details
"""
if torch.jit.is_scripting():
raise NotImplementedError("F.triplet_margin_loss_with_distance does not support JIT. "
"Please use nn.TripletMarginLossWithDistance instead.")
if distance_function is None:
distance_function = pairwise_distance

positive_dist = distance_function(anchor, positive)
negative_dist = distance_function(anchor, negative)

if swap:
swap_dist = distance_function(positive, negative)
if is_similarity_function:
negative_dist = torch.max(negative_dist, swap_dist)
else:
negative_dist = torch.min(negative_dist, swap_dist)

if is_similarity_function:
output = torch.clamp(margin - positive_dist + negative_dist, min=0.0)
else:
output = torch.clamp(margin + positive_dist - negative_dist, min=0.0)
reduction_enum = _Reduction.get_enum(reduction)

if reduction_enum == 1:
return output.mean()
elif reduction_enum == 2:
return output.sum()
else:
return output


def normalize(input, p=2, dim=1, eps=1e-12, out=None):
# type: (Tensor, float, int, float, Optional[Tensor]) -> Tensor
r"""Performs :math:`L_p` normalization of inputs over specified dimension.
Expand Down
12 changes: 9 additions & 3 deletions torch/nn/functional.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ GRID_SAMPLE_PADDING_MODES = Dict[str, int]
# This was necessary since the JIT uses BroadcastingList* types but static checking with mypy etc requires a `Sequence`
# type. There is no way to express the expected lengths of these lists in the current Python typing system.
#
# Functions created via `_add_docstr` in `functional.py` where merely typed as `Any` by `stubgen`, so those were
# deleted from the stub and replaced by generated declarations. See `gen_pyi` for the implementation of the code
# generation logic for those functions. In the future, it might be worth looking into using the mypy plugin system
# Functions created via `_add_docstr` in `functional.py` where merely typed as `Any` by `stubgen`, so those were
# deleted from the stub and replaced by generated declarations. See `gen_pyi` for the implementation of the code
# generation logic for those functions. In the future, it might be worth looking into using the mypy plugin system
# to encode the type semantics of `_add_docstr`, should that system ever become widespread.
def fractional_max_pool2d_with_indices(input: Tensor, kernel_size: _size, output_size: Optional[_size] = ...,
output_ratio: Optional[_ratio_any_t] = ..., return_indices: bool = ...,
Expand Down Expand Up @@ -311,6 +311,12 @@ def triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor, marg
reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...


def triplet_margin_loss_with_distance(anchor: Tensor, positive: Tensor, negative: Tensor,
distance_function: Optional[Callable[[Tensor, Tensor], Tensor]]=...,
is_similarity_function: bool=..., margin: float=...,
swap: bool=..., reduction: str=...) -> Tensor: ...


def normalize(input: Tensor, p: float = ..., dim: int = ..., eps: float = ...,
out: Optional[Tensor] = ...) -> Tensor: ...

Expand Down
6 changes: 3 additions & 3 deletions torch/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
Hardsigmoid, Hardswish, SiLU
from .loss import L1Loss, NLLLoss, KLDivLoss, MSELoss, BCELoss, BCEWithLogitsLoss, NLLLoss2d, \
CosineEmbeddingLoss, CTCLoss, HingeEmbeddingLoss, MarginRankingLoss, \
MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, \
SmoothL1Loss, SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, PoissonNLLLoss
MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, SmoothL1Loss, \
SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, TripletMarginLossWithDistance, PoissonNLLLoss
from .container import Container, Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict
from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \
MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, \
Expand Down Expand Up @@ -54,5 +54,5 @@
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU',
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'TripletMarginLossWithDistance'
]

0 comments on commit 4750845

Please sign in to comment.