Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/stan-dev/cmdstanpy
Browse files Browse the repository at this point in the history
  • Loading branch information
mitzimorris committed Aug 4, 2020
2 parents 1bb4010 + e0cb9a2 commit ac80e60
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 35 deletions.
32 changes: 32 additions & 0 deletions cmdstanpy/install_cmdstan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Optional command line arguments:
-v, --version : version, defaults to latest
-d, --dir : install directory, defaults to '~/.cmdstanpy
-c --compiler : add C++ compiler to path (Windows only)
"""
import argparse
import contextlib
Expand Down Expand Up @@ -186,6 +187,13 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument('--version', '-v')
parser.add_argument('--dir', '-d')
if platform.system() == 'Windows':
# use compiler installed with install_cxx_toolchain
# Install a new compiler if compiler not found
# Search order is RTools40, RTools35
parser.add_argument(
'--compiler', '-c', dest='compiler', action='store_true'
)
args = parser.parse_args(sys.argv[1:])

version = vars(args)['version']
Expand All @@ -199,6 +207,30 @@ def main():
validate_dir(install_dir)
print('Install directory: {}'.format(install_dir))

if platform.system() == 'Windows' and vars(args)['compiler']:
from .install_cxx_toolchain import (
main as _main_cxx,
is_installed as _is_installed_cxx,
)
from .utils import cxx_toolchain_path

cxx_loc = os.path.expanduser(os.path.join('~', '.cmdstanpy'))
compiler_found = False
for cxx_version in ['40', '35']:
if _is_installed_cxx(cxx_loc, cxx_version):
compiler_found = True
break
if not compiler_found:
print('Installing RTools40')
# copy argv and clear sys.argv
original_argv = sys.argv[:]
sys.argv = sys.argv[:1]
_main_cxx()
sys.argv = original_argv
cxx_version = '40'
# Add toolchain to $PATH
cxx_toolchain_path(cxx_version)

cmdstan_version = 'cmdstan-{}'.format(version)
with pushd(install_dir):
if not os.path.exists(cmdstan_version):
Expand Down
10 changes: 7 additions & 3 deletions cmdstanpy/install_cxx_toolchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,12 @@ def install_mingw32_make(toolchain_loc):
list(
OrderedDict.fromkeys(
[
os.path.join(
toolchain_loc,
'mingw_64' if IS_64BITS else 'mingw_32',
'bin',
),
os.path.join(toolchain_loc, 'usr', 'bin'),
os.path.join(toolchain_loc, 'mingw64', 'bin'),
]
+ os.environ.get('PATH', '').split(';')
)
Expand Down Expand Up @@ -152,7 +156,7 @@ def install_mingw32_make(toolchain_loc):
def is_installed(toolchain_loc, version):
"""Returns True is toolchain is installed."""
if platform.system() == 'Windows':
if version == '3.5':
if version in ['35', '3.5']:
if not os.path.exists(os.path.join(toolchain_loc, 'bin')):
return False
return os.path.exists(
Expand All @@ -163,7 +167,7 @@ def is_installed(toolchain_loc, version):
'g++' + EXTENSION,
)
)
elif version == '4.0':
elif version in ['40', '4.0', '4']:
return os.path.exists(
os.path.join(
toolchain_loc,
Expand Down
14 changes: 13 additions & 1 deletion cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class CmdStanModel:
model given data.
- By default, compiles model on instantiation - override with argument
``compile=False``
- By default, property ``name`` corresponds to basename of the Stan program
or exe file - override with argument ``model_name=<name>``.
"""

def __init__(
Expand All @@ -72,6 +74,12 @@ def __init__(
self._logger = logger or get_logger()

if model_name is not None:
if not model_name.strip():
raise ValueError(
'Invalid value for argument model name, found "{}"'.format(
model_name
)
)
self._name = model_name.strip()

if stan_file is None:
Expand Down Expand Up @@ -158,7 +166,11 @@ def __repr__(self) -> str:

@property
def name(self) -> str:
"""Stan program name; corresponds to bare filename, no extension."""
"""
Model name used in output filename templates. Default is basename
of Stan program or exe file, unless specified in call to constructor
via argument `model_name`.
"""
return self._name

@property
Expand Down
69 changes: 41 additions & 28 deletions cmdstanpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,112 +214,125 @@ def cxx_toolchain_path(version: str = None) -> Tuple[str]:
toolchain_root = ''
if 'CMDSTAN_TOOLCHAIN' in os.environ:
toolchain_root = os.environ['CMDSTAN_TOOLCHAIN']
if os.path.exists(os.path.join(toolchain_root, 'mingw_64')):
if os.path.exists(os.path.join(toolchain_root, 'mingw64')):
compiler_path = os.path.join(
toolchain_root,
'mingw_64' if (sys.maxsize > 2 ** 32) else 'mingw_32',
'mingw64' if (sys.maxsize > 2 ** 32) else 'mingw32',
'bin',
)
if os.path.exists(compiler_path):
tool_path = os.path.join(toolchain_root, 'bin')
tool_path = os.path.join(toolchain_root, 'usr', 'bin')
if not os.path.exists(tool_path):
tool_path = ''
compiler_path = ''
logger.warning(
'Found invalid installion for RTools35 on %s',
'Found invalid installion for RTools40 on %s',
toolchain_root,
)
toolchain_root = ''
else:
compiler_path = ''
logger.warning(
'Found invalid installion for RTools35 on %s',
'Found invalid installion for RTools40 on %s',
toolchain_root,
)
toolchain_root = ''
elif os.path.exists(os.path.join(toolchain_root, 'mingw64')):

elif os.path.exists(os.path.join(toolchain_root, 'mingw_64')):
compiler_path = os.path.join(
toolchain_root,
'mingw64' if (sys.maxsize > 2 ** 32) else 'mingw32',
'mingw_64' if (sys.maxsize > 2 ** 32) else 'mingw_32',
'bin',
)
if os.path.exists(compiler_path):
tool_path = os.path.join(toolchain_root, 'usr', 'bin')
tool_path = os.path.join(toolchain_root, 'bin')
if not os.path.exists(tool_path):
tool_path = ''
compiler_path = ''
logger.warning(
'Found invalid installion for RTools40 on %s',
'Found invalid installion for RTools35 on %s',
toolchain_root,
)
toolchain_root = ''
else:
compiler_path = ''
logger.warning(
'Found invalid installion for RTools40 on %s',
'Found invalid installion for RTools35 on %s',
toolchain_root,
)
toolchain_root = ''
else:
rtools_dir = os.path.expanduser(
os.path.join('~', '.cmdstanpy', 'RTools')
os.path.join('~', '.cmdstanpy', 'RTools40')
)
if not os.path.exists(rtools_dir):
raise ValueError(
'no RTools installation found, '
'run command line script "install_cxx_toolchain"'
rtools_dir = os.path.expanduser(
os.path.join('~', '.cmdstanpy', 'RTools35')
)
if not os.path.exists(rtools_dir):
rtools_dir = os.path.expanduser(
os.path.join('~', '.cmdstanpy', 'RTools')
)
if not os.path.exists(rtools_dir):
raise ValueError(
'no RTools installation found, '
'run command line script "install_cxx_toolchain"'
)
else:
rtools_dir = os.path.expanduser(os.path.join('~', '.cmdstanpy'))
else:
rtools_dir = os.path.expanduser(os.path.join('~', '.cmdstanpy'))
compiler_path = ''
tool_path = ''
if version not in ('4', '40', '4.0') and os.path.exists(
os.path.join(rtools_dir, 'RTools35')
if version not in ('35', '3.5', '3') and os.path.exists(
os.path.join(rtools_dir, 'RTools40')
):
toolchain_root = os.path.join(rtools_dir, 'RTools35')
toolchain_root = os.path.join(rtools_dir, 'RTools40')
compiler_path = os.path.join(
toolchain_root,
'mingw_64' if (sys.maxsize > 2 ** 32) else 'mingw_32',
'mingw64' if (sys.maxsize > 2 ** 32) else 'mingw32',
'bin',
)
if os.path.exists(compiler_path):
tool_path = os.path.join(toolchain_root, 'bin')
tool_path = os.path.join(toolchain_root, 'usr', 'bin')
if not os.path.exists(tool_path):
tool_path = ''
compiler_path = ''
logger.warning(
'Found invalid installion for RTools35 on %s',
'Found invalid installation for RTools40 on %s',
toolchain_root,
)
toolchain_root = ''
else:
compiler_path = ''
logger.warning(
'Found invalid installion for RTools35 on %s',
'Found invalid installation for RTools40 on %s',
toolchain_root,
)
toolchain_root = ''
if (
not toolchain_root or version in ('4', '40', '4.0')
) and os.path.exists(os.path.join(rtools_dir, 'RTools40')):
toolchain_root = os.path.join(rtools_dir, 'RTools40')
) and os.path.exists(os.path.join(rtools_dir, 'RTools35')):
toolchain_root = os.path.join(rtools_dir, 'RTools35')
compiler_path = os.path.join(
toolchain_root,
'mingw64' if (sys.maxsize > 2 ** 32) else 'mingw32',
'mingw_64' if (sys.maxsize > 2 ** 32) else 'mingw_32',
'bin',
)
if os.path.exists(compiler_path):
tool_path = os.path.join(toolchain_root, 'usr', 'bin')
tool_path = os.path.join(toolchain_root, 'bin')
if not os.path.exists(tool_path):
tool_path = ''
compiler_path = ''
logger.warning(
'Found invalid installion for RTools40 on %s',
'Found invalid installation for RTools35 on %s',
toolchain_root,
)
toolchain_root = ''
else:
compiler_path = ''
logger.warning(
'Found invalid installion for RTools40 on %s',
'Found invalid installation for RTools35 on %s',
toolchain_root,
)
toolchain_root = ''
Expand All @@ -328,7 +341,7 @@ def cxx_toolchain_path(version: str = None) -> Tuple[str]:
'no C++ toolchain installation found, '
'run command line script "install_cxx_toolchain"'
)
logger.info('Adds C++ toolchain to $PATH: %s', toolchain_root)
logger.info('Add C++ toolchain to $PATH: %s', toolchain_root)
os.environ['PATH'] = ';'.join(
list(
OrderedDict.fromkeys(
Expand Down
13 changes: 10 additions & 3 deletions test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

BERN_STAN = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
BERN_EXE = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
BERN_BASENAME = 'bernoulli'


class CmdStanModelTest(unittest.TestCase):
Expand All @@ -50,22 +51,24 @@ def show_cmdstan_version(self):
self.assertTrue(True)

def test_model_good(self):
# compile on instantiation
# compile on instantiation, override model name
model = CmdStanModel(model_name='bern', stan_file=BERN_STAN)
self.assertEqual(BERN_STAN, model.stan_file)
self.assertTrue(model.exe_file.endswith(BERN_EXE.replace('\\', '/')))
self.assertEqual('bern', model.name)

# default model name
model = CmdStanModel(stan_file=BERN_STAN)
self.assertEqual(BERN_BASENAME, model.name)

# instantiate with existing exe
model = CmdStanModel(stan_file=BERN_STAN, exe_file=BERN_EXE)
self.assertEqual(BERN_STAN, model.stan_file)
self.assertTrue(model.exe_file.endswith(BERN_EXE))
self.assertEqual('bernoulli', model.name)

# instantiate with existing exe only - no model
model2 = CmdStanModel(exe_file=BERN_EXE)
self.assertEqual(BERN_EXE, model2.exe_file)
self.assertEqual('bernoulli', model2.name)
with self.assertRaises(RuntimeError):
model2.code()
with self.assertRaises(RuntimeError):
Expand All @@ -82,6 +85,10 @@ def test_model_bad(self):
CmdStanModel(stan_file=None, exe_file=None)
with self.assertRaises(ValueError):
CmdStanModel(model_name='bad')
with self.assertRaises(ValueError):
CmdStanModel(model_name='', stan_file=BERN_STAN)
with self.assertRaises(ValueError):
CmdStanModel(model_name=' ', stan_file=BERN_STAN)

def test_stanc_options(self):
opts = {
Expand Down

0 comments on commit ac80e60

Please sign in to comment.