Skip to content

Commit

Permalink
Update on "[pytorch] Add triplet margin loss with custom distance"
Browse files Browse the repository at this point in the history
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23363898](https://our.internmc.facebook.com/intern/diff/D23363898)

[ghstack-poisoned]
  • Loading branch information
ethch18 committed Sep 3, 2020
1 parent 9fb04f9 commit 2746a04
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 95 deletions.
97 changes: 54 additions & 43 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12074,38 +12074,47 @@ def test_triplet_margin_loss_with_distance_default_parity(self, device):
# implementations of triplet margin loss (``nn.TripletMarginLoss`
# and `F.triplet_margin_loss`) under *default args*.

anchor = torch.randn(5, 10, device=device, requires_grad=True)
positive = torch.randn(5, 10, device=device, requires_grad=True)
negative = torch.randn(5, 10, device=device, requires_grad=True)

# functional grad and parity check
self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_loss_with_distance(
a, p, n), (anchor, positive, negative)))
self.assertEqual(F.triplet_margin_loss_with_distance(anchor, positive, negative),
F.triplet_margin_loss(anchor, positive, negative))

# module grad and parity check
loss_base = nn.TripletMarginLoss()
loss_test = nn.TripletMarginLossWithDistance()
self.assertTrue(gradcheck(lambda a, p, n: loss_test(
a, p, n), (anchor, positive, negative)))
self.assertEqual(loss_test(anchor, positive, negative),
loss_base(anchor, positive, negative))
for extra_args in \
itertools.product((0.5, 1, 1.5), (True, False), ('none', 'mean', 'sum')):
kwargs = {'margin': extra_args[0], 'swap': extra_args[1], 'reduction': extra_args[2]}

anchor = torch.randn(5, 10, device=device, requires_grad=True)
positive = torch.randn(5, 10, device=device, requires_grad=True)
negative = torch.randn(5, 10, device=device, requires_grad=True)

# Test forward, functional
expected = F.triplet_margin_loss(anchor, positive, negative, **kwargs)
actual = F.triplet_margin_loss_with_distance(anchor, positive, negative, **kwargs)
self.assertEqual(actual, expected, rtol=1e-6, atol=1e-6)

# Test forward, module
loss_ref = nn.TripletMarginLoss(**kwargs)
loss_op = nn.TripletMarginLossWithDistance(**kwargs)
self.assertEqual(loss_op(anchor, positive, negative),
loss_ref(anchor, positive, negative),
rtol=1e-6, atol=1e-6)

# Test backward
self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_loss_with_distance(
a, p, n, **kwargs), (anchor, positive, negative)))
self.assertTrue(gradcheck(lambda a, p, n: loss_op(a, p, n),
(anchor, positive, negative)))

def test_triplet_margin_loss_with_distance(self, device):
# Test for `nn.TripletMarginLossWithDistance` and
# `F.triplet_margin_loss_with_distance`. Checks
# for parity against the respective non-distance-agnostic
# implementations of triplet margin loss (`nn.TripletMarginLoss`
# and `F.triplet_margin_loss`).
# Test for parity between `nn.TripletMarginLossWithDistance` and
# `F.triplet_margin_loss_with_distance`.

def pairwise_similarity(x, y):
return 1.0 - F.pairwise_distance(x, y)
pairwise_distance = nn.PairwiseDistance()
distance_functions = ((pairwise_similarity, True), (pairwise_distance, False))

reductions = ('mean', 'none')
margins = (1.0, 1.5)
def cosine_distance(x, y):
return 1.0 - F.cosine_similarity(x, y)
cosine_similarity = nn.CosineSimilarity()
distance_functions = ((pairwise_similarity, True), (pairwise_distance, False),
(cosine_similarity, True), (cosine_distance, False))

reductions = ('mean', 'none', 'sum')
margins = (1.0, 1.5, 0.5)
swaps = (True, False)

for (distance_fn, is_similarity_fn), reduction, margin, swap \
Expand All @@ -12114,27 +12123,29 @@ def pairwise_similarity(x, y):
positive = torch.randn(5, 10, device=device, requires_grad=True)
negative = torch.randn(5, 10, device=device, requires_grad=True)

# functional: standard gradient check
# Test backward
self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_loss_with_distance(
a, p, n, distance_function=distance_fn), (anchor, positive, negative)))
# functional: parity check
self.assertEqual(F.triplet_margin_loss_with_distance(anchor, positive, negative,
distance_function=distance_fn,
is_similarity_function=is_similarity_fn,
reduction=reduction, margin=margin, swap=swap),
F.triplet_margin_loss(anchor, positive, negative,
reduction=reduction, margin=margin, swap=swap))

