Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix jax.device_put for potentials #183

Merged
merged 4 commits into from
Nov 25, 2022

Conversation

michalk8
Copy link
Collaborator

@michalk8 michalk8 commented Nov 24, 2022

Fixes jax.device_put on DualPotentials.

partially addresses #182

@michalk8 michalk8 self-assigned this Nov 24, 2022
@codecov-commenter
Copy link

codecov-commenter commented Nov 24, 2022

Codecov Report

Merging #183 (9b28c55) into main (156ef2b) will decrease coverage by 0.04%.
The diff coverage is 64.70%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #183      +/-   ##
==========================================
- Coverage   89.28%   89.23%   -0.05%     
==========================================
  Files          51       51              
  Lines        5187     5192       +5     
  Branches      527      529       +2     
==========================================
+ Hits         4631     4633       +2     
- Misses        430      431       +1     
- Partials      126      128       +2     
Impacted Files Coverage Δ
ott/problems/linear/potentials.py 84.76% <64.70%> (-2.24%) ⬇️

Copy link
Contributor

@marcocuturi marcocuturi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's great Michal!

There was a change in the last PR that I overlooked on passing a and b by default. Because f and g here correspond exactly to an existing a and b (even if that pair was defined by default initially, it somewhat becomes unique once it is set), it would be a mistake IMHO to allow the user to pass on an f and a g without requiring them to pass the corresponding a and b. This could easily lead to a silent error where the f,g were computed for a given pair of weights but the dual potentials evaluated with the default uniform measures. Therefore I would prefer sticking to the way this was handled before (see also this comment: #167 (comment))

@michalk8
Copy link
Collaborator Author

That's great Michal!

There was a change in the last PR that I overlooked on passing a and b by default. Because f and g here correspond exactly to an existing a and b (even if that pair was defined by default initially, it somewhat becomes unique once it is set), it would be a mistake IMHO to allow the user to pass on an f and a g without requiring them to pass the corresponding a and b. This could easily lead to a silent error where the f,g were computed for a given pair of weights but the dual potentials evaluated with the default uniform measures. Therefore I would prefer sticking to the way this was handled before (see also this comment: #167 (comment))

I'd then consider passing LinearProblem instead of geom/a/b to further avoid any user-level mistakes.

@michalk8 michalk8 merged commit be94b74 into ott-jax:main Nov 25, 2022
@michalk8 michalk8 deleted the fix/dual-potentials-device branch November 25, 2022 12:29
michalk8 added a commit that referenced this pull request Jun 27, 2024
* Fix `jax.device_put` for potentials

* Make weights not optional

* Use `LinearProblem` in `EntropicPotentials`

* Remove extra docs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants