Skip to content

Commit

Permalink
[cpp] Add distance-agnostic triplet loss to core
Browse files Browse the repository at this point in the history
Summary: Following up on the C++ side of [this
issue](#43342).  The implementation
here is parallel to that of the Python one, but we don't use native functions
because Callables aren't supported.

Test Plan: Unit test with test_api

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
ethch18 committed Sep 2, 2020
1 parent 9fb04f9 commit b66374b
Show file tree
Hide file tree
Showing 6 changed files with 425 additions and 9 deletions.
92 changes: 92 additions & 0 deletions test/cpp/api/functional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,98 @@ TEST_F(FunctionalTest, TripletMarginLoss) {
ASSERT_TRUE(output.allclose(expected, 1e-04));
}

TEST_F(FunctionalTest, TripletMarginLossWithDistanceDefaultParity) {
/// Check that if we use torch::pairwise_distance with the default
/// TripletMarginLoss options as our distance function, the outputs
/// are equal (i.e., equal under defaults).
auto basicOptions = F::TripletMarginLossFuncOptions();
auto distanceOptions = F::TripletMarginLossWithDistanceFuncOptions();

auto anchor = torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
auto positive = torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
auto negative =
torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true));

auto basicOutput =
F::triplet_margin_loss(anchor, positive, negative, basicOptions);
auto distanceOutput = F::triplet_margin_loss_with_distance(
anchor, positive, negative, distanceOptions);

ASSERT_TRUE(distanceOutput.allclose(basicOutput));

distanceOutput.backward();
ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
ASSERT_EQ(positive.sizes(), positive.grad().sizes());
ASSERT_EQ(negative.sizes(), negative.grad().sizes());
}

TEST_F(FunctionalTest, TripletMarginLossWithDistance) {
/// Check that TripletMarginLoss and TripletMarginLossWithDistance
/// behave analogously irrespective of flags.
auto defaultBasicOptions = F::TripletMarginLossFuncOptions();
auto pairwise_distance = [&](const torch::Tensor& x, const torch::Tensor& y) {
return torch::pairwise_distance(
x, y, defaultBasicOptions.p(), defaultBasicOptions.eps());
};
auto pairwise_similarity = [&](const torch::Tensor& x,
const torch::Tensor& y) {
return 1.0 -
torch::pairwise_distance(
x, y, defaultBasicOptions.p(), defaultBasicOptions.eps());
};
std::vector<std::tuple<
F::TripletMarginLossWithDistanceFuncOptions::distance_function_t,
bool>>
distance_functions = {std::make_tuple(pairwise_distance, false),
std::make_tuple(pairwise_similarity, true)};

std::vector<F::TripletMarginLossWithDistanceFuncOptions::reduction_t>
reductions = {torch::kSum, torch::kMean, torch::kSum};
std::vector<float> margins = {1.0, 1.5};
std::vector<bool> swaps = {true, false};

for (auto& funcPair : distance_functions) {
for (auto& reduction : reductions) {
for (auto& margin : margins) {
for (const auto& swap : swaps) {
auto basicOptions = F::TripletMarginLossFuncOptions()
.reduction(reduction)
.margin(margin)
.swap(swap);
auto distanceOptions =
F::TripletMarginLossWithDistanceFuncOptions()
.distance_function(std::get<0>(funcPair))
.is_similarity_function(std::get<1>(funcPair))
.reduction(reduction)
.margin(margin)
.swap(swap);

auto anchor = torch::randn(
{100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
auto positive = torch::randn(
{100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
auto negative = torch::randn(
{100, 128}, torch::dtype(torch::kFloat).requires_grad(true));

auto basicOutput =
F::triplet_margin_loss(anchor, positive, negative, basicOptions);
auto distanceOutput = F::triplet_margin_loss_with_distance(
anchor, positive, negative, distanceOptions);

ASSERT_TRUE(distanceOutput.allclose(basicOutput));

// handle for torch::kNone reduction
auto sum = distanceOutput.sum();
distanceOutput.backward();
ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
ASSERT_EQ(positive.sizes(), positive.grad().sizes());
ASSERT_EQ(negative.sizes(), negative.grad().sizes());
}
}
}
}
}

TEST_F(FunctionalTest, NLLLoss) {
auto input = torch::tensor({{-0.1315, -3.1315, -2.5315},
{-3.7038, -0.1038, -2.6038},
Expand Down
125 changes: 123 additions & 2 deletions test/cpp/api/modules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2085,6 +2085,100 @@ TEST_F(ModulesTest, TripletMarginLoss) {
ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
}

TEST_F(ModulesTest, TripletMarginLossWithDistanceDefaultParity) {
/// Check that if we use torch::pairwise_distance with the default
/// TripletMarginLoss options as our distance function, the outputs
/// are equal (i.e., equal under defaults).
auto basicOptions = TripletMarginLossOptions();
auto distanceOptions = TripletMarginLossWithDistanceOptions();

auto anchor = torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
auto positive = torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
auto negative =
torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true));

TripletMarginLoss basicLoss(basicOptions);
TripletMarginLossWithDistance distanceLoss(distanceOptions);

auto basicOutput = basicLoss->forward(anchor, positive, negative);
auto distanceOutput = distanceLoss->forward(anchor, positive, negative);

ASSERT_TRUE(distanceOutput.allclose(basicOutput));

distanceOutput.backward();
ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
ASSERT_EQ(positive.sizes(), positive.grad().sizes());
ASSERT_EQ(negative.sizes(), negative.grad().sizes());
}

TEST_F(ModulesTest, TripletMarginLossWithDistance) {
/// Check that TripletMarginLoss and TripletMarginLossWithDistance
/// behave analogously irrespective of flags.
auto defaultBasicOptions = TripletMarginLossOptions();
auto pairwise_distance = [&](const torch::Tensor& x, const torch::Tensor& y) {
return torch::pairwise_distance(
x, y, defaultBasicOptions.p(), defaultBasicOptions.eps());
};
auto pairwise_similarity = [&](const torch::Tensor& x,
const torch::Tensor& y) {
return 1.0 -
torch::pairwise_distance(
x, y, defaultBasicOptions.p(), defaultBasicOptions.eps());
};
std::vector<std::tuple<
TripletMarginLossWithDistanceOptions::distance_function_t,
bool>>
distance_functions = {std::make_tuple(pairwise_distance, false),
std::make_tuple(pairwise_similarity, true)};

std::vector<TripletMarginLossWithDistanceOptions::reduction_t>
reductions = {torch::kSum, torch::kMean, torch::kSum};
std::vector<float> margins = {1.0, 1.5};
std::vector<bool> swaps = {true, false};

for (auto& funcPair : distance_functions) {
for (auto& reduction : reductions) {
for (auto& margin : margins) {
for (const auto& swap : swaps) {
auto basicOptions = TripletMarginLossOptions()
.reduction(reduction)
.margin(margin)
.swap(swap);
auto distanceOptions =
TripletMarginLossWithDistanceOptions()
.distance_function(std::get<0>(funcPair))
.is_similarity_function(std::get<1>(funcPair))
.reduction(reduction)
.margin(margin)
.swap(swap);

auto anchor = torch::randn(
{100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
auto positive = torch::randn(
{100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
auto negative = torch::randn(
{100, 128}, torch::dtype(torch::kFloat).requires_grad(true));

TripletMarginLoss basicLoss(basicOptions);
TripletMarginLossWithDistance distanceLoss(distanceOptions);

auto basicOutput = basicLoss->forward(anchor, positive, negative);
auto distanceOutput = distanceLoss->forward(anchor, positive, negative);

ASSERT_TRUE(distanceOutput.allclose(basicOutput));

// handle for torch::kNone reduction
auto sum = distanceOutput.sum();
distanceOutput.backward();
ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
ASSERT_EQ(positive.sizes(), positive.grad().sizes());
ASSERT_EQ(negative.sizes(), negative.grad().sizes());
}
}
}
}
}

TEST_F(ModulesTest, NLLLoss) {
NLLLoss loss;
auto input = torch::tensor({{-0.1315, -3.1315, -2.5315},
Expand Down Expand Up @@ -3529,9 +3623,9 @@ TEST_F(ModulesTest, PrettyPrintIdentity) {
}

TEST_F(ModulesTest, PrettyPrintFlatten) {
ASSERT_EQ(c10::str(Flatten()),
ASSERT_EQ(c10::str(Flatten()),
"torch::nn::Flatten(start_dim=1, end_dim=-1)");
ASSERT_EQ(c10::str(Flatten(FlattenOptions().start_dim(2).end_dim(4))),
ASSERT_EQ(c10::str(Flatten(FlattenOptions().start_dim(2).end_dim(4))),
"torch::nn::Flatten(start_dim=2, end_dim=4)");
}

Expand Down Expand Up @@ -4394,6 +4488,33 @@ TEST_F(ModulesTest, PrettyPrintTripletMarginLoss) {
"torch::nn::TripletMarginLoss(margin=3, p=2, eps=1e-06, swap=false)");
}

TEST_F(ModulesTest, PrettyPrintTripletMarginLossWithDistance) {
auto distanceOptions = TripletMarginLossWithDistanceOptions()
.distance_function([&](const torch::Tensor& x,
const torch::Tensor& y) {
return torch::pairwise_distance(x, y, 2.0, 1e-6);
})
.margin(1.5)
.swap(true)
.reduction(torch::kMean);
ASSERT_EQ(
c10::str(TripletMarginLossWithDistance(distanceOptions)),
"torch::nn::TripletMarginLossWithDistance(margin=1.5, swap=true, is_similarity_function=false)");
}

TEST_F(ModulesTest, PrettyPrintTripletMarginLossWithDistanceSimilarity) {
auto distanceOptions = TripletMarginLossWithDistanceOptions()
.distance_function([&](const torch::Tensor& x,
const torch::Tensor& y) {
return 1.0 - torch::pairwise_distance(x, y, 2.0, 1e-6);
})
.reduction(torch::kMean)
.is_similarity_function(true);
ASSERT_EQ(
c10::str(TripletMarginLossWithDistance(distanceOptions)),
"torch::nn::TripletMarginLossWithDistance(margin=1, swap=false, is_similarity_function=true)");
}

TEST_F(ModulesTest, PrettyPrintNLLLoss) {
ASSERT_EQ(
c10::str(NLLLoss()), "torch::nn::NLLLoss()");
Expand Down
89 changes: 89 additions & 0 deletions torch/csrc/api/include/torch/nn/functional/loss.h
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,95 @@ inline Tensor triplet_margin_loss(

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor triplet_margin_loss_with_distance(
const Tensor& anchor,
const Tensor& positive,
const Tensor& negative,
c10::optional<TripletMarginLossWithDistanceFuncOptions::distance_function_t> distance_function,
bool is_similarity_function,
double margin,
bool swap,
TripletMarginLossWithDistanceFuncOptions::reduction_t reduction) {
Tensor dist_pos, dist_neg;
if (distance_function.has_value()) {
auto distance_function_impl = distance_function.value();
dist_pos = distance_function_impl(anchor, positive);
dist_neg = distance_function_impl(anchor, negative);
} else {
dist_pos = pairwise_distance(anchor, positive);
dist_neg = pairwise_distance(anchor, negative);
}

if (swap) {
Tensor dist_swap;
if (distance_function.has_value()) {
dist_swap = distance_function.value()(positive, negative);
} else {
dist_swap = pairwise_distance(positive, negative);
}
if (is_similarity_function) {
dist_neg = torch::max(dist_neg, dist_swap);
} else {
dist_neg = torch::min(dist_neg, dist_swap);
}
}

Tensor loss, ret;
if (is_similarity_function) {
loss = torch::clamp_min(dist_neg - dist_pos + margin, 0);
} else {
loss = torch::clamp_min(dist_pos - dist_neg + margin, 0);
}

if (c10::get_if<enumtype::kNone>(&reduction)) {
ret = loss;
} else if (c10::get_if<enumtype::kMean>(&reduction)) {
ret = loss.mean();
} else if (c10::get_if<enumtype::kSum>(&reduction)) {
ret = loss.sum();
} else {
ret = anchor;
TORCH_INTERNAL_ASSERT(
false,
enumtype::get_enum_name(reduction),
" is not valid");
}
return ret;
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.triplet_margin_loss_with_distance
/// about the exact behavior of this functional.
///
/// See the documentation for `torch::nn::functional::TripletMarginLossWithDistanceFuncOptions` class to learn what
/// optional arguments are supported for this functional.
///
/// Example:
/// ```
/// namespace F = torch::nn::functional;
/// F::triplet_margin_loss_with_distance(anchor, positive, negative, F::TripletMarginLossWithDistanceFuncOptions().margin(1.0));
/// ```
inline Tensor triplet_margin_loss_with_distance(
const Tensor& anchor,
const Tensor& positive,
const Tensor& negative,
const TripletMarginLossWithDistanceFuncOptions& options = {}) {
return detail::triplet_margin_loss_with_distance(
anchor,
positive,
negative,
options.distance_function(),
options.is_similarity_function(),
options.margin(),
options.swap(),
options.reduction());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor ctc_loss(const Tensor& log_probs,
Expand Down

0 comments on commit b66374b

Please sign in to comment.