Skip to content

Commit

Permalink
Pass in split RNGs to multiple chains (#205)
Browse files Browse the repository at this point in the history
  • Loading branch information
neerajprad authored and fehiepsi committed Jun 14, 2019
1 parent 31ea972 commit 7f0605f
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion examples/bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def run_inference(model, args, rng, X, Y, D_H):
init_params, potential_fn, constrain_fn = initialize_model(rng, model, X, Y, D_H)
start = time.time()
samples = mcmc(args.num_warmup, args.num_samples, init_params, num_chains=args.num_chains,
sampler='hmc', potential_fn=potential_fn, constrain_fn=constrain_fn, progbar=None)
sampler='hmc', potential_fn=potential_fn, constrain_fn=constrain_fn)
print('\nMCMC elapsed time:', time.time() - start)
return samples

Expand Down
2 changes: 1 addition & 1 deletion examples/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def run_inference(model, args, rng, X, Y):
init_params, potential_fn, constrain_fn = initialize_model(rng, model, X, Y)
start = time.time()
samples = mcmc(args.num_warmup, args.num_samples, init_params, num_chains=args.num_chains,
sampler='hmc', potential_fn=potential_fn, constrain_fn=constrain_fn, progbar=None)
sampler='hmc', potential_fn=potential_fn, constrain_fn=constrain_fn)
print('\nMCMC elapsed time:', time.time() - start)
return samples

Expand Down
1 change: 1 addition & 0 deletions numpyro/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def mcmc(num_warmup, num_samples, init_params, num_chains=1, sampler='hmc',
samples = tree_map(lambda x: x[np.newaxis, ...], samples_flat)
else:
def single_chain_mcmc(rng, init_params):
sampler_kwargs['rng'] = rng
hmc_state = init_kernel(init_params, num_warmup, run_warmup=False, **sampler_kwargs)
samples = fori_collect(num_warmup, num_warmup + num_samples, sample_kernel, hmc_state,
transform=lambda x: constrain_fn(x.z),
Expand Down

0 comments on commit 7f0605f

Please sign in to comment.