Skip to content

Commit

Permalink
Merge b380975 into ab437e4
Browse files Browse the repository at this point in the history
  • Loading branch information
springcoil committed Dec 21, 2016
2 parents ab437e4 + b380975 commit 0c2e665
Showing 1 changed file with 78 additions and 42 deletions.
120 changes: 78 additions & 42 deletions pymc3/tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
import unittest

from .checks import close_to
from .models import simple_categorical, mv_simple, mv_simple_discrete, simple_2model
from .models import mv_simple, mv_simple_discrete, simple_2model, simple_categorical
from .helpers import SeededTest
from pymc3.sampling import assign_step_methods, sample
from pymc3.model import Model
from pymc3.step_methods import (NUTS, BinaryGibbsMetropolis, CategoricalGibbsMetropolis,
Metropolis, Slice, CompoundStep,
MultivariateNormalProposal, HamiltonianMC)
from pymc3.distributions import Binomial, Normal, Bernoulli, Categorical
from pymc3.distributions import Binomial, Normal, Bernoulli, Categorical, Uniform

from numpy.testing import assert_array_almost_equal
from numpy.testing import assert_array_almost_equal, assert_array_less
from scipy import stats
import numpy as np
from tqdm import tqdm

Expand Down Expand Up @@ -39,26 +41,26 @@ class TestStepMethods(object): # yield test doesn't work subclassing unittest.T
7.04959179e-01, 8.37863464e-01, -5.24200836e-01, 1.28261340e+00, 9.08774240e-01,
8.80566763e-01, 7.82911967e-01, 8.01843432e-01, 7.09251098e-01, 5.73803618e-01]),
HamiltonianMC: np.array([
-0.74925631, -0.2566773 , -2.12480977, 1.64328926, -1.39315913,
2.04200003, 0.00706711, 0.34240498, 0.44276674, -0.21368043,
-0.76398723, 1.19280082, -1.43030242, -0.44896107, 0.0547087 ,
-1.72170938, -0.20443956, 0.35432546, 1.77695096, -0.31053636,
-0.26729283, 1.26450201, 0.17049917, 0.27953939, -0.24185153,
0.95617117, -0.45707061, 0.75837366, -1.73391277, 1.63331612,
-0.68426038, 0.20499991, -0.43866983, 0.31080195, 0.47104548,
-0.50331753, 0.7821196 , -1.7544931 , 1.24106497, -1.0152971 ,
-0.01949091, -0.33151479, 0.19138253, 0.40349184, 0.31694823,
-0.01508142, -0.31330951, 0.40874228, 0.40874228, 0.58078882,
0.68378375, 0.84142914, 0.44756075, -0.87297183, 0.59695222,
1.96161733, -0.37126652, 0.27552912, 0.74547583, -0.16172925,
0.79969568, -0.20501522, -0.36181518, 0.13114261, -0.8461323 ,
-0.07749079, -0.07013026, 0.88022116, -0.5546825 , 0.25232708,
0.09483573, 0.84910913, 1.33348018, -1.1971401 , 0.49203123,
0.22365435, 1.3801812 , 0.06885929, 1.07115053, -1.52225141,
-0.74925631, -0.2566773, -2.12480977, 1.64328926, -1.39315913,
2.04200003, 0.00706711, 0.34240498, 0.44276674, -0.21368043,
-0.76398723, 1.19280082, -1.43030242, -0.44896107, 0.0547087,
-1.72170938, -0.20443956, 0.35432546, 1.77695096, -0.31053636,
-0.26729283, 1.26450201, 0.17049917, 0.27953939, -0.24185153,
0.95617117, -0.45707061, 0.75837366, -1.73391277, 1.63331612,
-0.68426038, 0.20499991, -0.43866983, 0.31080195, 0.47104548,
-0.50331753, 0.7821196, -1.7544931, 1.24106497, -1.0152971,
-0.01949091, -0.33151479, 0.19138253, 0.40349184, 0.31694823,
-0.01508142, -0.31330951, 0.40874228, 0.40874228, 0.58078882,
0.68378375, 0.84142914, 0.44756075, -0.87297183, 0.59695222,
1.96161733, -0.37126652, 0.27552912, 0.74547583, -0.16172925,
0.79969568, -0.20501522, -0.36181518, 0.13114261, -0.8461323,
-0.07749079, -0.07013026, 0.88022116, -0.5546825, 0.25232708,
0.09483573, 0.84910913, 1.33348018, -1.1971401, 0.49203123,
0.22365435, 1.3801812, 0.06885929, 1.07115053, -1.52225141,
1.50179721, -2.01528399, -1.31610679, -0.32298834, -0.80630885,
-0.6828592 , 0.2897919 , 1.64608125, -0.71793662, -0.5233058 ,
0.53549836, 0.61119221, 0.24235732, -1.3940593 , 0.28380114,
-0.22629978, -0.19318957, 1.12543101, -1.40328285, 0.21054137]),
-0.6828592, 0.2897919, 1.64608125, -0.71793662, -0.5233058,
0.53549836, 0.61119221, 0.24235732, -1.3940593, 0.28380114,
-0.22629978, -0.19318957, 1.12543101, -1.40328285, 0.21054137]),
Metropolis: np.array([
1.62434536, 1.01258895, 0.4844172, -0.58855142, 1.15626034, 0.39505344, 1.85716138,
-0.20297933, -0.20297933, -0.20297933, -0.20297933, -1.08083775, -1.08083775,
Expand All @@ -76,26 +78,26 @@ class TestStepMethods(object): # yield test doesn't work subclassing unittest.T
2.18960348, 2.18960348, 2.63096792, 2.53081269, 2.5482221, 1.42620337, 0.90910891,
-0.08791792, 0.40729341, 0.23259025, 0.23259025, 0.23259025, 2.76091595, 2.51228118]),
NUTS: np.array([
1.11832371, 1.11832371, 0.6296164 , -1.88725852, -0.28085593,
-0.28085593, 0.51246481, 0.51246481, 0.51524239, -1.07479086,
-1.42956404, -1.42956404, 0.236747 , -1.04721507, -0.9716711 ,
-0.9716711 , -0.65903886, 1.392075 , 0.97569367, 0.16332534,
0.16332534, -0.36465255, 1.4513665 , 0.60044829, 0.60044829,
-0.7582248 , -0.81863678, -0.8432519 , 0.4756505 , -1.8618275 ,
-1.28026218, -1.28026218, -1.39438141, -1.33699454, 1.65436385,
1.18307331, 1.18307331, -1.02586036, -0.49676467, -0.17723852,
-0.17723852, -0.46433406, 0.80562527, 0.0872473 , 0.0872473 ,
-0.5121292 , 0.2457691 , 0.2457691 , -0.28227229, -0.56863496,
-0.13361197, 0.40085491, -0.34774478, -0.34774478, -0.49552973,
-0.50378818, 0.26151237, 0.26151237, -0.3271885 , -0.33567672,
0.9339671 , 0.92457538, 0.92457538, 0.92079262, 0.07603108,
0.53798844, 0.5509594 , -0.10204449, -0.10204449, -0.27679836,
0.29559731, -0.63318597, 0.50441029, 0.50441029, -0.64408839,
0.85784078, 0.83528853, 0.32389337, 0.32389337, 0.32371787,
0.45515893, 0.42012555, 0.42012555, -0.02036946, -0.1275346 ,
0.1818874 , 1.28333928, 0.92705793, 0.92705793, 0.42301906,
1.97444363, 2.12202997, 1.79560373, 1.79560373, 1.658178 ,
1.64034201, 2.01737179, 1.45213152, 1.45213152, 1.4536979 ]),
1.11832371, 1.11832371, 0.6296164, -1.88725852, -0.28085593,
-0.28085593, 0.51246481, 0.51246481, 0.51524239, -1.07479086,
-1.42956404, -1.42956404, 0.236747, -1.04721507, -0.9716711,
-0.9716711, -0.65903886, 1.392075, 0.97569367, 0.16332534,
0.16332534, -0.36465255, 1.4513665, 0.60044829, 0.60044829,
-0.7582248, -0.81863678, -0.8432519, 0.4756505, -1.8618275,
-1.28026218, -1.28026218, -1.39438141, -1.33699454, 1.65436385,
1.18307331, 1.18307331, -1.02586036, -0.49676467, -0.17723852,
-0.17723852, -0.46433406, 0.80562527, 0.0872473, 0.0872473,
-0.5121292, 0.2457691, 0.2457691, -0.28227229, -0.56863496,
-0.13361197, 0.40085491, -0.34774478, -0.34774478, -0.49552973,
-0.50378818, 0.26151237, 0.26151237, -0.3271885, -0.33567672,
0.9339671, 0.92457538, 0.92457538, 0.92079262, 0.07603108,
0.53798844, 0.5509594, -0.10204449, -0.10204449, -0.27679836,
0.29559731, -0.63318597, 0.50441029, 0.50441029, -0.64408839,
0.85784078, 0.83528853, 0.32389337, 0.32389337, 0.32371787,
0.45515893, 0.42012555, 0.42012555, -0.02036946, -0.1275346,
0.1818874, 1.28333928, 0.92705793, 0.92705793, 0.42301906,
1.97444363, 2.12202997, 1.79560373, 1.79560373, 1.658178,
1.64034201, 2.01737179, 1.45213152, 1.45213152, 1.4536979]),
}

