-
Notifications
You must be signed in to change notification settings - Fork 24.9k
implement TripletMarginLoss as a native function #5680
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -959,16 +967,17 @@ class TripletMarginLoss(Module): | |||
http://www.iis.ee.ic.ac.uk/%7Evbalnt/shallow_descr/TFeat_paper.pdf | |||
""" | |||
|
|||
def __init__(self, margin=1.0, p=2, eps=1e-6, swap=False): | |||
super(TripletMarginLoss, self).__init__() | |||
def __init__(self, margin=1.0, p=2, eps=1e-6, swap=False, size_average=True, reduce=True): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_nn.py
Outdated
constructor_args=(torch.rand(10),), | ||
input_fn=lambda: torch.randn(5, 10), | ||
target_fn=lambda: torch.rand(5, 10).mul(2).floor(), | ||
reference_fn=lambda i, t, m: -((t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * get_weight(m)).sum() / |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/Distance.cpp
Outdated
Tensor pairwise_distance(const Tensor& x1, const Tensor& x2, double p, double eps) { | ||
auto diff = abs(x1 - x2); | ||
auto out = pow(diff + eps, p).sum(1); | ||
return pow(out, 1 / p); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@pytorchbot retest this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code looks good to me. I had a few minor comments and a concern that we're subtly changing the behavior of pairwise_distance
margin (float, optional): Default: `1`. | ||
p (int, optional): The norm degree for pairwise distance. Default: `2`. | ||
swap (float, optional): The distance swap is described in detail in the paper | ||
`Learning shallow convolutional feature descriptors with triplet losses` by |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -959,16 +967,17 @@ class TripletMarginLoss(Module): | |||
http://www.iis.ee.ic.ac.uk/%7Evbalnt/shallow_descr/TFeat_paper.pdf | |||
""" | |||
|
|||
def __init__(self, margin=1.0, p=2, eps=1e-6, swap=False): | |||
super(TripletMarginLoss, self).__init__() | |||
def __init__(self, margin=1.0, p=2, eps=1e-6, swap=False, size_average=True, reduce=True): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -271,6 +271,9 @@ | |||
- func: ones_like(Tensor self, *, Type dtype) -> Tensor | |||
variants: function | |||
|
|||
- func: pairwise_distance(Tensor x1, Tensor x2, double p, double eps) -> Tensor |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -360,6 +363,9 @@ | |||
- func: transpose_(Tensor self, int64_t dim0, int64_t dim1) -> Tensor | |||
variants: method | |||
|
|||
- func: triplet_margin_loss(Tensor anchor, Tensor positive, Tensor negative, double margin, double p, double eps, bool swap, bool size_average, bool reduce) -> Tensor |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/Distance.cpp
Outdated
|
||
Tensor pairwise_distance(const Tensor& x1, const Tensor& x2, double p, double eps) { | ||
auto diff = abs(x1 - x2 + eps); | ||
return norm(diff, p, 1); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/Distance.cpp
Outdated
namespace at { namespace native { | ||
|
||
Tensor pairwise_distance(const Tensor& x1, const Tensor& x2, double p, double eps) { | ||
auto diff = abs(x1 - x2 + eps); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
assert x1.size() == x2.size(), "Input sizes must be equal." | ||
assert x1.dim() == 2, "Input must be a 2D matrix." | ||
diff = torch.abs(x1 - x2) | ||
out = torch.pow(diff + eps, p).sum(dim=1, keepdim=True) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/functional.py
Outdated
@@ -1960,11 +1960,7 @@ def pairwise_distance(x1, x2, p=2, eps=1e-6): | |||
>>> output = F.pairwise_distance(input1, input2, p=2) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Minor nit below
aten/src/ATen/native/Distance.cpp
Outdated
@@ -5,7 +5,6 @@ | |||
namespace at { namespace native { | |||
|
|||
Tensor pairwise_distance(const Tensor& x1, const Tensor& x2, double p, double eps) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yay! Approve approve approve
Benchmarks run on inputs of size (5,500) x 5000 iterations.
old:
forward [1.0334448860958219, 1.0698160571046174, 1.054812144022435]
backward [1.6976923400070518, 1.6186234278138727, 1.617640192154795]
double backward [3.2021269120741636, 3.165953448973596, 3.24088541790843]
new:
forward [0.8313469190616161, 0.8255189431365579, 0.8157028129789978]
backward [1.5887232730165124, 1.6056769208516926, 1.5178304109722376]
double backward [3.154403690015897, 3.1905114939436316, 3.2450901730917394]