Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugs in the DualPotentials class #182

Closed
lucaeyring opened this issue Nov 24, 2022 · 6 comments
Closed

Bugs in the DualPotentials class #182

lucaeyring opened this issue Nov 24, 2022 · 6 comments
Assignees
Labels
bug Something isn't working

Comments

@lucaeyring
Copy link
Contributor

I believe there are two bugs in the DualPotentials class:

  1. in the distance function tgt and src are swapped. g should be applied to the src, not the tgt.
  2. It is not possible to move a DualPotentials object to a different device because its inputs f and g are Callables which is not a valid Jax type.

To Reproduce
Steps to reproduce the behavior of 2:

import jax
from ott.problems.linear import potentials

f = lambda x: x
g = lambda x: x
test_potential = potentials.DualPotentials(f, g, corr=True)
jax.device_put(test_potential, "cpu")

Expected behavior
One should be able to move a DualPotentials object to a different device.

A possible fix for 2 would be to instead use the TrainStates as input.

@michalk8
Copy link
Collaborator

michalk8 commented Nov 24, 2022

  1. Do you mean when corr=False or corr=True or in both cases? Will take a look again at the Makkuva et al. paper.
  2. Thanks, nice catch! Will be fixed in Fix jax.device_put for potentials #183

@michalk8 michalk8 self-assigned this Nov 24, 2022
@michalk8 michalk8 added the bug Something isn't working label Nov 24, 2022
@michalk8
Copy link
Collaborator

You're right, seems that it's swapped in the paper (also not consistent with NeuralDual's way of computing the W2 distance. Will be addressed in a different PR than #183

@marcocuturi
Copy link
Contributor

@michalk8 , shouldnt we fix this?

@bamos
Copy link
Contributor

bamos commented Dec 9, 2022

@michalk8 , shouldnt we fix this?

@marcocuturi I have a fix swapping the order of the neural dual potentials to be consistent with Makkuva et al. in a larger PR I'll send soon, but I've still been having some issues stabilizing it and setting up a version with a minimal amount of changes to start a larger discussion. My updates unfortunately ended up being more difficult to clean up and debug than I expected, but I'll send an initial version in a few days

@marcocuturi
Copy link
Contributor

@michalk8 , no rush at all, let's chat live about this then!!

@michalk8
Copy link
Collaborator

closed via #219

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants