Skip to content

Commit

Permalink
[c++] Distance-agnostic triplet margin loss (#45377)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #45377

This PR adds a C++ implementation of the TripletMarginWithDistanceLoss, for which the Python implementation was introduced in PR #43680.  It's based on PR #44072, but I'm resubmitting this to unlink it from Phabricator.

Test Plan: Imported from OSS

Reviewed By: izdeby

Differential Revision: D24003973

fbshipit-source-id: 2d9ada7260a6f27425ff2fdbbf623dad0fb79405
  • Loading branch information
Xinyu Li authored and facebook-github-bot committed Sep 30, 2020
1 parent 181afd5 commit c9bb990
Show file tree
Hide file tree
Showing 7 changed files with 379 additions and 11 deletions.
50 changes: 50 additions & 0 deletions test/cpp/api/functional.cpp
Expand Up @@ -682,6 +682,56 @@ TEST_F(FunctionalTest, TripletMarginLoss) {
ASSERT_TRUE(output.allclose(expected, 1e-04));
}

TEST_F(FunctionalTest, TripletMarginWithDistanceLossDefaultParity) {
// Check that if we use torch::pairwise_distance with the default
// TripletMarginLoss options as our distance function, the outputs
// are equal (i.e., equal under defaults).

std::vector<TripletMarginWithDistanceLossOptions::reduction_t>
reductions = {torch::kSum, torch::kMean, torch::kNone};
std::vector<float> margins = {0.5, 1.0, 1.5};
std::vector<bool> swaps = {true, false};

for (auto& reduction : reductions) {
for (auto& margin : margins) {
for (const auto& swap : swaps) {
auto anchor =
torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
auto positive =
torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
auto negative =
torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true));

auto basicOptions = F::TripletMarginLossFuncOptions()
.reduction(reduction)
.margin(margin)
.swap(swap);
auto distanceOptions =
F::TripletMarginWithDistanceLossFuncOptions()
.reduction(reduction)
.margin(margin)
.swap(swap);
TripletMarginLoss basicLoss(basicOptions);
TripletMarginWithDistanceLoss distanceLoss(distanceOptions);

auto basicOutput =
F::triplet_margin_loss(anchor, positive, negative, basicOptions);
auto distanceOutput = F::triplet_margin_with_distance_loss(
anchor, positive, negative, distanceOptions);

ASSERT_TRUE(distanceOutput.allclose(basicOutput, 1e-6, 1e-6));

// handle for torch::kNone reduction
auto sum = distanceOutput.sum();
sum.backward();
ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
ASSERT_EQ(positive.sizes(), positive.grad().sizes());
ASSERT_EQ(negative.sizes(), negative.grad().sizes());
}
}
}
}

TEST_F(FunctionalTest, NLLLoss) {
auto input = torch::tensor({{-0.1315, -3.1315, -2.5315},
{-3.7038, -0.1038, -2.6038},
Expand Down
127 changes: 125 additions & 2 deletions test/cpp/api/modules.cpp
Expand Up @@ -2085,6 +2085,115 @@ TEST_F(ModulesTest, TripletMarginLoss) {
ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
}

TEST_F(ModulesTest, TripletMarginWithDistanceLossDefaultParity) {
// Check that if we use torch::pairwise_distance with the default
// TripletMarginLoss options as our distance function, the outputs
// are equal (i.e., equal under defaults).

std::vector<TripletMarginWithDistanceLossOptions::reduction_t>
reductions = {torch::kSum, torch::kMean, torch::kNone};
std::vector<float> margins = {0.5, 1.0, 1.5};
std::vector<bool> swaps = {true, false};

for (auto& reduction : reductions) {
for (auto& margin : margins) {
for (const auto& swap : swaps) {
auto anchor =
torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
auto positive =
torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
auto negative =
torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true));

auto basicOptions = TripletMarginLossOptions()
.reduction(reduction)
.margin(margin)
.swap(swap);
auto distanceOptions =
TripletMarginWithDistanceLossOptions()
.reduction(reduction)
.margin(margin)
.swap(swap);
TripletMarginLoss basicLoss(basicOptions);
TripletMarginWithDistanceLoss distanceLoss(distanceOptions);

auto basicOutput = basicLoss->forward(anchor, positive, negative);
auto distanceOutput = distanceLoss->forward(anchor, positive, negative);
auto basicOperatorOutput = basicLoss(anchor, positive, negative);
auto distanceOperatorOutput = distanceLoss(anchor, positive, negative);

ASSERT_TRUE(distanceOutput.allclose(basicOutput, 1e-6, 1e-6));
ASSERT_TRUE(distanceOperatorOutput.allclose(distanceOutput, 1e-6, 1e-6));
ASSERT_TRUE(distanceOperatorOutput.allclose(basicOperatorOutput, 1e-6, 1e-6));

// handle for torch::kNone reduction
auto sum = distanceOutput.sum();
sum.backward();
ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
ASSERT_EQ(positive.sizes(), positive.grad().sizes());
ASSERT_EQ(negative.sizes(), negative.grad().sizes());
}
}
}
}

TEST_F(ModulesTest, TripletMarginWithDistanceLossFunctionalParity) {
// Check for parity between F::triplet_margin_with_distance_loss and
// TripletMarginWithDistanceLoss.
auto pairwise_distance = [&](const torch::Tensor& x, const torch::Tensor& y) {
return torch::pairwise_distance(x, y);
};
auto cosine_distance = [&](const torch::Tensor& x,
const torch::Tensor& y) {
return 1.0 - torch::cosine_similarity(x, y);
};
std::vector<TripletMarginWithDistanceLossOptions::distance_function_t>
distance_functions = {pairwise_distance, cosine_distance};

std::vector<TripletMarginWithDistanceLossOptions::reduction_t>
reductions = {torch::kSum, torch::kMean, torch::kNone};
std::vector<float> margins = {0.5, 1.0, 1.5};
std::vector<bool> swaps = {true, false};

for (auto& function : distance_functions) {
for (auto& reduction : reductions) {
for (auto& margin : margins) {
for (const auto& swap : swaps) {
auto moduleOptions =
TripletMarginWithDistanceLossOptions()
.distance_function(function)
.reduction(reduction)
.margin(margin)
.swap(swap);
auto functionOptions =
torch::nn::functional::TripletMarginWithDistanceLossFuncOptions()
.distance_function(function)
.reduction(reduction)
.margin(margin)
.swap(swap);

auto anchor = torch::randn(
{100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
auto positive = torch::randn(
{100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
auto negative = torch::randn(
{100, 128}, torch::dtype(torch::kFloat).requires_grad(true));

TripletMarginWithDistanceLoss distanceLoss(moduleOptions);

auto moduleOutput = distanceLoss->forward(anchor, positive, negative);
auto moduleOperatorOutput = distanceLoss(anchor, positive, negative);
auto functionOutput = torch::nn::functional::triplet_margin_with_distance_loss(
anchor, positive, negative, functionOptions);

ASSERT_TRUE(moduleOutput.allclose(functionOutput, 1e-6, 1e-6));
ASSERT_TRUE(moduleOperatorOutput.allclose(functionOutput, 1e-6, 1e-6));
}
}
}
}
}

TEST_F(ModulesTest, NLLLoss) {
NLLLoss loss;
auto input = torch::tensor({{-0.1315, -3.1315, -2.5315},
Expand Down Expand Up @@ -3529,9 +3638,9 @@ TEST_F(ModulesTest, PrettyPrintIdentity) {
}

TEST_F(ModulesTest, PrettyPrintFlatten) {
ASSERT_EQ(c10::str(Flatten()),
ASSERT_EQ(c10::str(Flatten()),
"torch::nn::Flatten(start_dim=1, end_dim=-1)");
ASSERT_EQ(c10::str(Flatten(FlattenOptions().start_dim(2).end_dim(4))),
ASSERT_EQ(c10::str(Flatten(FlattenOptions().start_dim(2).end_dim(4))),
"torch::nn::Flatten(start_dim=2, end_dim=4)");
}

Expand Down Expand Up @@ -4394,6 +4503,20 @@ TEST_F(ModulesTest, PrettyPrintTripletMarginLoss) {
"torch::nn::TripletMarginLoss(margin=3, p=2, eps=1e-06, swap=false)");
}

TEST_F(ModulesTest, PrettyPrintTripletMarginWithDistanceLoss) {
auto distanceOptions = TripletMarginWithDistanceLossOptions()
.distance_function([&](const torch::Tensor& x,
const torch::Tensor& y) {
return torch::pairwise_distance(x, y, 2.0, 1e-6);
})
.margin(1.5)
.swap(true)
.reduction(torch::kMean);
ASSERT_EQ(
c10::str(TripletMarginWithDistanceLoss(distanceOptions)),
"torch::nn::TripletMarginWithDistanceLoss(margin=1.5, swap=true)");
}

TEST_F(ModulesTest, PrettyPrintNLLLoss) {
ASSERT_EQ(
c10::str(NLLLoss()), "torch::nn::NLLLoss()");
Expand Down
79 changes: 79 additions & 0 deletions torch/csrc/api/include/torch/nn/functional/loss.h
Expand Up @@ -527,6 +527,85 @@ inline Tensor triplet_margin_loss(

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor triplet_margin_with_distance_loss(
const Tensor& anchor,
const Tensor& positive,
const Tensor& negative,
c10::optional<TripletMarginWithDistanceLossFuncOptions::distance_function_t> distance_function,
double margin,
bool swap,
TripletMarginWithDistanceLossFuncOptions::reduction_t reduction) {
Tensor dist_pos, dist_neg;
if (distance_function.has_value()) {
auto distance_function_impl = distance_function.value();
dist_pos = distance_function_impl(anchor, positive);
dist_neg = distance_function_impl(anchor, negative);
} else {
dist_pos = pairwise_distance(anchor, positive);
dist_neg = pairwise_distance(anchor, negative);
}

if (swap) {
Tensor dist_swap;
if (distance_function.has_value()) {
dist_swap = distance_function.value()(positive, negative);
} else {
dist_swap = pairwise_distance(positive, negative);
}
dist_neg = torch::min(dist_neg, dist_swap);
}

auto loss = torch::clamp_min(dist_pos - dist_neg + margin, 0);

Tensor ret;
if (c10::get_if<enumtype::kNone>(&reduction)) {
ret = loss;
} else if (c10::get_if<enumtype::kMean>(&reduction)) {
ret = loss.mean();
} else if (c10::get_if<enumtype::kSum>(&reduction)) {
ret = loss.sum();
} else {
ret = anchor;
TORCH_INTERNAL_ASSERT(
false,
enumtype::get_enum_name(reduction),
" is not valid");
}
return ret;
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.triplet_margin_with_distance_loss
/// about the exact behavior of this functional.
///
/// See the documentation for `torch::nn::functional::TripletMarginWithDistanceLossFuncOptions` class to learn what
/// optional arguments are supported for this functional.
///
/// Example:
/// ```
/// namespace F = torch::nn::functional;
/// F::triplet_margin_with_distance_loss(anchor, positive, negative, F::TripletMarginWithDistanceLossFuncOptions().margin(1.0));
/// ```
inline Tensor triplet_margin_with_distance_loss(
const Tensor& anchor,
const Tensor& positive,
const Tensor& negative,
const TripletMarginWithDistanceLossFuncOptions& options = {}) {
return detail::triplet_margin_with_distance_loss(
anchor,
positive,
negative,
options.distance_function(),
options.margin(),
options.swap(),
options.reduction());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor ctc_loss(const Tensor& log_probs,
Expand Down
54 changes: 49 additions & 5 deletions torch/csrc/api/include/torch/nn/modules/loss.h
Expand Up @@ -309,7 +309,7 @@ struct TORCH_API SmoothL1LossImpl : public Cloneable<SmoothL1LossImpl> {
TORCH_MODULE(SmoothL1Loss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiLabelMarginLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

/// Creates a criterion that optimizes a multi-class multi-classification
/// hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`)
/// and output :math:`y` (which is a 2D `Tensor` of target class indices).
Expand Down Expand Up @@ -421,9 +421,9 @@ TORCH_MODULE(MultiLabelSoftMarginLoss);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TripletMarginLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

/// Creates a criterion that measures the triplet loss given an input
/// tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater
/// tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater
/// than :math:`0`. This is used for measuring a relative similarity between
/// samples. A triplet is composed by `a`, `p` and `n` (i.e., `anchor`,
/// samples. A triplet is composed by `a`, `p` and `n` (i.e., `anchor`,
/// `positive examples` and `negative examples` respectively). The
/// shapes of all input tensors should be :math:`(N, D)`.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.TripletMarginLoss to learn
Expand Down Expand Up @@ -461,6 +461,50 @@ struct TORCH_API TripletMarginLossImpl : public Cloneable<TripletMarginLossImpl>
/// module storage semantics.
TORCH_MODULE(TripletMarginLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TripletMarginWithDistanceLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

/// Creates a criterion that measures the triplet loss given input
/// tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor,
/// positive, and negative examples, respectively); and a nonnegative, real-valued function
/// ("distance function") used to compute the relationships between the anchor
/// and positive example ("positive distance") and the anchor and negative
/// example ("negative distance").
/// See https://pytorch.org/docs/master/nn.html#torch.nn.TripletMarginWithDistanceLoss to learn
/// about the exact behavior of this module.
///
/// See the documentation for `torch::nn::TripletMarginWithDistanceLossOptions` class to learn what
/// constructor arguments are supported for this module.
///
/// Example:
/// ```
/// TripletMarginWithDistanceLoss model(TripletMarginWithDistanceLossOptions().margin(3).swap(false));
/// ```
struct TORCH_API TripletMarginWithDistanceLossImpl : public Cloneable<TripletMarginWithDistanceLossImpl> {
explicit TripletMarginWithDistanceLossImpl(
TripletMarginWithDistanceLossOptions options_ = {});

void reset() override;

/// Pretty prints the `TripletMarginWithDistanceLoss` module into the given `stream`.
void pretty_print(std::ostream& stream) const override;

Tensor forward(
const Tensor& anchor,
const Tensor& positive,
const Tensor& negative);

/// The options with which this `Module` was constructed.
TripletMarginWithDistanceLossOptions options;
};

/// A `ModuleHolder` subclass for `TripletMarginWithDistanceLossImpl`.
/// See the documentation for `TripletMarginWithDistanceLossImpl` class to learn what methods it
/// provides, and examples of how to use `TripletMarginWithDistanceLoss` with
/// `torch::nn::TripletMarginWithDistanceLossOptions`.
/// See the documentation for `ModuleHolder` to learn about PyTorch's
/// module storage semantics.
TORCH_MODULE(TripletMarginWithDistanceLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CTCLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

/// The Connectionist Temporal Classification loss.
Expand Down Expand Up @@ -626,9 +670,9 @@ TORCH_MODULE(NLLLoss);
struct TORCH_API CrossEntropyLossImpl : public Cloneable<CrossEntropyLossImpl> {
explicit CrossEntropyLossImpl(
const CrossEntropyLossOptions& options_ = {});

void reset() override;

/// Pretty prints the `CrossEntropyLoss` module into the given `stream`.
void pretty_print(std::ostream& stream) const override;

Expand Down

0 comments on commit c9bb990

Please sign in to comment.