Skip to content

Commit

Permalink
make HMCState pickable (#161)
Browse files Browse the repository at this point in the history
* make HMCState pickable

* remove laxtuple import
  • Loading branch information
fehiepsi authored and neerajprad committed May 20, 2019
1 parent 1274b6c commit 6a5feb3
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions numpyro/mcmc.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@
import math
import os
from collections import namedtuple

import tqdm

import jax.numpy as np
from jax import jit, partial, random
from jax.flatten_util import ravel_pytree
from jax.random import PRNGKey
from jax.tree_util import register_pytree_node

import numpyro.distributions as dist
from numpyro.hmc_util import IntegratorState, build_tree, find_reasonable_step_size, velocity_verlet, warmup_adapter
from numpyro.util import cond, fori_loop, laxtuple
from numpyro.util import cond, fori_loop

HMCState = laxtuple('HMCState', ['z', 'z_grad', 'potential_energy', 'num_steps', 'accept_prob',
'step_size', 'inverse_mass_matrix', 'rng'])
HMCState = namedtuple('HMCState', ['z', 'z_grad', 'potential_energy', 'num_steps', 'accept_prob',
'step_size', 'inverse_mass_matrix', 'rng'])


register_pytree_node(
HMCState,
lambda xs: (tuple(xs), None),
lambda _, xs: HMCState(*xs)
)


HMCState.update = HMCState._replace


def _get_num_steps(step_size, trajectory_length):
Expand Down

0 comments on commit 6a5feb3

Please sign in to comment.