-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix jax.device_put
for potentials
#183
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #183 +/- ##
==========================================
- Coverage 89.28% 89.23% -0.05%
==========================================
Files 51 51
Lines 5187 5192 +5
Branches 527 529 +2
==========================================
+ Hits 4631 4633 +2
- Misses 430 431 +1
- Partials 126 128 +2
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's great Michal!
There was a change in the last PR that I overlooked on passing a
and b
by default. Because f
and g
here correspond exactly to an existing a
and b
(even if that pair was defined by default initially, it somewhat becomes unique once it is set), it would be a mistake IMHO to allow the user to pass on an f
and a g
without requiring them to pass the corresponding a
and b
. This could easily lead to a silent error where the f,g
were computed for a given pair of weights but the dual potentials evaluated with the default uniform measures. Therefore I would prefer sticking to the way this was handled before (see also this comment: #167 (comment))
I'd then consider passing |
* Fix `jax.device_put` for potentials * Make weights not optional * Use `LinearProblem` in `EntropicPotentials` * Remove extra docs
Fixes
jax.device_put
onDualPotentials
.partially addresses #182