Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flake8 for test dir #179

Merged
merged 4 commits into from Nov 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .travis.yml
Expand Up @@ -26,6 +26,7 @@ script:
- mkdir new_dir
- cd new_dir
- flake8 --extend-ignore=E127,E201,E202,E203,E231,E252,E266,E402,E999,F401,F841,W503,W605 --max-line-length=80 ../cmdstanpy
- flake8 --extend-ignore=E127,E201,E202,E203,E231,E252,E266,E402,E999,F841,W503,W605 --max-line-length=80 ../test
- pytest -v ../test
- python -m pip install -r ../requirements-optional.txt
- python ../test/example_script.py
Expand Down
4 changes: 2 additions & 2 deletions test/example_script.py
Expand Up @@ -3,9 +3,9 @@
import sys

# explicit import to test if it is installed
import tqdm
import tqdm # noqa

from cmdstanpy import CmdStanModel, cmdstan_path, set_make_env
from cmdstanpy import CmdStanModel, cmdstan_path


def run_bernoulli_fit():
Expand Down
22 changes: 13 additions & 9 deletions test/test_cmdstan_args.py
@@ -1,8 +1,6 @@
import os
import unittest

from cmdstanpy import TMPDIR

from cmdstanpy.cmdstan_args import (
Method,
SamplerArgs,
Expand Down Expand Up @@ -30,7 +28,6 @@ def test_args_algorithm_init_alpha(self):
args.validate()
cmd = args.compose(None, cmd=['output'])


self.assertIn('init_alpha=0.0002', ' '.join(cmd))
args = OptimizeArgs(init_alpha=-1.0)
self.assertRaises(ValueError, lambda: args.validate())
Expand Down Expand Up @@ -155,12 +152,14 @@ def test_adapt(self):
args = SamplerArgs(adapt_engaged=False)
args.validate(chains=4)
cmd = args.compose(1, cmd=[])
self.assertIn('method=sample algorithm=hmc adapt engaged=0', ' '.join(cmd))
self.assertIn('method=sample algorithm=hmc adapt engaged=0',
' '.join(cmd))

args = SamplerArgs(adapt_engaged=True)
args.validate(chains=4)
cmd = args.compose(1, cmd=[])
self.assertIn('method=sample algorithm=hmc adapt engaged=1', ' '.join(cmd))
self.assertIn('method=sample algorithm=hmc adapt engaged=1',
' '.join(cmd))

args = SamplerArgs()
args.validate(chains=4)
Expand All @@ -172,22 +171,26 @@ def test_metric(self):
args = SamplerArgs(metric='dense_e')
args.validate(chains=4)
cmd = args.compose(1, cmd=[])
self.assertIn('method=sample algorithm=hmc metric=dense_e', ' '.join(cmd))
self.assertIn('method=sample algorithm=hmc metric=dense_e',
' '.join(cmd))

args = SamplerArgs(metric='dense')
args.validate(chains=4)
cmd = args.compose(1, cmd=[])
self.assertIn('method=sample algorithm=hmc metric=dense_e', ' '.join(cmd))
self.assertIn('method=sample algorithm=hmc metric=dense_e',
' '.join(cmd))

args = SamplerArgs(metric='diag_e')
args.validate(chains=4)
cmd = args.compose(1, cmd=[])
self.assertIn('method=sample algorithm=hmc metric=diag_e', ' '.join(cmd))
self.assertIn('method=sample algorithm=hmc metric=diag_e',
' '.join(cmd))

args = SamplerArgs(metric='diag')
args.validate(chains=4)
cmd = args.compose(1, cmd=[])
self.assertIn('method=sample algorithm=hmc metric=diag_e', ' '.join(cmd))
self.assertIn('method=sample algorithm=hmc metric=diag_e',
' '.join(cmd))

args = SamplerArgs()
args.validate(chains=4)
Expand Down Expand Up @@ -511,6 +514,7 @@ def test_args_fitted_params(self):
self.assertIn('method=generate_quantities', ' '.join(cmd))
self.assertIn('fitted_params={}'.format(csv_files[0]), ' '.join(cmd))


class VariationalTest(unittest.TestCase):
def test_args_variational(self):
args = VariationalArgs()
Expand Down
8 changes: 2 additions & 6 deletions test/test_cxx_installation.py
@@ -1,11 +1,6 @@
import argparse
import os
import unittest
from unittest import mock
import platform
import shutil

from cmdstanpy import TMPDIR
from cmdstanpy import install_cxx_toolchain


Expand Down Expand Up @@ -37,7 +32,8 @@ def test_install_not_windows(self):

with self.assertRaisesRegex(
NotImplementedError,
r'Download for the C\+\+ toolchain on the current platform has not been implemented:\s*\S+',
r'Download for the C\+\+ toolchain on the current platform has not '
'been implemented:\s*\S+',
):
install_cxx_toolchain.main()

Expand Down
20 changes: 5 additions & 15 deletions test/test_generate_quantities.py
@@ -1,16 +1,8 @@
import os
import unittest

from cmdstanpy.cmdstan_args import Method, SamplerArgs, CmdStanArgs
from cmdstanpy.utils import EXTENSION
from cmdstanpy.cmdstan_args import Method
from cmdstanpy.model import CmdStanModel
from cmdstanpy.stanfit import RunSet
from contextlib import contextmanager
import logging
from multiprocessing import cpu_count
import numpy as np
import sys
from testfixtures import LogCapture

here = os.path.dirname(os.path.abspath(__file__))
datafiles_path = os.path.join(here, 'data')
Expand Down Expand Up @@ -62,15 +54,14 @@ def test_gen_quantities_csv_files(self):
bern_gqs.mcmc_sample.shape[1] +
bern_gqs.generated_quantities_pd.shape[1])



