-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NaNs when batching simulations in tutorial example. #168
Comments
Thanks for pointing it out @talfanoutlook . |
Great let me know - I'll try to take a look myself also. |
I have a feeling that the library behind jwave (jaxdf) needs to be upgraded to the most recent jax versions, so maybe give me a bit to check that out before getting lost in the code :) |
Hey Antonio, any suggestions on where to get started? Happy to dig in you you don’t have the time rn.
Talfan
…Sent from my iPhone
On 3 May 2023, at 16:50, Antonio Stanziola ***@***.***> wrote:
I have a feeling that the library behind jwave (jaxdf) needs to be upgraded to the most recent jax versions, so maybe give me a bit to check that out before getting lost in the code :)
—
Reply to this email directly, view it on GitHub<#168 (comment)>, or unsubscribe<https://github.com/notifications/unsubscribe-auth/A7R44FBWIPNCEPAZ7TTSMPTXEJ5ENANCNFSM6AAAAAAXTGVPYM>.
You are receiving this because you were mentioned.Message ID: ***@***.***>
|
Hi @talfanoutlook , I might need another couple of days on this sorry. |
@talfanoutlook Just to let you know that I haven't forgot about this, but the problem is weirder than I expected. What is happening is that if I run twice the same exact (jitted!) function, in this case |
Interestingly enough, if one disables GPUs and runs the code on CPU, the issue does not appear. I'll try to setup an MRE and raise it to the jax team. |
That’s great Antonio, thanks for looking into this.
Excited to start using jwave :)
…Sent from my iPhone
On 23 May 2023, at 15:27, Antonio Stanziola ***@***.***> wrote:
Interestingly enough, if one disables GPUs and runs the code on CPU, the issue does not appear.
That explains why the tests tha run against k-Wave do not fail.
I'll try to setup an MRE and raise it to the jax team.
—
Reply to this email directly, view it on GitHub<#168 (comment)>, or unsubscribe<https://github.com/notifications/unsubscribe-auth/A7R44FHD5TEHTKDFJKAQ2JDXHTCOLANCNFSM6AAAAAAXTGVPYM>.
You are receiving this because you were mentioned.Message ID: ***@***.***>
|
Made a MRE and raised to the |
That's awesome Antonio - I'll follow the the link!
…________________________________
From: Antonio Stanziola ***@***.***>
Sent: 25 May 2023 12:47
To: ucl-bug/jwave ***@***.***>
Cc: talfanoutlook ***@***.***>; Mention ***@***.***>
Subject: Re: [ucl-bug/jwave] NaNs when batching simulations in tutorial example. (Issue #168)
Made a MRE and raised to the jax team: jax-ml/jax#16129<jax-ml/jax#16129>
—
Reply to this email directly, view it on GitHub<#168 (comment)>, or unsubscribe<https://github.com/notifications/unsubscribe-auth/A7R44FBPH3HCWRVM7AYQPELXH5BGNANCNFSM6AAAAAAXTGVPYM>.
You are receiving this because you were mentioned.Message ID: ***@***.***>
|
Very weird bug indeed. Is it specifically in the sinc function as far as you can tell?
…________________________________
From: Talfan Evans ***@***.***>
Sent: 26 May 2023 11:36
To: ucl-bug/jwave ***@***.***>
Subject: Re: [ucl-bug/jwave] NaNs when batching simulations in tutorial example. (Issue #168)
That's awesome Antonio - I'll follow the the link!
________________________________
From: Antonio Stanziola ***@***.***>
Sent: 25 May 2023 12:47
To: ucl-bug/jwave ***@***.***>
Cc: talfanoutlook ***@***.***>; Mention ***@***.***>
Subject: Re: [ucl-bug/jwave] NaNs when batching simulations in tutorial example. (Issue #168)
Made a MRE and raised to the jax team: jax-ml/jax#16129<jax-ml/jax#16129>
—
Reply to this email directly, view it on GitHub<#168 (comment)>, or unsubscribe<https://github.com/notifications/unsubscribe-auth/A7R44FBPH3HCWRVM7AYQPELXH5BGNANCNFSM6AAAAAAXTGVPYM>.
You are receiving this because you were mentioned.Message ID: ***@***.***>
|
Something related to it, but quite weird to just have to do with the sinc.
A quick way to avoid this problem for now is to get rid of the k-space correction (by overloading the `params` of `simulate_wave_propagation`, and setting the k-space operator to all ones manually). It should not matter
too much for simulations of the order of tens of wavelengths (see the paragraphs around eq. (2.15) and (2.16) of the k-wave manual www.k-wave.org/manual/k-wave_user_manual_1.1.pdf ).
I will add an example about how to disable it in the next release.
|
Seems like this is solved in jax 0.4.11 (and I have absolutely no idea why that's the case). Do you mind giving it a try with the latest jax version @talfanoutlook ? For some reason, I do get an OOM error in my RTX 4000, so you may want to try with a beefier one. The fix for the OOM error is trivial, that is: passing the simulation parameters to the solver directly instead of having them internally calculated. I will update the example notebook accordingly soon, while I move the time-stepping method to |
I will close it for now, as in my tests this seems to be consistently fixed. Feel free to reopen it or open a new one if the problem still arises! |
Describe the bug
Hi guys, I am trying to reproduce the Full Wave Inversion tutorial example. When batching the single simulation across sensor indices with
jax.vmap
, I am getting NaNs, even though running those sensor simulations individually runs stably.I'm running on a Colab with GPU.
To Reproduce
Colab to reproduce:
https://colab.research.google.com/drive/1U3ttIMkn4lZnlu4SMNf4Y0IcuB7Sv9MH?usp=sharing
Expected behavior
Batching vs. looping over single sensor simulations should produce the same result.
Desktop (please complete the following information):
The text was updated successfully, but these errors were encountered: