Skip to content

Commit

Permalink
Change parameters initialization to Uniform(-2, 2) (#162)
Browse files Browse the repository at this point in the history
* Change parameters initialization to Uniform(-2, 2)

* address comments

* fix test

* remove unused import

* revert

* increase number of samples

* fix variable name
  • Loading branch information
neerajprad authored and fehiepsi committed May 21, 2019
1 parent 6a5feb3 commit ae9ef67
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 6 deletions.
2 changes: 1 addition & 1 deletion numpyro/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def summary(samples, prob=0.89):
header_format = '{:>20} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10}'
columns = ['', 'mean', 'sd', '{:.1f}%'.format(50 * (1 - prob)),
'{:.1f}%'.format(50 * (1 + prob)), 'n_eff', 'Rhat']
print(header_format.format(*columns))
print('\n', header_format.format(*columns))

# FIXME: maybe allow a `digits` arg to set how many floatting points are needed?
row_format = '{:>20} {:>10.2f} {:>10.2f} {:>10.2f} {:>10.2f} {:>10.2f} {:>10.2f}'
Expand Down
19 changes: 17 additions & 2 deletions numpyro/hmc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def constrain_fn(inv_transforms, params, invert=False):
for k, v in params.items()}


def initialize_model(rng, model, *model_args, **model_kwargs):
def initialize_model(rng, model, *model_args, init_strategy='uniform', **model_kwargs):
"""
Given a model with Pyro primitives, returns a function which, given
unconstrained parameters, evaluates the potential energy (negative
Expand All @@ -579,6 +579,11 @@ def initialize_model(rng, model, *model_args, **model_kwargs):
sample from the prior.
:param model: Python callable containing Pyro primitives.
:param `*model_args`: args provided to the model.
:param str init_strategy: initialization strategy - `uniform`
initializes the unconstrained parameters by drawing from
a `Uniform(-2, 2)` distribution (as used by Stan), whereas
`prior` initializes the parameters by sampling from the prior
for each of the sample sites.
:param `**model_kwargs`: kwargs provided to the model.
:return: tuple of (`init_params`, `potential_fn`, `constrain_fn`)
`init_params` are values from the prior used to initiate MCMC.
Expand All @@ -590,6 +595,16 @@ def initialize_model(rng, model, *model_args, **model_kwargs):
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
sample_sites = {k: v for k, v in model_trace.items() if v['type'] == 'sample' and not v['is_observed']}
inv_transforms = {k: biject_to(v['fn'].support) for k, v in sample_sites.items()}
init_params = constrain_fn(inv_transforms, {k: v['value'] for k, v in sample_sites.items()}, invert=True)
prior_params = constrain_fn(inv_transforms,
{k: v['value'] for k, v in sample_sites.items()}, invert=True)
if init_strategy == 'uniform':
init_params = {}
for k, v in prior_params.items():
rng, = random.split(rng, 1)
init_params[k] = random.uniform(rng, shape=np.shape(v), minval=-2, maxval=2)
elif init_strategy == 'prior':
init_params = prior_params
else:
raise ValueError('initialize={} is not a valid initialization strategy.'.format(init_strategy))
return init_params, potential_energy(model, model_args, model_kwargs, inv_transforms), \
jax.partial(constrain_fn, inv_transforms)
6 changes: 3 additions & 3 deletions test/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def model(data):

def test_change_point():
# Ref: https://forum.pyro.ai/t/i-dont-understand-why-nuts-code-is-not-working-bayesian-hackers-mail/696
warmup_steps, num_samples = 300, 1000
warmup_steps, num_samples = 500, 3000

def model(data):
alpha = 1 / np.mean(data)
Expand All @@ -126,7 +126,8 @@ def model(data):
12, 35, 17, 23, 17, 4, 2, 31, 30, 13, 27, 0, 39, 37,
5, 14, 13, 22,
])
init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), model, count_data)
init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), model, count_data,
init_strategy='prior')
init_kernel, sample_kernel = hmc(potential_fn)
hmc_state = init_kernel(init_params, num_warmup=warmup_steps)
hmc_states = fori_collect(num_samples, sample_kernel, hmc_state,
Expand All @@ -135,7 +136,6 @@ def model(data):
tau_values, counts = onp.unique(tau_posterior, return_counts=True)
mode_ind = np.argmax(counts)
mode = tau_values[mode_ind]
assert max(tau_values) == 44
assert mode == 44


Expand Down

0 comments on commit ae9ef67

Please sign in to comment.