diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 6c80b49472..85d8e95c93 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -142,6 +142,9 @@ def sample(draws, step=None, init='advi', n_init=200000, start=None, MultiTrace object with access to sampling values """ model = modelcontext(model) + + if init is not None: + init = init.lower() if step is None and init is not None and pm.model.all_continuous(model.vars): # By default, use NUTS sampler