Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposal: Generic Triplet-Margin Loss #43342

Closed
ethch18 opened this issue Aug 20, 2020 · 7 comments
Closed

Proposal: Generic Triplet-Margin Loss #43342

ethch18 opened this issue Aug 20, 2020 · 7 comments
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: loss Problem is related to loss function module: nn Related to torch.nn needs research We need to decide whether or not this merits inclusion, based on research world triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ethch18
Copy link

ethch18 commented Aug 20, 2020

馃殌 Feature

We'd like to add a distance-agnostic version of the TripletMarginLoss, as the current version only supports l_p norms. This was something we implemented as part of a recent investigation for embedding learning and thought could be helpful to the PyTorch userbase at large.

Motivation

  1. Extensibility: decoupling the distance function and the loss computation allow for more flexibility for users. PyTorch currently has a CosineEmbeddingLoss, but that serves a somewhat different purpose and doesn't really work for users wanting a triplet-margin loss with cosine distance.

  2. Existing use cases: several papers have proposed triplet loss functions with cosine distance (1, 2) or have generally used cosine-based metrics (1, 2). PyTorch-BigGraph also does something similar with its ranking loss.

  3. Products like Tensorflow decouple the distance functions and even allow for custom distance metrics

Pitch

We have a working implementation of this (see screenshots) and propose to use it directly, with adjustments to fit the existing PyTorch code structure. We'll tentatively call these torch.nn.GenericTripletMarginLoss and F.generic_triplet_margin_loss. In the proposed implementation, GenericTripletMarginLoss will be a stateful wrapper for F.generic_triplet_margin_loss, and the loss computation will occur in the latter.

The module signatures will resemble that of the production module, with an analogous signature for the functional module. The removed parameters (p and eps) are specific to l_p norm computation; size_average and reduce are removed because of deprecation. There are two new parameters: distance_function and is_similarity_function. distance_function is a function/module $$f: R^{n \cross d} \rightarrow R^n$$ that gives an index-wise distance metric between two tensors of embeddings. This can be, for example, torch.nn.PairwiseDistance or torch.nn.CosineSimilarity. is_similarity_function is a boolean that denotes whether distance_function is a distance or similarity function: this will indicate whether to flip the signs in distance computation.

Alternatives

Beyond the straightforward alternative of leaving this to the user to implement every time, we haven't come up with any other alternatives.

One open question is whether the Python implementation is performant enough, or if this would need to be added in aTen as well.

Additional context

image

cc @albanD @mruberry

@gchanan gchanan added module: nn Related to torch.nn enhancement Not as big of a feature, but technically not a bug. Should be easy to fix needs research We need to decide whether or not this merits inclusion, based on research world triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: loss Problem is related to loss function labels Aug 20, 2020
@mruberry
Copy link
Collaborator

This is an interesting proposal and letting users specify their own distance functions makes sense.

I see two challenges with supporting this in PyTorch Core:

  • PyTorch typically supports both a Python and C++ API, and I think this would only work in Python?
  • Would this work with the JIT? I don't think scripting or tracing support passing callables.

@ethch18
Copy link
Author

ethch18 commented Aug 25, 2020

@mruberry thanks for your response!

Re: C++, would it be possible to pass in the distance function as a lambda? I'm not too familar with aTen but the distance function should just be another function. Or were there other concerns about C++?

Re: JIT, are there ways to work around the callable issue?

@mruberry
Copy link
Collaborator

mruberry commented Aug 25, 2020

Re: C++, would it be possible to pass in the distance function as a lambda? I'm not too familar with aTen but the distance function should just be another function. Or were there other concerns about C++?

Yes, you could have two implementations: one for Python and one for C++.

Re: JIT, are there ways to work around the callable issue?

Not that I know of.

Edit/Update: Passing a Python function to a module's init method and then calling that function in forward works fine with scripting inference.

@ethch18
Copy link
Author

ethch18 commented Aug 25, 2020

@mruberry I see that in this Q/A, it's possible to get around the JIT callable issue by wrapping it as a module. Since the implementation I proposed takes in an nn.Module and keeps it as an attribute, I think this should work fine as-is?

@mruberry
Copy link
Collaborator

@mruberry I see that in this Q/A, it's possible to get around the JIT callable issue by wrapping it as a module. Since the implementation I proposed takes in an nn.Module and keeps it as an attribute, I think this should work fine as-is?

Interesting! Good point. My mistake I should have thought of wrapping in a module / noticing that's how your callable was typed.

@mruberry
Copy link
Collaborator

Update from offline conversation with @ethch18:

This seems like a cool improvement on an existing module, but the work to implement it in core might be significant. If it's pursued, it's recommend to start by prototyping the C++ changes required in a BC-preserving way.

