Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect gradients for 3D Helmholtz simulations using checkpointing #150

Closed
astanziola opened this issue Feb 5, 2023 · 2 comments
Closed
Assignees
Labels
bug Something isn't working

Comments

@astanziola
Copy link
Member

There are some edge cases for which Helmholtz simulations in 3D give incorrect gradients when checkpoint=True is used in the helmholtz_solver for FourierSeries fields.

This is likely to be a jax issue, since results were correct up to jax 0.3.20.
It has been raised upstream: google/jax#14302

For the moment, if memory requirements allow, setting checkpoint=False should get around the problem.

May make sense to add a warning when checkpoint=True and setting its default value to False, especially if this is going to take long to be solved from the jax side.

@astanziola astanziola added the bug Something isn't working label Feb 5, 2023
@astanziola
Copy link
Member Author

Worth checking if this is still an issue with lineax.

@astanziola
Copy link
Member Author

This seems to be fixed now. Added some regression tests to catch it if it happens again

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant