Skip to content

Commit

Permalink
Test examples in travis CI (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
neerajprad committed May 13, 2019
1 parent 4781c5d commit ec450c6
Show file tree
Hide file tree
Showing 16 changed files with 121 additions and 66 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ __pycache__/

# data files
numpyro/examples/.data
numpyro/examples/.results
examples/.results
numpyro/.DS_Store

# test related
Expand Down
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ install:
- pip install -U pip
- pip install jaxlib
- pip install jax
- pip install .[test]
- pip install .[examples,test]
- pip freeze

branches:
Expand Down
Empty file added examples/__init__.py
Empty file.
7 changes: 3 additions & 4 deletions numpyro/examples/baseball.py → examples/baseball.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,9 @@ def partially_pooled_with_logit(at_bats, hits=None):


def run_inference(model, at_bats, hits, rng, args):
init_params, potential_fn, transform_fn = initialize_model(rng, model,
(at_bats, hits), {})
init_params, potential_fn, transform_fn = initialize_model(rng, model, at_bats, hits)
init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
hmc_state = init_kernel(init_params, args.num_warmup_steps)
hmc_state = init_kernel(init_params, args.num_warmup)
hmc_states = fori_collect(args.num_samples, sample_kernel, hmc_state,
transform=lambda hmc_state: transform_fn(hmc_state.z))
return hmc_states
Expand Down Expand Up @@ -202,7 +201,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Baseball batting average using HMC")
parser.add_argument("-n", "--num-samples", nargs="?", default=3000, type=int)
parser.add_argument("--num-warmup-steps", nargs='?', default=1500, type=int)
parser.add_argument("--num-warmup", nargs='?', default=1500, type=int)
parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".')
args = parser.parse_args()
main(args)
14 changes: 6 additions & 8 deletions numpyro/examples/covtype.py → examples/covtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import time

import numpy as onp
from sklearn.datasets import fetch_covtype

import jax.numpy as np
from jax import random
from jax.config import config as jax_config

import numpyro.distributions as dist
from numpyro.examples.datasets import COVTYPE, load_dataset
from numpyro.handlers import sample
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import hmc
Expand All @@ -32,11 +32,9 @@
-1.59496680e-01, -1.88516974e-01, -1.20889175e+00])}


# TODO: add to datasets.py so as to avoid dependency on scikit-learn
def load_dataset():
data = fetch_covtype()
features = data.data
labels = data.target
def _load_dataset():
_, fetch = load_dataset(COVTYPE, shuffle=False)
features, labels = fetch()

# normalize features and add intercept
features = (features - features.mean(0)) / features.std(0)
Expand All @@ -63,7 +61,7 @@ def model(data, labels):

def benchmark_hmc(args, features, labels):
trajectory_length = step_size * args.num_steps
_, potential_fn, _ = initialize_model(random.PRNGKey(1), model, (features, labels,), {})
_, potential_fn, _ = initialize_model(random.PRNGKey(1), model, features, labels)
init_kernel, sample_kernel = hmc(potential_fn, algo=args.algo)
t0 = time.time()
# TODO: Use init_params from `initialize_model` instead of fixed params.
Expand All @@ -84,7 +82,7 @@ def transform(state): return {'coefs': state.z['coefs'],

def main(args):
jax_config.update("jax_platform_name", args.device)
features, labels = load_dataset()
features, labels = _load_dataset()
benchmark_hmc(args, features, labels)


Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def main(args):
_, fetch = load_dataset(SP500, shuffle=False)
dates, returns = fetch()
init_rng, sample_rng = random.split(random.PRNGKey(args.rng))
init_params, potential_fn, transform_fn = initialize_model(init_rng, model, (returns,), {})
init_params, potential_fn, transform_fn = initialize_model(init_rng, model, returns)
init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
hmc_state = init_kernel(init_params, args.num_warmup_steps, rng=sample_rng)
hmc_state = init_kernel(init_params, args.num_warmup, rng=sample_rng)
hmc_states = fori_collect(args.num_samples, sample_kernel, hmc_state,
transform=lambda hmc_state: transform_fn(hmc_state.z))
print_results(hmc_states, dates)
Expand All @@ -79,7 +79,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Stochastic Volatility Model")
parser.add_argument('-n', '--num-samples', nargs='?', default=3000, type=int)
parser.add_argument('--num-warmup-steps', nargs='?', default=1500, type=int)
parser.add_argument('--num-warmup', nargs='?', default=1500, type=int)
parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".')
parser.add_argument('--rng', default=21, type=int, help='random number generator seed')
args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion numpyro/examples/ucbadmit.py → examples/ucbadmit.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def glmm(dept, male, applications, admit):

def run_inference(dept, male, applications, admit, rng, args):
init_params, potential_fn, transform_fn = initialize_model(
rng, glmm, (dept, male, applications, admit), {})
rng, glmm, dept, male, applications, admit)
init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
hmc_state = init_kernel(init_params, args.num_warmup_steps)
hmc_states = fori_collect(args.num_samples, sample_kernel, hmc_state,
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
promote_shapes,
signed_stick_breaking_tril,
standard_gamma,
vec_to_tril_matrix,
vec_to_tril_matrix
)


