-
Notifications
You must be signed in to change notification settings - Fork 244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory reduction fixes for MCMC sampler #1802
Conversation
Is there any way to get the output / insight into the failing test? For the latest update I've run the tests both locally and on a kaggle T4x2 session and they passed. The prolda test sometimes has issues with getting the dataset but I've had no problems with the stochastic_volatility test. |
I accidentally updated this with a merge instead of rebasing - let me know if that's an issue. Anyhow it's back up to date |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @andrewdipper. This PR might have a huge impact on the numpyro performance so we need to review the change carefully. Let me take another pass this weekend.
For sure, makes sense |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @andrewdipper, impressive work!
Sorry for the slow review - took me a while to fully understand the changes.
Happy to merge! 💯 |
Here are some proposed changes for reducing GPU memory footprint for MCMC sampling. I was hitting OOM earlier than expected so tracked some of it down:
cached_by
change was to allow saving multiple functions withinfori_collect
. Otherwise they'd collide and use the wrong function.Below are some rough numbers for peak memory usage / runtime with and without the changes for two models (split by //) just to give an initial view.
baseline / no progress bar: 3592MB / 63sec
baseline / with progress bar: 5126MB/ 109sec // 14412MB / 515sec
new / no progress bar: 2052MB / 63sec
new / with progress bar: 2052MB / 70sec // 5182MB / 320sec
Let me know if you think any of the changes would be useful / any modifications are needed