Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[c++] Distance-agnostic triplet margin loss #45377

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 50 additions & 0 deletions test/cpp/api/functional.cpp
Expand Up @@ -682,6 +682,56 @@ TEST_F(FunctionalTest, TripletMarginLoss) {
ASSERT_TRUE(output.allclose(expected, 1e-04));
}

TEST_F(FunctionalTest, TripletMarginWithDistanceLossDefaultParity) {
// Check that if we use torch::pairwise_distance with the default
// TripletMarginLoss options as our distance function, the outputs
// are equal (i.e., equal under defaults).

std::vector<TripletMarginWithDistanceLossOptions::reduction_t>
reductions = {torch::kSum, torch::kMean, torch::kNone};
std::vector<float> margins = {0.5, 1.0, 1.5};
std::vector<bool> swaps = {true, false};

for (auto& reduction : reductions) {
for (auto& margin : margins) {
for (const auto& swap : swaps) {
auto anchor =
torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
auto positive =
torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
auto negative =
torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true));

auto basicOptions = F::TripletMarginLossFuncOptions()
.reduction(reduction)
.margin(margin)
.swap(swap);
auto distanceOptions =
F::TripletMarginWithDistanceLossFuncOptions()
.reduction(reduction)
.margin(margin)
.swap(swap);
TripletMarginLoss basicLoss(basicOptions);
TripletMarginWithDistanceLoss distanceLoss(distanceOptions);

auto basicOutput =
F::triplet_margin_loss(anchor, positive, negative, basicOptions);
auto distanceOutput = F::triplet_margin_with_distance_loss(
anchor, positive, negative, distanceOptions);

ASSERT_TRUE(distanceOutput.allclose(basicOutput, 1e-6, 1e-6));

// handle for torch::kNone reduction
auto sum = distanceOutput.sum();
sum.backward();
ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
ASSERT_EQ(positive.sizes(), positive.grad().sizes());
ASSERT_EQ(negative.sizes(), negative.grad().sizes());
}
}
}
}