@ethch18 ethch18 linked a pull request Aug 26, 2020 that will close this issue
ethch18 added a commit that referenced this issue Aug 27, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@ethch18 ethch18 linked a pull request Aug 27, 2020 that will close this issue
ethch18 added a commit that referenced this issue Aug 27, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23363898](https://our.internmc.facebook.com/intern/diff/D23363898)

[ghstack-poisoned]
ethch18 added a commit that referenced this issue Aug 27, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: b62a832256009bd023fa9c8e4afdb657655c82af
Pull Request resolved: #43680
ethch18 added a commit that referenced this issue Aug 27, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23363898](https://our.internmc.facebook.com/intern/diff/D23363898)

[ghstack-poisoned]
ethch18 added a commit that referenced this issue Aug 27, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 14fbd6d444517d04ad3dd4f6b5e040411481905e
Pull Request resolved: #43680
ethch18 added a commit that referenced this issue Aug 31, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23363898](https://our.internmc.facebook.com/intern/diff/D23363898)

[ghstack-poisoned]
ethch18 added a commit that referenced this issue Aug 31, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23363898](https://our.internmc.facebook.com/intern/diff/D23363898)

[ghstack-poisoned]
ethch18 added a commit that referenced this issue Sep 1, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23363898](https://our.internmc.facebook.com/intern/diff/D23363898)

[ghstack-poisoned]
ethch18 added a commit that referenced this issue Sep 1, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

* Consolidate tests, clarify functional limitations

* Documentation updates

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: f15c9a1778fa243f86861cfbc11c3e789f5629b1
Pull Request resolved: #43680
ethch18 added a commit that referenced this issue Sep 1, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23363898](https://our.internmc.facebook.com/intern/diff/D23363898)

[ghstack-poisoned]
ethch18 added a commit that referenced this issue Sep 1, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

* Consolidate tests, clarify functional limitations

* Documentation updates

* Remove stray imports

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 59680921c9b703c1172929e805ba38074e3f69f2
Pull Request resolved: #43680
ethch18 added a commit that referenced this issue Sep 2, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23363898](https://our.internmc.facebook.com/intern/diff/D23363898)

[ghstack-poisoned]
ethch18 added a commit that referenced this issue Sep 2, 2020
Summary: Following up on the C++ side of [this
issue](#43342).  The implementation
here is parallel to that of the Python one, but we don't use native functions
because Callables aren't supported.

Test Plan: Unit test with test_api

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
ethch18 added a commit that referenced this issue Sep 2, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23363898](https://our.internmc.facebook.com/intern/diff/D23363898)

[ghstack-poisoned]
ethch18 added a commit that referenced this issue Sep 2, 2020
Summary: Following up on the C++ side of [this
issue](#43342).  The implementation
here is parallel to that of the Python one, but we don't use native functions
because Callables aren't supported.

Test Plan: Unit test with test_api

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 821a1600a9f961dceb112d75857691a290219ecc
Pull Request resolved: #44072
@ethch18 ethch18 linked a pull request Sep 2, 2020 that will close this issue
ethch18 added a commit that referenced this issue Sep 3, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23363898](https://our.internmc.facebook.com/intern/diff/D23363898)

[ghstack-poisoned]
ethch18 added a commit that referenced this issue Sep 3, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23363898](https://our.internmc.facebook.com/intern/diff/D23363898)

[ghstack-poisoned]
ethch18 added a commit that referenced this issue Sep 3, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23363898](https://our.internmc.facebook.com/intern/diff/D23363898)

[ghstack-poisoned]
ethch18 added a commit that referenced this issue Sep 3, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23363898](https://our.internmc.facebook.com/intern/diff/D23363898)

[ghstack-poisoned]
ethch18 added a commit that referenced this issue Sep 3, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

* Consolidate tests, clarify functional limitations

* Documentation updates

* Remove stray imports

* Fix CI

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: bcbf393978f422143b523c0a4916a97bfd8f5e18
Pull Request resolved: #43680
ethch18 added a commit that referenced this issue Sep 11, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23363898](https://our.internmc.facebook.com/intern/diff/D23363898)

[ghstack-poisoned]
ethch18 added a commit that referenced this issue Sep 11, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23363898](https://our.internmc.facebook.com/intern/diff/D23363898)

[ghstack-poisoned]
ethch18 added a commit that referenced this issue Sep 11, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

* Consolidate tests, clarify functional limitations

* Documentation updates

* Remove stray imports

* Fix CI

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 731708e0ced604dc46e1b2b81a91dcecd9607d8f
Pull Request resolved: #43680
ethch18 added a commit that referenced this issue Sep 18, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23363898](https://our.internmc.facebook.com/intern/diff/D23363898)

[ghstack-poisoned]
ethch18 added a commit that referenced this issue Sep 18, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

* Consolidate tests, clarify functional limitations

* Documentation updates

* Remove stray imports

* Fix CI

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: a4e87fa38886b7feaea0cfafea2eebd701d44c9a
Pull Request resolved: #43680
ethch18 added a commit that referenced this issue Sep 18, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23363898](https://our.internmc.facebook.com/intern/diff/D23363898)

[ghstack-poisoned]
ethch18 added a commit that referenced this issue Sep 18, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23363898](https://our.internmc.facebook.com/intern/diff/D23363898)

[ghstack-poisoned]
ethch18 added a commit that referenced this issue Sep 18, 2020
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

* Consolidate tests, clarify functional limitations

* Documentation updates

* Remove stray imports

* Fix CI

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 7052170dfc4c796c16ab4e5dada4cc8e7eb9dba7
Pull Request resolved: #43680
facebook-github-bot pushed a commit that referenced this issue Sep 22, 2020
Summary:
Pull Request resolved: #43680

As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan:
python test/run_tests.py

Imported from OSS

Reviewed By: albanD

Differential Revision: D23363898

fbshipit-source-id: 1cafc05abecdbe7812b41deaa1e50ea11239d0cb
loadbxh pushed a commit to loadbxh/Torch that referenced this issue Sep 23, 2020
Summary: As discussed [here](pytorch/pytorch#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

* Consolidate tests, clarify functional limitations

* Documentation updates

* Remove stray imports

* Fix CI

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 672055f026a5627c7883c43625ca85bc05f5af5e
Pull Request resolved: pytorch/pytorch#43680
@mruberry
Copy link
Collaborator

mruberry commented Dec 2, 2020

We now have a "generic" triplet margin loss! Closing this issue.

@mruberry mruberry closed this as completed Dec 2, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: loss Problem is related to loss function module: nn Related to torch.nn needs research We need to decide whether or not this merits inclusion, based on research world triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
3 participants