def test_gen_quantities_csv_files_bad(self):
stan = os.path.join(datafiles_path, 'bernoulli_ppc.stan')
model = CmdStanModel(stan_file=stan)
jdata = os.path.join(datafiles_path, 'bernoulli.data.json')

# synthesize list of filenames
goodfiles_path = os.path.join(datafiles_path, 'runset-bad', 'bad-draws-bern')
goodfiles_path = os.path.join(datafiles_path, 'runset-bad',
'bad-draws-bern')
csv_files = []
for i in range(4):
csv_files.append('{}-{}.csv'.format(goodfiles_path, i+1))
Expand All @@ -81,7 +72,6 @@ def test_gen_quantities_csv_files_bad(self):
mcmc_sample=csv_files
)


def test_gen_quanties_mcmc_sample(self):
stan = os.path.join(datafiles_path, 'bernoulli.stan')
bern_model = CmdStanModel(stan_file=stan)
Expand Down Expand Up @@ -123,7 +113,8 @@ def test_gen_quanties_mcmc_sample(self):
'y_rep.10',
]
self.assertEqual(bern_gqs.column_names, tuple(column_names))
self.assertEqual(bern_fit.get_drawset().shape, bern_gqs.mcmc_sample.shape)
self.assertEqual(bern_fit.get_drawset().shape,
bern_gqs.mcmc_sample.shape)
self.assertEqual(bern_gqs.sample_plus_quantities.shape[1],
bern_gqs.mcmc_sample.shape[1] +
bern_gqs.generated_quantities_pd.shape[1])
Expand All @@ -144,6 +135,5 @@ def test_sample_plus_quantities_dedup(self):
bern_gqs.mcmc_sample.shape[1])



if __name__ == '__main__':
unittest.main()
8 changes: 1 addition & 7 deletions test/test_model.py
Expand Up @@ -2,15 +2,11 @@
import pytest
import unittest

from cmdstanpy.cmdstan_args import Method, SamplerArgs, CmdStanArgs
from cmdstanpy.utils import EXTENSION
from cmdstanpy.model import CmdStanModel
from contextlib import contextmanager

from cmdstanpy.utils import cmdstan_path

import sys

here = os.path.dirname(os.path.abspath(__file__))
datafiles_path = os.path.join(here, 'data')

Expand Down Expand Up @@ -47,7 +43,7 @@ def show_cmdstan_version(self):
def test_model_good(self):
stan = os.path.join(datafiles_path, 'bernoulli.stan')
exe = os.path.join(datafiles_path, 'bernoulli' + EXTENSION)

# compile on instantiation
model = CmdStanModel(stan_file=stan)
self.assertEqual(stan, model.stan_file)
Expand Down Expand Up @@ -94,7 +90,6 @@ def test_model_bad(self):
with self.assertRaises(Exception):
model = CmdStanModel(stan_file=stan)


def test_repr(self):
stan = os.path.join(datafiles_path, 'bernoulli.stan')
model = CmdStanModel(stan_file=stan)
Expand Down Expand Up @@ -153,6 +148,5 @@ def test_model_compile_includes(self):
self.assertTrue(model3.exe_file.endswith(exe.replace('\\', '/')))



if __name__ == '__main__':
unittest.main()
17 changes: 5 additions & 12 deletions test/test_optimize.py
Expand Up @@ -3,16 +3,10 @@
import unittest
import json

from cmdstanpy.cmdstan_args import Method, OptimizeArgs, CmdStanArgs
from cmdstanpy.cmdstan_args import OptimizeArgs, CmdStanArgs
from cmdstanpy.utils import EXTENSION
from cmdstanpy.model import CmdStanModel
from cmdstanpy.stanfit import RunSet, CmdStanMLE
from contextlib import contextmanager
import logging
from multiprocessing import cpu_count
import numpy as np
import sys
from testfixtures import LogCapture

here = os.path.dirname(os.path.abspath(__file__))
datafiles_path = os.path.join(here, 'data')
Expand Down Expand Up @@ -54,7 +48,6 @@ def test_set_mle_attrs(self):
self.assertEqual(mle.column_names,('lp__', 'x', 'y'))
self.assertAlmostEqual(mle.optimized_params_dict['x'], 1, places=3)
self.assertAlmostEqual(mle.optimized_params_dict['y'], 1, places=3)



class OptimizeTest(unittest.TestCase):
Expand Down Expand Up @@ -112,7 +105,6 @@ def test_optimize_good_dict(self):
self.assertAlmostEqual(mle.optimized_params_np[0], -5, places=2)
self.assertAlmostEqual(mle.optimized_params_np[1], 0.2, places=3)


def test_optimize_rosenbrock(self):
stan = os.path.join(datafiles_path, 'optimize', 'rosenbrock.stan')
rose_model = CmdStanModel(stan_file=stan)
Expand All @@ -127,12 +119,13 @@ def test_optimize_rosenbrock(self):
self.assertAlmostEqual(mle.optimized_params_dict['x'], 1, places=3)
self.assertAlmostEqual(mle.optimized_params_dict['y'], 1, places=3)


def test_optimize_bad(self):
stan = os.path.join(datafiles_path, 'optimize', 'exponential_boundary.stan')
stan = os.path.join(datafiles_path, 'optimize',
'exponential_boundary.stan')
exp_bound_model = CmdStanModel(stan_file=stan)
no_data = {}
with self.assertRaisesRegex(Exception, 'Error during optimizing, error code 70'):
with self.assertRaisesRegex(Exception,
'Error during optimizing, error code 70'):
exp_bound_model.optimize(
data=no_data,
seed=1239812093,
Expand Down
10 changes: 5 additions & 5 deletions test/test_sample.py
Expand Up @@ -16,6 +16,7 @@
goodfiles_path = os.path.join(datafiles_path, 'runset-good')
badfiles_path = os.path.join(datafiles_path, 'runset-bad')


class SampleTest(unittest.TestCase):

@pytest.fixture(scope="class", autouse=True)
Expand Down Expand Up @@ -129,9 +130,6 @@ def test_init_types(self):
inits=-1
)




def test_bernoulli_bad(self):
stan = os.path.join(datafiles_path, 'bernoulli.stan')
bern_model = CmdStanModel(stan_file=stan)
Expand Down Expand Up @@ -189,7 +187,7 @@ def test_fixed_param_good(self):
self.assertTrue(os.path.exists(txt_file))

self.assertEqual(datagen_fit.runset.chains, 1)

column_names = [
'lp__',
'accept_stat__',
Expand Down Expand Up @@ -288,7 +286,9 @@ def test_bernoulli_file_with_space(self):
self.test_bernoulli_good('bernoulli with space in name.stan')

def test_bernoulli_path_with_space(self):
self.test_bernoulli_good('path with space/bernoulli_path_with_space.stan')
self.test_bernoulli_good('path with space/'
'bernoulli_path_with_space.stan')


class CmdStanMCMCTest(unittest.TestCase):
def test_validate_good_run(self):
Expand Down
7 changes: 2 additions & 5 deletions test/test_utils.py
Expand Up @@ -2,7 +2,6 @@
import os
import unittest
import platform
import tempfile
import shutil
import string
import random
Expand All @@ -11,7 +10,6 @@

from cmdstanpy import TMPDIR
from cmdstanpy.utils import (
EXTENSION,
cmdstan_path,
set_cmdstan_path,
validate_cmdstan_path,
Expand Down Expand Up @@ -150,7 +148,6 @@ def test_jsondump(self):
self.assertEqual(json.load(fp), dict_zero_matrix)



class ReadStanCsvTest(unittest.TestCase):
def test_check_sampler_csv_1(self):
csv_good = os.path.join(datafiles_path, 'bernoulli_output_1.csv')
Expand Down Expand Up @@ -380,7 +377,7 @@ def test_roundtrip_metric(self):
os.remove(dfile_tmp)

def test_parse_rdump_value(self):
s1 = 'structure(c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16),.Dim=c(2,8))'
s1 = 'structure(c(1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16),.Dim=c(2,8))'
v_s1 = parse_rdump_value(s1)
self.assertEqual(v_s1.shape, (2, 8))
self.assertEqual(v_s1[1, 0], 2)
Expand All @@ -390,7 +387,7 @@ def test_parse_rdump_value(self):
v_s2 = parse_rdump_value(s2)
self.assertEqual(v_s2.shape, (1, 16))

s3 = 'structure(c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16),.Dim = c(8, 2))'
s3 = 'structure(c(1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16),.Dim=c(8,2))'
v_s3 = parse_rdump_value(s3)
self.assertEqual(v_s3.shape, (8, 2))
self.assertEqual(v_s3[1, 0], 2)
Expand Down