Skip to content

Commit

Permalink
Respect init params if provided to mcmc.run (#1547)
Browse files Browse the repository at this point in the history
* respect init params if provided

* fix lint
  • Loading branch information
fehiepsi committed Mar 8, 2023
1 parent 9423ebc commit 94121ac
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
9 changes: 8 additions & 1 deletion numpyro/infer/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,12 @@ def __init__(

def _init_state(self, rng_key, model_args, model_kwargs, init_params):
if self._model is not None:
init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
(
new_init_params,
potential_fn,
postprocess_fn,
model_trace,
) = initialize_model(
rng_key,
self._model,
dynamic_args=True,
Expand All @@ -658,6 +663,8 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params):
model_kwargs=model_kwargs,
forward_mode_differentiation=self._forward_mode_differentiation,
)
if init_params is None:
init_params = new_init_params
if self._init_fn is None:
self._init_fn, self._sample_fn = hmc(
potential_fn_gen=potential_fn,
Expand Down
8 changes: 6 additions & 2 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,9 @@ def warmup(
:param bool collect_warmup: Whether to collect samples from the warmup phase. Defaults
to `False`.
:param init_params: Initial parameters to begin sampling. The type must be consistent
with the input type to `potential_fn`.
with the input type to `potential_fn` provided to the kernel. If the kernel is
instantiated by a numpyro model, the initial parameters here correspond to latent
values in unconstrained space.
:param kwargs: Keyword arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init`
method. These are typically the keyword arguments needed by the `model`.
"""
Expand Down Expand Up @@ -546,7 +548,9 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
`"adapt_state.step_size"` can be used to collect step sizes at each step.
:type extra_fields: tuple or list of str
:param init_params: Initial parameters to begin sampling. The type must be consistent
with the input type to `potential_fn`.
with the input type to `potential_fn` provided to the kernel. If the kernel is
instantiated by a numpyro model, the initial parameters here correspond to latent
values in unconstrained space.
:param kwargs: Keyword arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init`
method. These are typically the keyword arguments needed by the `model`.
Expand Down

0 comments on commit 94121ac

Please sign in to comment.