Expand Down
86 changes: 55 additions & 31 deletions numpyro/examples/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,44 @@
from jax import device_put, lax
from jax.interpreters.xla import DeviceArray

DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__),
'.data'))
if 'CI' in os.environ:
DATA_DIR = os.path.expanduser('~/.data')
else:
DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__),
'.data'))
os.makedirs(DATA_DIR, exist_ok=True)


dset = namedtuple('dset', ['name', 'urls'])


BASEBALL = dset('baseball', [
'https://d2fefpcigoriu7.cloudfront.net/datasets/EfronMorrisBB.txt',
])


COVTYPE = dset('covtype', [
'https://d2fefpcigoriu7.cloudfront.net/datasets/covtype.data.gz',
])


MNIST = dset('mnist', [
'https://d2fefpcigoriu7.cloudfront.net/datasets/mnist/train-images-idx3-ubyte.gz',
'https://d2fefpcigoriu7.cloudfront.net/datasets/mnist/train-labels-idx1-ubyte.gz',
'https://d2fefpcigoriu7.cloudfront.net/datasets/mnist/t10k-images-idx3-ubyte.gz',
'https://d2fefpcigoriu7.cloudfront.net/datasets/mnist/t10k-labels-idx1-ubyte.gz',
])

BASEBALL = dset('baseball', [
'https://d2fefpcigoriu7.cloudfront.net/datasets/EfronMorrisBB.txt',

SP500 = dset('SP500', [
'https://d2fefpcigoriu7.cloudfront.net/datasets/SP500.csv',
])


UCBADMIT = dset('ucbadmit', [
'https://d2fefpcigoriu7.cloudfront.net/datasets/UCBadmit.csv',
])

SP500 = dset('SP500', [
'https://d2fefpcigoriu7.cloudfront.net/datasets/SP500.csv',
])


def _download(dset):
for url in dset.urls:
Expand All @@ -49,6 +60,37 @@ def _download(dset):
print('Download complete.')


def _load_baseball():
_download(BASEBALL)

def train_test_split(file):
train, test, player_names = [], [], []
with open(file, 'r') as f:
csv_reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
for row in csv_reader:
player_names.append(row['FirstName'] + ' ' + row['LastName'])
at_bats, hits = row['At-Bats'], row['Hits']
train.append(np.array([int(at_bats), int(hits)]))
season_at_bats, season_hits = row['SeasonAt-Bats'], row['SeasonHits']
test.append(np.array([int(season_at_bats), int(season_hits)]))
return np.stack(train), np.stack(test), np.array(player_names)

train, test, player_names = train_test_split(os.path.join(DATA_DIR, 'EfronMorrisBB.txt'))
return {'train': (train, player_names),
'test': (test, player_names)}


def _load_covtype():
_download(COVTYPE)

file_path = os.path.join(DATA_DIR, 'covtype.data.gz')
data = np.genfromtxt(gzip.GzipFile(file_path), delimiter=',')

return {
'train': (data[:, :-1], data[:, -1].astype(np.int32))
}


def _load_mnist():
_download(MNIST)

Expand All @@ -70,26 +112,6 @@ def read_img(file):
'test': (read_img(files[2]), read_label(files[3]))}


def _load_baseball():
_download(BASEBALL)

def train_test_split(file):
train, test, player_names = [], [], []
with open(file, 'r') as f:
csv_reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
for row in csv_reader:
player_names.append(row['FirstName'] + ' ' + row['LastName'])
at_bats, hits = row['At-Bats'], row['Hits']
train.append(np.array([int(at_bats), int(hits)]))
season_at_bats, season_hits = row['SeasonAt-Bats'], row['SeasonHits']
test.append(np.array([int(season_at_bats), int(season_hits)]))
return np.stack(train), np.stack(test), np.array(player_names)

train, test, player_names = train_test_split(os.path.join(DATA_DIR, 'EfronMorrisBB.txt'))
return {'train': (train, player_names),
'test': (test, player_names)}


def _load_sp500():
_download(SP500)

Expand Down Expand Up @@ -126,10 +148,12 @@ def _load_ucbadmit():


def _load(dset):
if dset == MNIST:
return _load_mnist()
elif dset == BASEBALL:
if dset == BASEBALL:
return _load_baseball()
elif dset == COVTYPE:
return _load_covtype()
elif dset == MNIST:
return _load_mnist()
elif dset == SP500:
return _load_sp500()
elif dset == UCBADMIT:
Expand Down
2 changes: 1 addition & 1 deletion numpyro/hmc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def transform_fn(transforms, params, invert=False):
for k, v in params.items()}


def initialize_model(rng, model, model_args, model_kwargs):
def initialize_model(rng, model, *model_args, **model_kwargs):
model = seed(model, rng)
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
sample_sites = {k: v for k, v in model_trace.items() if v['type'] == 'sample' and not v['is_observed']}
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
'doc': ['sphinx', 'sphinx_rtd_theme'],
'test': ['flake8', 'pytest>=4.1'],
'dev': ['ipython'],
'examples': ['matplotlib'],
},
tests_require=['flake8', 'pytest>=4.1'],
keywords='probabilistic machine learning bayesian statistics',
Expand Down
25 changes: 16 additions & 9 deletions test/test_example_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
import jax.numpy as np
from jax import lax