TEST_F(FunctionalTest, NLLLoss) {
auto input = torch::tensor({{-0.1315, -3.1315, -2.5315},
{-3.7038, -0.1038, -2.6038},
Expand Down
127 changes: 125 additions & 2 deletions test/cpp/api/modules.cpp
Expand Up @@ -2085,6 +2085,115 @@ TEST_F(ModulesTest, TripletMarginLoss) {
ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
}

TEST_F(ModulesTest, TripletMarginWithDistanceLossDefaultParity) {
// Check that if we use torch::pairwise_distance with the default
// TripletMarginLoss options as our distance function, the outputs
// are equal (i.e., equal under defaults).

std::vector<TripletMarginWithDistanceLossOptions::reduction_t>
reductions = {torch::kSum, torch::kMean, torch::kNone};
std::vector<float> margins = {0.5, 1.0, 1.5};
std::vector<bool> swaps = {true, false};

for (auto& reduction : reductions) {
for (auto& margin : margins) {
for (const auto& swap : swaps) {
auto anchor =
torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
auto positive =
torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
auto negative =
torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true));

auto basicOptions = TripletMarginLossOptions()
.reduction(reduction)
.margin(margin)
.swap(swap);
auto distanceOptions =
TripletMarginWithDistanceLossOptions()
.reduction(reduction)
.margin(margin)
.swap(swap);
TripletMarginLoss basicLoss(basicOptions);
TripletMarginWithDistanceLoss distanceLoss(distanceOptions);

auto basicOutput = basicLoss->forward(anchor, positive, negative);
auto distanceOutput = distanceLoss->forward(anchor, positive, negative);
ethch18 marked this conversation as resolved.
Show resolved Hide resolved
auto basicOperatorOutput = basicLoss(anchor, positive, negative);
auto distanceOperatorOutput = distanceLoss(anchor, positive, negative);

ASSERT_TRUE(distanceOutput.allclose(basicOutput, 1e-6, 1e-6));
ASSERT_TRUE(distanceOperatorOutput.allclose(distanceOutput, 1e-6, 1e-6));
ASSERT_TRUE(distanceOperatorOutput.allclose(basicOperatorOutput, 1e-6, 1e-6));

// handle for torch::kNone reduction
auto sum = distanceOutput.sum();
sum.backward();
ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
ASSERT_EQ(positive.sizes(), positive.grad().sizes());
ASSERT_EQ(negative.sizes(), negative.grad().sizes());
}
}
}
}

TEST_F(ModulesTest, TripletMarginWithDistanceLossFunctionalParity) {
// Check for parity between F::triplet_margin_with_distance_loss and
// TripletMarginWithDistanceLoss.
auto pairwise_distance = [&](const torch::Tensor& x, const torch::Tensor& y) {
return torch::pairwise_distance(x, y);
};
auto cosine_distance = [&](const torch::Tensor& x,
const torch::Tensor& y) {
return 1.0 - torch::cosine_similarity(x, y);
};
std::vector<TripletMarginWithDistanceLossOptions::distance_function_t>
distance_functions = {pairwise_distance, cosine_distance};

std::vector<TripletMarginWithDistanceLossOptions::reduction_t>
reductions = {torch::kSum, torch::kMean, torch::kNone};
std::vector<float> margins = {0.5, 1.0, 1.5};
std::vector<bool> swaps = {true, false};

for (auto& function : distance_functions) {
for (auto& reduction : reductions) {
for (auto& margin : margins) {
for (const auto& swap : swaps) {
auto moduleOptions =
TripletMarginWithDistanceLossOptions()
.distance_function(function)
.reduction(reduction)
.margin(margin)
.swap(swap);
auto functionOptions =
torch::nn::functional::TripletMarginWithDistanceLossFuncOptions()
.distance_function(function)
.reduction(reduction)
.margin(margin)
.swap(swap);

auto anchor = torch::randn(
{100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
auto positive = torch::randn(
{100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
auto negative = torch::randn(
{100, 128}, torch::dtype(torch::kFloat).requires_grad(true));

TripletMarginWithDistanceLoss distanceLoss(moduleOptions);

auto moduleOutput = distanceLoss->forward(anchor, positive, negative);
auto moduleOperatorOutput = distanceLoss(anchor, positive, negative);
auto functionOutput = torch::nn::functional::triplet_margin_with_distance_loss(
anchor, positive, negative, functionOptions);

ASSERT_TRUE(moduleOutput.allclose(functionOutput, 1e-6, 1e-6));
ASSERT_TRUE(moduleOperatorOutput.allclose(functionOutput, 1e-6, 1e-6));
}
}
}
}
}

TEST_F(ModulesTest, NLLLoss) {
NLLLoss loss;
auto input = torch::tensor({{-0.1315, -3.1315, -2.5315},
Expand Down Expand Up @@ -3529,9 +3638,9 @@ TEST_F(ModulesTest, PrettyPrintIdentity) {
}

TEST_F(ModulesTest, PrettyPrintFlatten) {
ASSERT_EQ(c10::str(Flatten()),
ASSERT_EQ(c10::str(Flatten()),
"torch::nn::Flatten(start_dim=1, end_dim=-1)");
ASSERT_EQ(c10::str(Flatten(FlattenOptions().start_dim(2).end_dim(4))),
ASSERT_EQ(c10::str(Flatten(FlattenOptions().start_dim(2).end_dim(4))),
"torch::nn::Flatten(start_dim=2, end_dim=4)");
}

Expand Down Expand Up @@ -4394,6 +4503,20 @@ TEST_F(ModulesTest, PrettyPrintTripletMarginLoss) {
"torch::nn::TripletMarginLoss(margin=3, p=2, eps=1e-06, swap=false)");
}

TEST_F(ModulesTest, PrettyPrintTripletMarginWithDistanceLoss) {
auto distanceOptions = TripletMarginWithDistanceLossOptions()
.distance_function([&](const torch::Tensor& x,
const torch::Tensor& y) {
return torch::pairwise_distance(x, y, 2.0, 1e-6);
})
.margin(1.5)
.swap(true)
.reduction(torch::kMean);
ASSERT_EQ(
c10::str(TripletMarginWithDistanceLoss(distanceOptions)),
"torch::nn::TripletMarginWithDistanceLoss(margin=1.5, swap=true)");
}

TEST_F(ModulesTest, PrettyPrintNLLLoss) {
ASSERT_EQ(
c10::str(NLLLoss()), "torch::nn::NLLLoss()");
Expand Down
79 changes: 79 additions & 0 deletions torch/csrc/api/include/torch/nn/functional/loss.h
Expand Up @@ -527,6 +527,85 @@ inline Tensor triplet_margin_loss(

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor triplet_margin_with_distance_loss(
const Tensor& anchor,
const Tensor& positive,
const Tensor& negative,
c10::optional<TripletMarginWithDistanceLossFuncOptions::distance_function_t> distance_function,
double margin,
bool swap,
TripletMarginWithDistanceLossFuncOptions::reduction_t reduction) {
Tensor dist_pos, dist_neg;
if (distance_function.has_value()) {
auto distance_function_impl = distance_function.value();
dist_pos = distance_function_impl(anchor, positive);
dist_neg = distance_function_impl(anchor, negative);
} else {
dist_pos = pairwise_distance(anchor, positive);
dist_neg = pairwise_distance(anchor, negative);
}

if (swap) {
Tensor dist_swap;
if (distance_function.has_value()) {
dist_swap = distance_function.value()(positive, negative);
} else {
dist_swap = pairwise_distance(positive, negative);
}
dist_neg = torch::min(dist_neg, dist_swap);
}

auto loss = torch::clamp_min(dist_pos - dist_neg + margin, 0);

Tensor ret;
if (c10::get_if<enumtype::kNone>(&reduction)) {
ret = loss;
} else if (c10::get_if<enumtype::kMean>(&reduction)) {
ret = loss.mean();
} else if (c10::get_if<enumtype::kSum>(&reduction)) {
ret = loss.sum();
} else {
ret = anchor;
TORCH_INTERNAL_ASSERT(
false,
enumtype::get_enum_name(reduction),
" is not valid");
}
return ret;
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.triplet_margin_with_distance_loss
/// about the exact behavior of this functional.
///
/// See the documentation for `torch::nn::functional::TripletMarginWithDistanceLossFuncOptions` class to learn what
/// optional arguments are supported for this functional.
///
/// Example:
/// ```
/// namespace F = torch::nn::functional;
/// F::triplet_margin_with_distance_loss(anchor, positive, negative, F::TripletMarginWithDistanceLossFuncOptions().margin(1.0));
/// ```
inline Tensor triplet_margin_with_distance_loss(
const Tensor& anchor,
const Tensor& positive,
const Tensor& negative,
const TripletMarginWithDistanceLossFuncOptions& options = {}) {
return detail::triplet_margin_with_distance_loss(
anchor,
positive,
negative,
options.distance_function(),
options.margin(),
options.swap(),
options.reduction());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor ctc_loss(const Tensor& log_probs,
Expand Down
54 changes: 49 additions & 5 deletions torch/csrc/api/include/torch/nn/modules/loss.h
Expand Up @@ -309,7 +309,7 @@ struct TORCH_API SmoothL1LossImpl : public Cloneable<SmoothL1LossImpl> {
TORCH_MODULE(SmoothL1Loss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiLabelMarginLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

/// Creates a criterion that optimizes a multi-class multi-classification
/// hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`)
/// and output :math:`y` (which is a 2D `Tensor` of target class indices).
Expand Down Expand Up @@ -421,9 +421,9 @@ TORCH_MODULE(MultiLabelSoftMarginLoss);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TripletMarginLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

/// Creates a criterion that measures the triplet loss given an input
/// tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater
/// tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater
/// than :math:`0`. This is used for measuring a relative similarity between
/// samples. A triplet is composed by `a`, `p` and `n` (i.e., `anchor`,
/// samples. A triplet is composed by `a`, `p` and `n` (i.e., `anchor`,
/// `positive examples` and `negative examples` respectively). The
/// shapes of all input tensors should be :math:`(N, D)`.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.TripletMarginLoss to learn
Expand Down Expand Up @@ -461,6 +461,50 @@ struct TORCH_API TripletMarginLossImpl : public Cloneable<TripletMarginLossImpl>
/// module storage semantics.
TORCH_MODULE(TripletMarginLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TripletMarginWithDistanceLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

/// Creates a criterion that measures the triplet loss given input
/// tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor,
/// positive, and negative examples, respectively); and a nonnegative, real-valued function
/// ("distance function") used to compute the relationships between the anchor
/// and positive example ("positive distance") and the anchor and negative
/// example ("negative distance").
/// See https://pytorch.org/docs/master/nn.html#torch.nn.TripletMarginWithDistanceLoss to learn
/// about the exact behavior of this module.
///
/// See the documentation for `torch::nn::TripletMarginWithDistanceLossOptions` class to learn what
/// constructor arguments are supported for this module.
///
/// Example:
/// ```
/// TripletMarginWithDistanceLoss model(TripletMarginWithDistanceLossOptions().margin(3).swap(false));
/// ```
struct TORCH_API TripletMarginWithDistanceLossImpl : public Cloneable<TripletMarginWithDistanceLossImpl> {
explicit TripletMarginWithDistanceLossImpl(
TripletMarginWithDistanceLossOptions options_ = {});

void reset() override;

/// Pretty prints the `TripletMarginWithDistanceLoss` module into the given `stream`.
void pretty_print(std::ostream& stream) const override;

Tensor forward(
const Tensor& anchor,
const Tensor& positive,
const Tensor& negative);

/// The options with which this `Module` was constructed.
TripletMarginWithDistanceLossOptions options;
};

/// A `ModuleHolder` subclass for `TripletMarginWithDistanceLossImpl`.
/// See the documentation for `TripletMarginWithDistanceLossImpl` class to learn what methods it
/// provides, and examples of how to use `TripletMarginWithDistanceLoss` with
/// `torch::nn::TripletMarginWithDistanceLossOptions`.
/// See the documentation for `ModuleHolder` to learn about PyTorch's
/// module storage semantics.
TORCH_MODULE(TripletMarginWithDistanceLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CTCLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

/// The Connectionist Temporal Classification loss.
Expand Down Expand Up @@ -626,9 +670,9 @@ TORCH_MODULE(NLLLoss);
struct TORCH_API CrossEntropyLossImpl : public Cloneable<CrossEntropyLossImpl> {
explicit CrossEntropyLossImpl(
const CrossEntropyLossOptions& options_ = {});

void reset() override;

/// Pretty prints the `CrossEntropyLoss` module into the given `stream`.
void pretty_print(std::ostream& stream) const override;

Expand Down