Skip to content

Commit

Permalink
tests: port test_linear_regression.py
Browse files Browse the repository at this point in the history
  • Loading branch information
riddell-stan committed Dec 5, 2018
1 parent d61542a commit 01fd4b1
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
3 changes: 0 additions & 3 deletions tests/test_basic_bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
import stan


program_code = "parameters {real y;} model {y ~ normal(0,1);}"


program_code = """
data {
int<lower=0> N;
Expand Down
41 changes: 41 additions & 0 deletions tests/test_linear_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np
import pytest

import stan

np.random.seed(1)

program_code = """
data {
int<lower=0> N;
int<lower=0> p;
matrix[N,p] x;
vector[N] y;
}
parameters {
vector[p] beta;
real<lower=0> sigma;
}
model {
y ~ normal(x * beta, sigma);
}
"""

n, p = 10000, 3
X = np.random.normal(size=(n, p))
X = (X - np.mean(X, axis=0)) / np.std(X, ddof=1, axis=0, keepdims=True)
beta_true = (1, 3, 5)
y = np.dot(X, beta_true) + np.random.normal(size=n)

data = {"N": n, "p": p, "x": X, "y": y}


@pytest.fixture(scope="module")
def posterior():
return stan.build(program_code, data=data)


def test_linear_regression(posterior):
fit = posterior.sample(num_chains=4)
assert 0 < fit["sigma"].mean() < 2
assert np.allclose(fit["beta"].mean(axis=1), beta_true, atol=0.05)

0 comments on commit 01fd4b1

Please sign in to comment.