Skip to content

Commit

Permalink
Fix for pickling an MCMC object with HMCGibbs (and MixedHMC) samplers…
Browse files Browse the repository at this point in the history
… and parallel chains (#1746)

* Remove jax tracers from MixedHMC.__getstate__()

* Remove jax tracers from HMCGibbs serialization
  • Loading branch information
msaintja committed Feb 26, 2024
1 parent c4ca3d8 commit a967f69
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
5 changes: 5 additions & 0 deletions numpyro/infer/hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ def potential_fn(z_gibbs, z_hmc):

return HMCGibbsState(z, hmc_state, rng_key)

def __getstate__(self):
state = self.__dict__.copy()
state["_prototype_trace"] = None
return state


def _discrete_gibbs_proposal_body_fn(
z_init_flat, unravel_fn, pe_init, potential_fn, idx, i, val
Expand Down
2 changes: 2 additions & 0 deletions numpyro/infer/mixed_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,4 +307,6 @@ def body_fn(i, vals):
def __getstate__(self):
state = self.__dict__.copy()
state["_wa_update"] = None
state["_prototype_trace"] = None
state["_support_sizes_flat"] = None
return state

0 comments on commit a967f69

Please sign in to comment.