130 changes: 130 additions & 0 deletions pymc3/tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
HamiltonianMC,
EllipticalSlice,
DEMetropolis,
DEMetropolisZ,
)
from pymc3.theanof import floatX
from pymc3.distributions import Binomial, Normal, Bernoulli, Categorical, Beta, HalfNormal
Expand Down Expand Up @@ -778,6 +779,135 @@ def test_parallelized_chains_are_random(self):
pass


class TestDEMetropolisZ:
def test_tuning_lambda_sequential(self):
with Model() as pmodel:
Normal('n', 0, 2, shape=(3,))
trace = sample(
tune=1000,
draws=500,
step=DEMetropolisZ(tune='lambda', lamb=0.92),
cores=1,
chains=3,
discard_tuned_samples=False
)
for c in range(trace.nchains):
# check that the tuned settings changed and were reset
assert trace.get_sampler_stats('lambda', chains=c)[0] == 0.92
assert trace.get_sampler_stats('lambda', chains=c)[-1] != 0.92
assert set(trace.get_sampler_stats('tune', chains=c)) == {True, False}
pass

def test_tuning_epsilon_parallel(self):
with Model() as pmodel:
Normal('n', 0, 2, shape=(3,))
trace = sample(
tune=1000,
draws=500,
step=DEMetropolisZ(tune='scaling', scaling=0.002),
cores=2,
chains=2,
discard_tuned_samples=False
)
for c in range(trace.nchains):
# check that the tuned settings changed and were reset
assert trace.get_sampler_stats('scaling', chains=c)[0] == 0.002
assert trace.get_sampler_stats('scaling', chains=c)[-1] != 0.002
assert set(trace.get_sampler_stats('tune', chains=c)) == {True, False}
pass

def test_tuning_none(self):
with Model() as pmodel:
Normal('n', 0, 2, shape=(3,))
trace = sample(
tune=1000,
draws=500,
step=DEMetropolisZ(tune=None),
cores=1,
chains=2,
discard_tuned_samples=False
)
for c in range(trace.nchains):
# check that all tunable parameters remained constant
assert len(set(trace.get_sampler_stats('lambda', chains=c))) == 1
assert len(set(trace.get_sampler_stats('scaling', chains=c))) == 1
assert set(trace.get_sampler_stats('tune', chains=c)) == {True, False}
pass

def test_tuning_reset(self):
"""Re-use of the step method instance with cores=1 must not leak tuning information between chains."""
with Model() as pmodel:
D = 3
Normal('n', 0, 2, shape=(D,))
trace = sample(
tune=1000,
draws=500,
step=DEMetropolisZ(tune='scaling', scaling=0.002),
cores=1,
chains=3,
discard_tuned_samples=False
)
for c in range(trace.nchains):
# check that the tuned settings changed and were reset
assert trace.get_sampler_stats('scaling', chains=c)[0] == 0.002
assert trace.get_sampler_stats('scaling', chains=c)[-1] != 0.002
# check that the variance of the first 50 iterations is much lower than the last 100
for d in range(D):
var_start = np.var(trace.get_values('n', chains=c)[:50,d])
var_end = np.var(trace.get_values('n', chains=c)[-100:,d])
assert var_start < 0.1 * var_end
pass

def test_tune_drop_fraction(self):
tune = 300
tune_drop_fraction = 0.85
draws = 200
with Model() as pmodel:
Normal('n', 0, 2, shape=(3,))
step = DEMetropolisZ(tune_drop_fraction=tune_drop_fraction)
trace = sample(
tune=tune,
draws=draws,
step=step,
cores=1,
chains=1,
discard_tuned_samples=False
)
assert len(trace) == tune + draws
assert len(step._history) == (tune - tune * tune_drop_fraction) + draws
pass

@pytest.mark.parametrize('variable,has_grad,outcome', [('n', True, 1),('n', False, 1),('b', True, 0),('b', False, 0)])
def test_competence(self, variable, has_grad, outcome):
with Model() as pmodel:
Normal('n', 0, 2, shape=(3,))
Binomial('b', n=2, p=0.3)
assert DEMetropolisZ.competence(pmodel[variable], has_grad=has_grad) == outcome
pass

@pytest.mark.parametrize('tune_setting', ['foo', True, False])
def test_invalid_tune(self, tune_setting):
with Model() as pmodel:
Normal('n', 0, 2, shape=(3,))
with pytest.raises(ValueError):
DEMetropolisZ(tune=tune_setting)
pass

def test_custom_proposal_dist(self):
with Model() as pmodel:
D = 3
Normal('n', 0, 2, shape=(D,))
trace = sample(
tune=100,
draws=50,
step=DEMetropolisZ(proposal_dist=NormalProposal),
cores=1,
chains=3,
discard_tuned_samples=False
)
pass


@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
class TestNutsCheckTrace:
def test_multiple_samplers(self, caplog):
Expand Down