loss_base = nn.TripletMarginLoss(reduction=reduction, margin=margin, swap=swap)
loss_test = nn.TripletMarginLossWithDistance(distance_function=distance_fn,
a, p, n, distance_function=distance_fn, is_similarity_function=is_similarity_fn,
reduction=reduction, margin=margin, swap=swap),
(anchor, positive, negative)))
loss_op = nn.TripletMarginLossWithDistance(distance_function=distance_fn,
is_similarity_function=is_similarity_fn,
reduction=reduction, margin=margin, swap=swap)
# module: standard gradient check
self.assertTrue(gradcheck(lambda a, p, n: loss_test(
self.assertTrue(gradcheck(lambda a, p, n: loss_op(
a, p, n), (anchor, positive, negative)))
traced_loss_op = torch.jit.trace(loss_op, (anchor, positive, negative))
self.assertTrue(gradcheck(lambda a, p, n: traced_loss_op(
a, p, n), (anchor, positive, negative)))
# module: parity check
self.assertEqual(loss_test(anchor, positive, negative),
loss_base(anchor, positive, negative))

# Test forward parity
functional = F.triplet_margin_loss_with_distance(anchor, positive, negative,
distance_function=distance_fn,
is_similarity_function=is_similarity_fn,
reduction=reduction, margin=margin, swap=swap)
modular = loss_op(anchor, positive, negative)
traced = traced_loss_op(anchor, positive, negative)
self.assertEqual(functional, modular, atol=1e-6, rtol=1e-6)
self.assertEqual(traced, modular, atol=1e-6, rtol=1e-6)


class TestModuleGlobalHooks(TestCase):
Expand Down
8 changes: 4 additions & 4 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3733,12 +3733,12 @@ def triplet_margin_loss_with_distance(anchor, positive, negative, distance_funct
margin=1.0, swap=False, reduction="mean"):
# type: (Tensor, Tensor, Tensor, Optional[Callable[[Tensor, Tensor], Tensor]], bool, float, bool, str) -> Tensor
r"""
See :class:`~torch.nn.TripletMarginLossWithDistance` for details
See :class:`~torch.nn.TripletMarginLossWithDistance` for details.
Note: does not support JIT scripting.
"""
if torch.jit.is_scripting():
raise NotImplementedError("F.triplet_margin_loss_with_distance does not support JIT scripting: "
"Callables cannot be scripted unless they are properties of "
"a module. Please use nn.TripletMarginLossWithDistance instead.")
"functions requiring Callables cannot be scripted.")

tens_ops = (anchor, positive, negative)
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
Expand All @@ -3763,8 +3763,8 @@ def triplet_margin_loss_with_distance(anchor, positive, negative, distance_funct
output = torch.clamp(negative_dist - positive_dist + margin, min=0.0)
else:
output = torch.clamp(positive_dist - negative_dist + margin, min=0.0)
reduction_enum = _Reduction.get_enum(reduction)

reduction_enum = _Reduction.get_enum(reduction)
if reduction_enum == 1:
return output.mean()
elif reduction_enum == 2:
Expand Down
75 changes: 27 additions & 48 deletions torch/nn/modules/loss.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import warnings

import torch

from .distance import PairwiseDistance
from .module import Module
from .. import functional as F
Expand Down Expand Up @@ -1221,8 +1219,8 @@ class TripletMarginLoss(_Loss):
Shape:
- Input: :math:`(N, D)` where :math:`D` is the vector dimension.
- Output: If :attr:`reduction` is ``'none'``, then a tensor of shape :math:`(N)`,
or a scalar otherwise.
- Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar
otherwise.
>>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
>>> anchor = torch.randn(100, 128, requires_grad=True)
Expand Down Expand Up @@ -1256,8 +1254,10 @@ def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:
class TripletMarginLossWithDistance(_Loss):
r"""Creates a criterion that measures the triplet loss given input
tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor,
positive, and negative examples, respectively); and a real-valued function
between them.
positive, and negative examples, respectively); and a nonnegative,
real-valued function ("distance function") used to compute the relationship
between the anchor and positive examples ("positive distance") and the
anchor and negative examples ("negative distance").
The unreduced loss (i.e., with `reduction` set to `'none'`)
can be described as:
Expand All @@ -1266,11 +1266,11 @@ class TripletMarginLossWithDistance(_Loss):
\ell(a, p, n) = L = \{l_1,\dots,l_N\}^\top, \quad
l_i = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\}
where :math:`N` is the batch size; :math:`d` is a real-valued function quantifying
the separation between two tensors, referred to as `distance_function`;
and :math:`margin` is a non-negative margin enforced between the positive and
negative distances. The input tensors have :math:`N` elements each and can be of
any shape that the distance function can handle.
where :math:`N` is the batch size; :math:`d` is a nonnegative, real-valued function
quantifying the relationship between two tensors, referred to as `distance_function`;
and :math:`margin` is a non-negative margin between the positive and negative
distances that is required for a 0 loss. The input tensors have :math:`N` elements
each and can be of any shape that the distance function can handle.
If :attr:`reduction` is not ``'none'``
(default ``'mean'``), then:
Expand All @@ -1285,18 +1285,21 @@ class TripletMarginLossWithDistance(_Loss):
See also :class:`~torch.nn.TripletMarginLoss`, which computes the triplet
loss for input tensors using the :math:`l_p` distance as the distance function.
Note: does not support JIT scripting.
Args:
distance_function (callable, optional): A distance function between two Tensors which,
if specified, will be used instead of the pairwise distance. If not specified,
distance_function (callable, optional): A nonnegative, real-valued function that
quantifies the relationship between two tensors. If not specified,
`nn.PairwiseDistance` will be used. Default: ``None``
is_similarity_function (bool, optional): Whether `distance_function` represents a
similarity metric, i.e., larger is closer. If True, computes the difference of
similarity metric, i.e., larger values are closer. If True, computes the difference of
distances as :math:`d(a_i, n_i) - d(a_i, p_i)` so that larger loss values occur
when the negative example is more similar to the anchor than the positive example
is. Default: ``False``
margin (float, optional): A non-negative margin enforced between the positive and
negative distances. Larger margins penalize cases where the negative examples
are not distant enough from the anchors, relative to the positives. Default: :math:`1`.
margin (float, optional): A non-negative margin representing the minimum difference
between the positive and negative distances required for a 0 loss. Larger margins
penalize cases where the negative examples are not distant enough from the anchors,
relative to the positives. Default: :math:`1`.
swap (bool, optional): Whether to use the distance swap described in the paper
`Learning shallow convolutional feature descriptors with triplet losses` by
V. Balntas, E. Riba et al. If True, and if the positive example is closer to the
Expand All @@ -1311,10 +1314,10 @@ class TripletMarginLossWithDistance(_Loss):
Shape:
- Input: :math:`(N, *)` where :math:`*` represents any number of additional dimensions
as supported by the distance function.
- Output: If :attr:`reduction` is ``'none'``, then a tensor of shape :math:`(N)`,
or a scalar otherwise.
- Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar
otherwise.
Example::
Examples::
>>> # Initialize embeddings
>>> embedding = nn.Embedding(1000, 128)
Expand Down Expand Up @@ -1361,34 +1364,10 @@ def __init__(self, distance_function: Optional[Callable[[Tensor, Tensor], Tensor
self.swap = swap

def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:
if not torch.jit.is_scripting():
return F.triplet_margin_loss_with_distance(anchor, positive, negative,
distance_function=self.distance_function,
is_similarity_function=self.is_similarity_function,
margin=self.margin, swap=self.swap, reduction=self.reduction)
else:
positive_dist = self.distance_function(anchor, positive)
negative_dist = self.distance_function(anchor, negative)

if self.swap:
swap_dist = self.distance_function(positive, negative)
if self.is_similarity_function:
negative_dist = torch.max(negative_dist, swap_dist)
else:
negative_dist = torch.min(negative_dist, swap_dist)

if self.is_similarity_function:
output = torch.clamp(negative_dist - positive_dist + self.margin, min=0.0)
else:
output = torch.clamp(positive_dist - negative_dist + self.margin, min=0.0)
reduction_enum = _Reduction.get_enum(self.reduction)

if reduction_enum == 1:
return output.mean()
elif reduction_enum == 2:
return output.sum()
else:
return output
return F.triplet_margin_loss_with_distance(anchor, positive, negative,
distance_function=self.distance_function,
is_similarity_function=self.is_similarity_function,
margin=self.margin, swap=self.swap, reduction=self.reduction)


class CTCLoss(_Loss):
Expand Down

0 comments on commit 2746a04

Please sign in to comment.