diff --git a/cmdstanpy/install_cmdstan.py b/cmdstanpy/install_cmdstan.py index 82a42922..4324b1bb 100644 --- a/cmdstanpy/install_cmdstan.py +++ b/cmdstanpy/install_cmdstan.py @@ -4,6 +4,7 @@ Optional command line arguments: -v, --version : version, defaults to latest -d, --dir : install directory, defaults to '~/.cmdstanpy + -c --compiler : add C++ compiler to path (Windows only) """ import argparse import contextlib @@ -186,6 +187,13 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument('--version', '-v') parser.add_argument('--dir', '-d') + if platform.system() == 'Windows': + # use compiler installed with install_cxx_toolchain + # Install a new compiler if compiler not found + # Search order is RTools40, RTools35 + parser.add_argument( + '--compiler', '-c', dest='compiler', action='store_true' + ) args = parser.parse_args(sys.argv[1:]) version = vars(args)['version'] @@ -199,6 +207,30 @@ def main(): validate_dir(install_dir) print('Install directory: {}'.format(install_dir)) + if platform.system() == 'Windows' and vars(args)['compiler']: + from .install_cxx_toolchain import ( + main as _main_cxx, + is_installed as _is_installed_cxx, + ) + from .utils import cxx_toolchain_path + + cxx_loc = os.path.expanduser(os.path.join('~', '.cmdstanpy')) + compiler_found = False + for cxx_version in ['40', '35']: + if _is_installed_cxx(cxx_loc, cxx_version): + compiler_found = True + break + if not compiler_found: + print('Installing RTools40') + # copy argv and clear sys.argv + original_argv = sys.argv[:] + sys.argv = sys.argv[:1] + _main_cxx() + sys.argv = original_argv + cxx_version = '40' + # Add toolchain to $PATH + cxx_toolchain_path(cxx_version) + cmdstan_version = 'cmdstan-{}'.format(version) with pushd(install_dir): if not os.path.exists(cmdstan_version): diff --git a/cmdstanpy/install_cxx_toolchain.py b/cmdstanpy/install_cxx_toolchain.py index 1500ac8a..3fe3ba63 100644 --- a/cmdstanpy/install_cxx_toolchain.py +++ b/cmdstanpy/install_cxx_toolchain.py @@ -110,8 +110,12 @@ def install_mingw32_make(toolchain_loc): list( OrderedDict.fromkeys( [ + os.path.join( + toolchain_loc, + 'mingw_64' if IS_64BITS else 'mingw_32', + 'bin', + ), os.path.join(toolchain_loc, 'usr', 'bin'), - os.path.join(toolchain_loc, 'mingw64', 'bin'), ] + os.environ.get('PATH', '').split(';') ) @@ -152,7 +156,7 @@ def install_mingw32_make(toolchain_loc): def is_installed(toolchain_loc, version): """Returns True is toolchain is installed.""" if platform.system() == 'Windows': - if version == '3.5': + if version in ['35', '3.5']: if not os.path.exists(os.path.join(toolchain_loc, 'bin')): return False return os.path.exists( @@ -163,7 +167,7 @@ def is_installed(toolchain_loc, version): 'g++' + EXTENSION, ) ) - elif version == '4.0': + elif version in ['40', '4.0', '4']: return os.path.exists( os.path.join( toolchain_loc, diff --git a/cmdstanpy/model.py b/cmdstanpy/model.py index 1ee30b69..39c8b2ed 100644 --- a/cmdstanpy/model.py +++ b/cmdstanpy/model.py @@ -50,6 +50,8 @@ class CmdStanModel: model given data. - By default, compiles model on instantiation - override with argument ``compile=False`` + - By default, property ``name`` corresponds to basename of the Stan program + or exe file - override with argument ``model_name=``. """ def __init__( @@ -72,6 +74,12 @@ def __init__( self._logger = logger or get_logger() if model_name is not None: + if not model_name.strip(): + raise ValueError( + 'Invalid value for argument model name, found "{}"'.format( + model_name + ) + ) self._name = model_name.strip() if stan_file is None: @@ -158,7 +166,11 @@ def __repr__(self) -> str: @property def name(self) -> str: - """Stan program name; corresponds to bare filename, no extension.""" + """ + Model name used in output filename templates. Default is basename + of Stan program or exe file, unless specified in call to constructor + via argument `model_name`. + """ return self._name @property diff --git a/cmdstanpy/utils.py b/cmdstanpy/utils.py index d4a704a4..6e60e6a9 100644 --- a/cmdstanpy/utils.py +++ b/cmdstanpy/utils.py @@ -214,112 +214,125 @@ def cxx_toolchain_path(version: str = None) -> Tuple[str]: toolchain_root = '' if 'CMDSTAN_TOOLCHAIN' in os.environ: toolchain_root = os.environ['CMDSTAN_TOOLCHAIN'] - if os.path.exists(os.path.join(toolchain_root, 'mingw_64')): + if os.path.exists(os.path.join(toolchain_root, 'mingw64')): compiler_path = os.path.join( toolchain_root, - 'mingw_64' if (sys.maxsize > 2 ** 32) else 'mingw_32', + 'mingw64' if (sys.maxsize > 2 ** 32) else 'mingw32', 'bin', ) if os.path.exists(compiler_path): - tool_path = os.path.join(toolchain_root, 'bin') + tool_path = os.path.join(toolchain_root, 'usr', 'bin') if not os.path.exists(tool_path): tool_path = '' compiler_path = '' logger.warning( - 'Found invalid installion for RTools35 on %s', + 'Found invalid installion for RTools40 on %s', toolchain_root, ) toolchain_root = '' else: compiler_path = '' logger.warning( - 'Found invalid installion for RTools35 on %s', + 'Found invalid installion for RTools40 on %s', toolchain_root, ) toolchain_root = '' - elif os.path.exists(os.path.join(toolchain_root, 'mingw64')): + + elif os.path.exists(os.path.join(toolchain_root, 'mingw_64')): compiler_path = os.path.join( toolchain_root, - 'mingw64' if (sys.maxsize > 2 ** 32) else 'mingw32', + 'mingw_64' if (sys.maxsize > 2 ** 32) else 'mingw_32', 'bin', ) if os.path.exists(compiler_path): - tool_path = os.path.join(toolchain_root, 'usr', 'bin') + tool_path = os.path.join(toolchain_root, 'bin') if not os.path.exists(tool_path): tool_path = '' compiler_path = '' logger.warning( - 'Found invalid installion for RTools40 on %s', + 'Found invalid installion for RTools35 on %s', toolchain_root, ) toolchain_root = '' else: compiler_path = '' logger.warning( - 'Found invalid installion for RTools40 on %s', + 'Found invalid installion for RTools35 on %s', toolchain_root, ) toolchain_root = '' else: rtools_dir = os.path.expanduser( - os.path.join('~', '.cmdstanpy', 'RTools') + os.path.join('~', '.cmdstanpy', 'RTools40') ) if not os.path.exists(rtools_dir): - raise ValueError( - 'no RTools installation found, ' - 'run command line script "install_cxx_toolchain"' + rtools_dir = os.path.expanduser( + os.path.join('~', '.cmdstanpy', 'RTools35') ) + if not os.path.exists(rtools_dir): + rtools_dir = os.path.expanduser( + os.path.join('~', '.cmdstanpy', 'RTools') + ) + if not os.path.exists(rtools_dir): + raise ValueError( + 'no RTools installation found, ' + 'run command line script "install_cxx_toolchain"' + ) + else: + rtools_dir = os.path.expanduser(os.path.join('~', '.cmdstanpy')) + else: + rtools_dir = os.path.expanduser(os.path.join('~', '.cmdstanpy')) compiler_path = '' tool_path = '' - if version not in ('4', '40', '4.0') and os.path.exists( - os.path.join(rtools_dir, 'RTools35') + if version not in ('35', '3.5', '3') and os.path.exists( + os.path.join(rtools_dir, 'RTools40') ): - toolchain_root = os.path.join(rtools_dir, 'RTools35') + toolchain_root = os.path.join(rtools_dir, 'RTools40') compiler_path = os.path.join( toolchain_root, - 'mingw_64' if (sys.maxsize > 2 ** 32) else 'mingw_32', + 'mingw64' if (sys.maxsize > 2 ** 32) else 'mingw32', 'bin', ) if os.path.exists(compiler_path): - tool_path = os.path.join(toolchain_root, 'bin') + tool_path = os.path.join(toolchain_root, 'usr', 'bin') if not os.path.exists(tool_path): tool_path = '' compiler_path = '' logger.warning( - 'Found invalid installion for RTools35 on %s', + 'Found invalid installation for RTools40 on %s', toolchain_root, ) toolchain_root = '' else: compiler_path = '' logger.warning( - 'Found invalid installion for RTools35 on %s', + 'Found invalid installation for RTools40 on %s', toolchain_root, ) toolchain_root = '' if ( not toolchain_root or version in ('4', '40', '4.0') - ) and os.path.exists(os.path.join(rtools_dir, 'RTools40')): - toolchain_root = os.path.join(rtools_dir, 'RTools40') + ) and os.path.exists(os.path.join(rtools_dir, 'RTools35')): + toolchain_root = os.path.join(rtools_dir, 'RTools35') compiler_path = os.path.join( toolchain_root, - 'mingw64' if (sys.maxsize > 2 ** 32) else 'mingw32', + 'mingw_64' if (sys.maxsize > 2 ** 32) else 'mingw_32', 'bin', ) if os.path.exists(compiler_path): - tool_path = os.path.join(toolchain_root, 'usr', 'bin') + tool_path = os.path.join(toolchain_root, 'bin') if not os.path.exists(tool_path): tool_path = '' compiler_path = '' logger.warning( - 'Found invalid installion for RTools40 on %s', + 'Found invalid installation for RTools35 on %s', toolchain_root, ) toolchain_root = '' else: compiler_path = '' logger.warning( - 'Found invalid installion for RTools40 on %s', + 'Found invalid installation for RTools35 on %s', toolchain_root, ) toolchain_root = '' @@ -328,7 +341,7 @@ def cxx_toolchain_path(version: str = None) -> Tuple[str]: 'no C++ toolchain installation found, ' 'run command line script "install_cxx_toolchain"' ) - logger.info('Adds C++ toolchain to $PATH: %s', toolchain_root) + logger.info('Add C++ toolchain to $PATH: %s', toolchain_root) os.environ['PATH'] = ';'.join( list( OrderedDict.fromkeys( diff --git a/test/test_model.py b/test/test_model.py index fc563fbb..914aaba1 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -31,6 +31,7 @@ BERN_STAN = os.path.join(DATAFILES_PATH, 'bernoulli.stan') BERN_EXE = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) +BERN_BASENAME = 'bernoulli' class CmdStanModelTest(unittest.TestCase): @@ -50,22 +51,24 @@ def show_cmdstan_version(self): self.assertTrue(True) def test_model_good(self): - # compile on instantiation + # compile on instantiation, override model name model = CmdStanModel(model_name='bern', stan_file=BERN_STAN) self.assertEqual(BERN_STAN, model.stan_file) self.assertTrue(model.exe_file.endswith(BERN_EXE.replace('\\', '/'))) self.assertEqual('bern', model.name) + # default model name + model = CmdStanModel(stan_file=BERN_STAN) + self.assertEqual(BERN_BASENAME, model.name) + # instantiate with existing exe model = CmdStanModel(stan_file=BERN_STAN, exe_file=BERN_EXE) self.assertEqual(BERN_STAN, model.stan_file) self.assertTrue(model.exe_file.endswith(BERN_EXE)) - self.assertEqual('bernoulli', model.name) # instantiate with existing exe only - no model model2 = CmdStanModel(exe_file=BERN_EXE) self.assertEqual(BERN_EXE, model2.exe_file) - self.assertEqual('bernoulli', model2.name) with self.assertRaises(RuntimeError): model2.code() with self.assertRaises(RuntimeError): @@ -82,6 +85,10 @@ def test_model_bad(self): CmdStanModel(stan_file=None, exe_file=None) with self.assertRaises(ValueError): CmdStanModel(model_name='bad') + with self.assertRaises(ValueError): + CmdStanModel(model_name='', stan_file=BERN_STAN) + with self.assertRaises(ValueError): + CmdStanModel(model_name=' ', stan_file=BERN_STAN) def test_stanc_options(self): opts = {