from numpyro.examples.datasets import BASEBALL, MNIST, SP500, load_dataset
from numpyro.examples.datasets import BASEBALL, COVTYPE, MNIST, SP500, load_dataset


def test_baseball_data_load():
init, fetch = load_dataset(BASEBALL, split='train', shuffle=False)
num_batches, idx = init()
dataset = fetch(0, idx)
assert np.shape(dataset[0]) == (18, 2)
assert np.shape(dataset[1]) == (18,)


def test_covtype_data_load():
_, fetch = load_dataset(COVTYPE, shuffle=False)
x, y = fetch()
assert np.shape(x) == (581012, 54)
assert np.shape(y) == (581012,)


def test_mnist_data_load():
Expand All @@ -14,14 +29,6 @@ def mean_pixels(i, mean_pix):
assert lax.fori_loop(0, num_batches, mean_pixels, np.float32(0.)) / num_batches < 0.15


def test_baseball_data_load():
init, fetch = load_dataset(BASEBALL, split='train', shuffle=False)
num_batches, idx = init()
dataset = fetch(0, idx)
assert np.shape(dataset[0]) == (18, 2)
assert np.shape(dataset[1]) == (18,)


def test_sp500_data_load():
_, fetch = load_dataset(SP500, split='train', shuffle=False)
date, value = fetch()
Expand Down
27 changes: 27 additions & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
import sys
from subprocess import check_call

import pytest

TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
EXAMPLES_DIR = os.path.join(os.path.dirname(TESTS_DIR), 'examples')


EXAMPLES = [
'baseball.py --num-samples 100 --num-warmup 100',
'covtype.py --algo hmc --num-samples 10',
'minipyro.py',
'stochastic_volatility.py --num-samples 100 --num-warmup 100',
'ucbadmit.py',
'vae.py -n 1',
]


@pytest.mark.parametrize('example', EXAMPLES)
def test_cpu(example):
print('Running:\npython examples/{}'.format(example))
example = example.split()
filename, args = example[0], example[1:]
filename = os.path.join(EXAMPLES_DIR, filename)
check_call([sys.executable, filename] + args)
11 changes: 5 additions & 6 deletions test/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def model(labels):
logits = np.sum(coefs * data, axis=-1)
return sample('obs', dist.Bernoulli(logits=logits), obs=labels)

init_params, potential_fn, transform_fn = initialize_model(random.PRNGKey(2), model, (labels,), {})
init_params, potential_fn, transform_fn = initialize_model(random.PRNGKey(2), model, labels)
init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
hmc_state = init_kernel(init_params,
trajectory_length=10,
Expand All @@ -70,7 +70,7 @@ def model(data):

true_probs = np.array([0.9, 0.1])
data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), size=(1000, 2))
init_params, potential_fn, transform_fn = initialize_model(random.PRNGKey(2), model, (data,), {})
init_params, potential_fn, transform_fn = initialize_model(random.PRNGKey(2), model, data)
init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
hmc_state = init_kernel(init_params,
trajectory_length=1.,
Expand All @@ -94,7 +94,7 @@ def model(data):

true_probs = np.array([0.1, 0.6, 0.3])
data = dist.Categorical(true_probs).sample(random.PRNGKey(1), size=(2000,))
init_params, potential_fn, transform_fn = initialize_model(random.PRNGKey(2), model, (data,), {})
init_params, potential_fn, transform_fn = initialize_model(random.PRNGKey(2), model, data)
init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
hmc_state = init_kernel(init_params,
trajectory_length=1.,
Expand Down Expand Up @@ -126,8 +126,7 @@ def model(data):
12, 35, 17, 23, 17, 4, 2, 31, 30, 13, 27, 0, 39, 37,
5, 14, 13, 22,
])
init_params, potential_fn, transform_fn = initialize_model(random.PRNGKey(2), model,
(count_data,), {})
init_params, potential_fn, transform_fn = initialize_model(random.PRNGKey(2), model, count_data)
init_kernel, sample_kernel = hmc(potential_fn)
hmc_state = init_kernel(init_params, num_warmup_steps=warmup_steps)
hmc_states = fori_collect(num_samples, sample_kernel, hmc_state,
Expand All @@ -154,7 +153,7 @@ def model(data):
sample('obs', dist.Binomial(data['n'], probs=p), obs=data['x'])

data = {'n': 5000000, 'x': 3849}
init_params, potential_fn, transform_fn = initialize_model(random.PRNGKey(2), model, (data,), {})
init_params, potential_fn, transform_fn = initialize_model(random.PRNGKey(2), model, data)
init_kernel, sample_kernel = hmc(potential_fn)
hmc_state = init_kernel(init_params, num_warmup_steps=warmup_steps)
hmc_states = fori_collect(num_samples, sample_kernel, hmc_state,
Expand Down

0 comments on commit ec450c6

Please sign in to comment.