Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,14 @@ def sample(draws, step=None, init='advi', n_init=200000, start=None,
automatically (defaults to None).
init : str {'advi', 'advi_map', 'map', 'nuts', None}
Initialization method to use.
* advi : Run ADVI to estimate posterior mean and diagonal covariance matrix.
* advi : Run ADVI to estimate starting points and diagonal covariance
matrix. If njobs > 1 it will sample starting points from the estimated
posterior, otherwise it will use the estimated posterior mean.
* advi_map: Initialize ADVI with MAP and use MAP as starting point.
* map : Use the MAP as starting point.
* nuts : Run NUTS and estimate posterior mean and covariance matrix.
* nuts : Run NUTS to estimate starting points and covariance matrix. If
njobs > 1 it will sample starting points from the estimated posterior,
otherwise it will use the estimated posterior mean.
* None : Do not initialize.
n_init : int
Number of iterations of initializer
Expand Down Expand Up @@ -142,11 +146,12 @@ def sample(draws, step=None, init='advi', n_init=200000, start=None,
MultiTrace object with access to sampling values
"""
model = modelcontext(model)


if step is None and init is not None and pm.model.all_continuous(model.vars):
# By default, use NUTS sampler
pm._log.info('Auto-assigning NUTS sampler...')
start_, step = init_nuts(init=init, n_init=n_init, model=model)
start_, step = init_nuts(init=init, njobs=njobs, n_init=n_init, model=model, random_seed=random_seed)
if start is None:
start = start_
else:
Expand Down Expand Up @@ -393,7 +398,8 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None, random_see
return {k: np.asarray(v) for k, v in ppc.items()}


def init_nuts(init='advi', n_init=500000, model=None, **kwargs):
def init_nuts(init='advi', njobs=1, n_init=500000, model=None,
random_seed=-1, **kwargs):
"""Initialize and sample from posterior of a continuous model.

This is a convenience function. NUTS convergence and sampling speed is extremely
Expand All @@ -409,6 +415,8 @@ def init_nuts(init='advi', n_init=500000, model=None, **kwargs):
* advi_map: Initialize ADVI with MAP and use MAP as starting point.
* map : Use the MAP as starting point.
* nuts : Run NUTS and estimate posterior mean and covariance matrix.
njobs : int
Number of parallel jobs to start.
n_init : int
Number of iterations of initializer
If 'advi', number of iterations, if 'metropolis', number of draws.
Expand All @@ -430,23 +438,31 @@ def init_nuts(init='advi', n_init=500000, model=None, **kwargs):

pm._log.info('Initializing NUTS using {}...'.format(init))

random_seed = int(np.atleast_1d(random_seed)[0])

if init == 'advi':
v_params = pm.variational.advi(n=n_init)
start = pm.variational.sample_vp(v_params, 1, progressbar=False, hide_transformed=False)[0]
v_params = pm.variational.advi(n=n_init, random_seed=random_seed)
start = pm.variational.sample_vp(v_params, njobs, progressbar=False,
hide_transformed=False,
random_seed=random_seed)
if njobs == 1:
start = start[0]
cov = np.power(model.dict_to_array(v_params.stds), 2)
elif init == 'advi_map':
start = pm.find_MAP()
v_params = pm.variational.advi(n=n_init, start=start)
v_params = pm.variational.advi(n=n_init, start=start,
random_seed=random_seed)
cov = np.power(model.dict_to_array(v_params.stds), 2)
elif init == 'map':
start = pm.find_MAP()
cov = pm.find_hessian(point=start)

elif init == 'nuts':
init_trace = pm.sample(step=pm.NUTS(), draws=n_init)
cov = pm.trace_cov(init_trace[n_init//2:])

start = {varname: np.mean(init_trace[varname]) for varname in init_trace.varnames}
init_trace = pm.sample(step=pm.NUTS(), draws=n_init,
random_seed=random_seed)[n_init // 2:]
cov = np.atleast_1d(pm.trace_cov(init_trace))
start = np.random.choice(init_trace, njobs)
if njobs == 1:
start = start[0]
else:
raise NotImplemented('Initializer {} is not supported.'.format(init))

Expand Down