def test_sample_exact(self):
Expand Down Expand Up @@ -238,3 +240,37 @@ def test_binomial(self):
Binomial('x', 10, 0.5)
steps = assign_step_methods(model, [])
self.assertIsInstance(steps, Metropolis)


class TestSampleEstimates(SeededTest):
def test_parameter_estimate(self):
alpha_true, sigma_true = 1, 0.5
beta_true = np.array([1, 2.5])

size = 100

X1 = np.random.randn(size)
X2 = np.random.randn(size) * 0.2
Y = alpha_true + beta_true[0] * X1 + beta_true[1] * X2 + np.random.randn(size) * sigma_true

with Model() as model:
alpha = Normal('alpha', mu=0, sd=10)
beta = Normal('beta', mu=0, sd=10, shape=2)
sigma = Uniform('sigma', lower=0.0, upper=1.0)
mu = alpha + beta[0] * X1 + beta[1] * X2
Y_obs = Normal('Y_obs', mu=mu, sd=sigma, observed=Y)

for step_method in (NUTS(), Metropolis(),
[Slice([alpha, sigma]), Metropolis([beta])]):
trace = sample(1000, step=step_method, progressbar=False)

assert np.isclose(np.median(trace.beta, 0), beta_true, rtol=0.1).all()
assert np.isclose(np.median(trace.alpha), alpha_true, rtol=0.1)
assert np.isclose(np.median(trace.sigma), sigma_true, rtol=0.1)
np.random.seed(987654321)
test_normal = stats.kstest(trace.alpha, 'norm', alternative='greater')
test_normal_beta = stats.kstest(trace.beta[0], 'norm', alternative='greater')
test_uniform = stats.kstest(trace.sigma, 'uniform', alternative='greater')
assert np.less(np.median(test_normal[1]), 0.05)
assert np.less(np.median(test_normal_beta[1]) / 2, 0.5)
assert np.less(np.median(test_uniform[1]), 0.05)

0 comments on commit 0c2e665

Please sign in to comment.