-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
Hi!
First: Thanks for maintaining, keep up the good work :)
Description of your problem
Today i updated from v3.5 to 3.7 and our code broke :(
with model:
trace = pm.sample(
draws=draws,
tune=1000,
progressbar=False,
nuts_kwargs=dict(target_accept=0.9)
)
Looking at your release notes i found that:
nuts_kwargs and step_kwargs have been deprecated in favor of using the standard kwargs to pass optional step method arguments.
So i changed our code to
with model:
trace = pm.sample(
draws=draws,
tune=1000,
progressbar=False,
target_accept=0.9,
)
Which led to this Exception:
/opt/conda/lib/python3.6/site-packages/pymc3/sampling.py:406: in sample
step = assign_step_methods(model, step, step_kwargs=kwargs)
/opt/conda/lib/python3.6/site-packages/pymc3/sampling.py:155: in assign_step_methods
return instantiate_steppers(model, steps, selected_steps, step_kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
model = <pymc3.model.Model object at 0x7f0047b2a9b0>
steps = [<pymc3.step_methods.hmc.nuts.NUTS object at 0x7f00335fe3c8>, <pymc3.step_methods.metropolis.BinaryGibbsMetropolis object at 0x7f001c9fada0>]
selected_steps = defaultdict(<class 'list'>, {<class 'pymc3.step_methods.hmc.nuts.NUTS'>: [pi_outlier_logodds__, a_eta, b_eta, tau_eta_...lon, eta, epsilon, r_normal, sd, nu_log__], <class 'pymc3.step_methods.metropolis.BinaryGibbsMetropolis'>: [lambda_t]})
step_kwargs = {'target_accept': 0.9}
def instantiate_steppers(model, steps, selected_steps, step_kwargs=None):
if step_kwargs is None:
step_kwargs = {}
used_keys = set()
for step_class, vars in selected_steps.items():
if len(vars) == 0:
continue
args = step_kwargs.get(step_class.name, {})
used_keys.add(step_class.name)
step = step_class(vars=vars, **args)
steps.append(step)
unused_args = set(step_kwargs).difference(used_keys)
if unused_args:
> raise ValueError('Unused step method arguments: %s' % unused_args)
E ValueError: Unused step method arguments: {'target_accept'}
/opt/conda/lib/python3.6/site-packages/pymc3/sampling.py:81: ValueErrorAditional Info
I'm no pymc3 expert, to be honest i can;t event describe what our code does :)
What i see is that, in the for loop of instantiate_steppers values of step_kwargs that matchpymc3.step_methods.hmc.nuts.NUTS#name (nuts) will get extracted. I only passed {"target_accept": 0.9}, so the ValueError telling me about unused variables gets thrown.
When i change the call to pm.sample to the following, our tests pass again.
with model:
trace = pm.sample(
draws=draws,
tune=1000,
progressbar=False,
nuts={"target_accept": 0.9},
)
#3327 introducted the changes that lead to described behavior. Is this a bug in pymc3 or did we use the library in a wrong way?
(In a test, commit 990d248 added an assertion around a call that is excactly what i thought i should do to fix that error.)
Versions and main components
- PyMC3 Version: 3.7
- Theano Version: Theano==1.0.4
- Python Version: Python 3.6.0 :: Continuum Analytics, Inc.
- Operating system: alpine:3.6
- How did you install PyMC3: conda