Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for pickling an MCMC object with HMCGibbs (and MixedHMC) samplers and parallel chains #1746

Merged
merged 2 commits into from
Feb 26, 2024

Conversation

msaintja
Copy link
Contributor

This PR addresses issue #1742.

When using MixedHMC, and sampling with option chain_method="parallel" with more than one chain, using dill/pickle to save the MCMC object resulted in a ConcretizationTypeError, as some attributes may contain JAX tracers.
From limited testing, these would be the _support_sizes_flat attribute of the MCMC object, as well as some attributes of _prototype_trace[...]["fn"].

Setting those to None in the state copy passed to dill would allow to circumvent the JAX error.
I'm not sure how critical it is to maintain those attributes in a pickled save, however a simple dump-load-run seems to work fine with the example code presented in #1742:

with open("chains_test.pkl", 'wb') as f:
    dill.dump(mcmc, f)
with open("chains_test.pkl", 'rb') as f:
    mcmc = dill.load(f)
mcmc.post_warmup_state = mcmc.last_state
mcmc.run(key, data)

@msaintja msaintja changed the title Fix for pickling an MCMC object with MixedHMC sampler and parallel chains Fix for pickling an MCMC object with HMCGibbs (and MixedHMC) samplers and parallel chains Feb 25, 2024
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick fix and contributing to numpyro, @msaintja!!

@fehiepsi fehiepsi merged commit a967f69 into pyro-ppl:master Feb 26, 2024
4 checks passed
OlaRonning pushed a commit to aleatory-science/numpyro that referenced this pull request Feb 27, 2024
… and parallel chains (pyro-ppl#1746)

* Remove jax tracers from MixedHMC.__getstate__()

* Remove jax tracers from HMCGibbs serialization
OlaRonning pushed a commit to aleatory-science/numpyro that referenced this pull request May 6, 2024
… and parallel chains (pyro-ppl#1746)

* Remove jax tracers from MixedHMC.__getstate__()

* Remove jax tracers from HMCGibbs serialization
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants