Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaNs when batching simulations in tutorial example. #168

Closed
talfanoutlook opened this issue May 2, 2023 · 14 comments
Closed

NaNs when batching simulations in tutorial example. #168

talfanoutlook opened this issue May 2, 2023 · 14 comments
Assignees
Labels
bug Something isn't working

Comments

@talfanoutlook
Copy link

Describe the bug
Hi guys, I am trying to reproduce the Full Wave Inversion tutorial example. When batching the single simulation across sensor indices with jax.vmap, I am getting NaNs, even though running those sensor simulations individually runs stably.

I'm running on a Colab with GPU.

To Reproduce
Colab to reproduce:

https://colab.research.google.com/drive/1U3ttIMkn4lZnlu4SMNf4Y0IcuB7Sv9MH?usp=sharing

Expected behavior
Batching vs. looping over single sensor simulations should produce the same result.

Desktop (please complete the following information):

  • Running on Google Colab with GPU.
@talfanoutlook talfanoutlook added bug Something isn't working help wanted Extra attention is needed labels May 2, 2023
@astanziola
Copy link
Member

Thanks for pointing it out @talfanoutlook .
This repo needs a bit of love which I wasn't able to give recently, so there's a good chance that the problem is outdated jax dependencies, I will look into it!

@astanziola astanziola self-assigned this May 3, 2023
@talfanoutlook
Copy link
Author

Great let me know - I'll try to take a look myself also.

@astanziola
Copy link
Member

I have a feeling that the library behind jwave (jaxdf) needs to be upgraded to the most recent jax versions, so maybe give me a bit to check that out before getting lost in the code :)

@talfanoutlook
Copy link
Author

talfanoutlook commented May 10, 2023 via email

@astanziola
Copy link
Member

Hi @talfanoutlook , I might need another couple of days on this sorry.
I'm fixing a few issues with jaxdf first that arised with the new jax version, and then will jump on it.
If you find anything meanwhile feel very free to update here or PR!

@astanziola
Copy link
Member

astanziola commented May 23, 2023

@talfanoutlook Just to let you know that I haven't forgot about this, but the problem is weirder than I expected.

What is happening is that if I run twice the same exact (jitted!) function, in this case single_source_simulation, with the same arguments, sometimes I get two different results. It seems that there's something stored in memory that is affecting the results, which is definetely something unexpected from a funcional programming language like jax, so it's not immediately obvious to me what the issue is.

@astanziola
Copy link
Member

astanziola commented May 23, 2023

Interestingly enough, if one disables GPUs and runs the code on CPU, the issue does not appear.
That explains why the tests that run against k-Wave do not fail.

I'll try to setup an MRE and raise it to the jax team.

@talfanoutlook
Copy link
Author

talfanoutlook commented May 23, 2023 via email

@astanziola
Copy link
Member

astanziola commented May 25, 2023

Made a MRE and raised to the jax team: jax-ml/jax#16129
For reference, the pure jax MRE is here https://github.com/ucl-bug/jwave/blob/gpu-error-mre/error.py

@astanziola astanziola removed the help wanted Extra attention is needed label May 25, 2023
@talfanoutlook
Copy link
Author

talfanoutlook commented May 26, 2023 via email

@talfanoutlook
Copy link
Author

talfanoutlook commented May 26, 2023 via email

@astanziola
Copy link
Member

astanziola commented May 26, 2023 via email

@astanziola
Copy link
Member

Seems like this is solved in jax 0.4.11 (and I have absolutely no idea why that's the case). Do you mind giving it a try with the latest jax version @talfanoutlook ?

For some reason, I do get an OOM error in my RTX 4000, so you may want to try with a beefier one. The fix for the OOM error is trivial, that is: passing the simulation parameters to the solver directly instead of having them internally calculated.

I will update the example notebook accordingly soon, while I move the time-stepping method to diffrax :)

@astanziola
Copy link
Member

I will close it for now, as in my tests this seems to be consistently fixed. Feel free to reopen it or open a new one if the problem still arises!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants