Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory reduction fixes for MCMC sampler #1802

Merged
merged 17 commits into from
Jun 24, 2024
Merged

Conversation

andrewdipper
Copy link
Contributor

Here are some proposed changes for reducing GPU memory footprint for MCMC sampling. I was hitting OOM earlier than expected so tracked some of it down:

  • Collection was updated within jit blocks without donate_argnums (at least when using a progressbar). This added a copy for each update. Using donate_argnums makes sampling time with a progress bar much closer to without as well.
  • Unraveling the flattened collection array at the end requires a copy. I modified collection to keep the pytree structure during sampling to avoid the replication.
  • Waiting to materialize _states_flat. Since jax doesn't use views for reshaping this duplicates all the samples. _states and _states_flat might even have the same memory layout. I made a rough change to not materialize _states_flat until requested but I'm not sure that's an ideal solution
  • The cached_by change was to allow saving multiple functions within fori_collect. Otherwise they'd collide and use the wrong function.

Below are some rough numbers for peak memory usage / runtime with and without the changes for two models (split by //) just to give an initial view.

baseline / no progress bar: 3592MB / 63sec
baseline / with progress bar: 5126MB/ 109sec // 14412MB / 515sec
new / no progress bar: 2052MB / 63sec
new / with progress bar: 2052MB / 70sec // 5182MB / 320sec

Let me know if you think any of the changes would be useful / any modifications are needed

@andrewdipper
Copy link
Contributor Author

Is there any way to get the output / insight into the failing test?
FAILED test/test_examples.py::test_cpu[stochastic_volatility.py --num-samples 100 --num-warmup 100] - subprocess.CalledProcessError: Command '['/opt/hostedtoolcache/Python/3.9.19/x64/bin/python', '/home/runner/work/numpyro/numpyro/examples/stochastic_volatility.py', '--num-samples', '100', '--num-warmup', '100']' returned non-zero exit status 1.

For the latest update I've run the tests both locally and on a kaggle T4x2 session and they passed. The prolda test sometimes has issues with getting the dataset but I've had no problems with the stochastic_volatility test.

@andrewdipper andrewdipper changed the title [WIP] memory reduction fixes for MCMC sampler Memory reduction fixes for MCMC sampler May 28, 2024
@andrewdipper
Copy link
Contributor Author

I accidentally updated this with a merge instead of rebasing - let me know if that's an issue. Anyhow it's back up to date

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @andrewdipper. This PR might have a huge impact on the numpyro performance so we need to review the change carefully. Let me take another pass this weekend.

numpyro/util.py Outdated Show resolved Hide resolved
numpyro/util.py Outdated Show resolved Hide resolved
numpyro/util.py Outdated Show resolved Hide resolved
@andrewdipper
Copy link
Contributor Author

For sure, makes sense

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @andrewdipper, impressive work!

Sorry for the slow review - took me a while to fully understand the changes.

numpyro/util.py Outdated Show resolved Hide resolved
numpyro/util.py Outdated Show resolved Hide resolved
numpyro/util.py Show resolved Hide resolved
numpyro/util.py Outdated Show resolved Hide resolved
numpyro/util.py Outdated Show resolved Hide resolved
@fehiepsi
Copy link
Member

Happy to merge! 💯

@fehiepsi fehiepsi merged commit 616a811 into pyro-ppl:master Jun 24, 2024
4 checks passed
@andrewdipper andrewdipper deleted the memfix branch June 24, 2024 15:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants