Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase numerical precision for Jax (optional?) #17

Closed
Marius1311 opened this issue Sep 13, 2021 · 6 comments
Closed

Increase numerical precision for Jax (optional?) #17

Marius1311 opened this issue Sep 13, 2021 · 6 comments
Assignees

Comments

@Marius1311
Copy link
Collaborator

Makes a difference in our example (MK_2021-09-07_fgw_comparison_gt). Make it optional, e.g. when running on the GPU?

@michalk8
Copy link
Collaborator

Ok, will add dtype to __init__ and ensure passed arrays/geometries have that dtype.

@Marius1311
Copy link
Collaborator Author

amazing, thanks!

@michalk8
Copy link
Collaborator

I've tried 2 implementations:

  1. user does
from jax.config import config
config.update("jax_enable_x64", True)
  • we check and convert to f64 - problem is that all allocations (esp. ones not in our control in sinkhorn) are after the update in f64, which causes a mismatch between sizes.
  1. add option x64: bool = False to fit - this worked better, but it's a stateful operation (I have to write to a config and then NOT revert it for it to work). If previously x64 was disallowed and was temporarily allowed for the computation, we again get some sizing erros (i.e. expected 1600 bytes, got 3200).
    Solution would be to convert back to x32 or to remain stateful and I dislike both (not to mention I don't know what the consequences of modifying the config like this are).

My take: if x64 is desired, use as jax does it in the beginning.

@Marius1311
Copy link
Collaborator Author

Mhm, that doesn't sound very promising, let's further discuss this today. How can I make sure I use float64 for my internal checks? I would like to run some checks.

@michalk8
Copy link
Collaborator

Mhm, that doesn't sound very promising, let's further discuss this today. How can I make sure I use float64 for my internal checks? I would like to run some checks.

If you didn't use the above import, it should be always float32, otherwise float64.

@Marius1311
Copy link
Collaborator Author

Assuming this is done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants