Skip to content

Commit

Permalink
Allow pickled mcmc object to run post warmup phase (#1558)
Browse files Browse the repository at this point in the history
* allow to pickle mcmc objective

* black
  • Loading branch information
fehiepsi committed Mar 20, 2023
1 parent aa0cb24 commit f66ba4f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
6 changes: 4 additions & 2 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,14 +405,16 @@ def _get_cached_init_state(self, rng_key, args, kwargs):

def _single_chain_mcmc(self, init, args, kwargs, collect_fields):
rng_key, init_state, init_params = init
if init_state is None:
init_state = self.sampler.init(
# Check if _sample_fn is None, then we need to initialize the sampler.
if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
new_init_state = self.sampler.init(
rng_key,
self.num_warmup,
init_params,
model_args=args,
model_kwargs=kwargs,
)
init_state = new_init_state if init_state is None else init_state
sample_fn, postprocess_fn = self._get_cached_fns()
diagnostics = (
lambda x: self.sampler.get_diagnostics_str(x[0])
Expand Down
8 changes: 8 additions & 0 deletions test/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,11 @@ def test_pickle_singleton_constraint():
roundtripped_gt_cstr = pickle.loads(pickle.dumps(gt_cstr))
assert type(roundtripped_gt_cstr) is type(gt_cstr)
assert gt_cstr.lower_bound == roundtripped_gt_cstr.lower_bound


def test_mcmc_pickle_post_warmup():
mcmc = MCMC(NUTS(normal_model), num_warmup=10, num_samples=10)
mcmc.run(random.PRNGKey(0))
pickled_mcmc = pickle.loads(pickle.dumps(mcmc))
pickled_mcmc.post_warmup_state = pickled_mcmc.last_state
pickled_mcmc.run(random.PRNGKey(1))

0 comments on commit f66ba4f

Please sign in to comment.