diff --git a/pympipool/__init__.py b/pympipool/__init__.py index 6acc2e66..f6e0ee3d 100644 --- a/pympipool/__init__.py +++ b/pympipool/__init__.py @@ -4,6 +4,7 @@ from pympipool.shell.executor import SubprocessExecutor from pympipool.shell.interactive import ShellExecutor from pympipool.shared.dependencies import ExecutorWithDependencies +from pympipool.shared.inputcheck import check_refresh_rate as _check_refresh_rate __version__ = get_versions()["version"] @@ -149,10 +150,7 @@ def __new__( refresh_rate=refresh_rate, ) else: - if refresh_rate != 0.01: - raise ValueError( - "The sleep_interval parameter is only used when disable_dependencies=False." - ) + _check_refresh_rate(refresh_rate=refresh_rate) return create_executor( max_cores=max_cores, cores_per_worker=cores_per_worker, diff --git a/pympipool/scheduler/__init__.py b/pympipool/scheduler/__init__.py index 10d6b798..ff8d2f00 100644 --- a/pympipool/scheduler/__init__.py +++ b/pympipool/scheduler/__init__.py @@ -12,6 +12,8 @@ check_threads_per_core, check_oversubscribe, check_executor, + check_backend, + check_init_function, ) from pympipool.scheduler.slurm import ( PySlurmExecutor, @@ -79,16 +81,9 @@ def create_executor( command_line_argument_lst (list): Additional command line arguments for the srun call (SLURM only) """ - if not block_allocation and init_function is not None: - raise ValueError("") - if backend not in ["auto", "mpi", "slurm", "flux"]: - raise ValueError( - 'The currently implemented backends are ["flux", "mpi", "slurm"]. ' - 'Alternatively, you can select "auto", the default option, to automatically determine the backend. But ' - + backend - + " is not a valid choice." - ) - elif backend == "flux" or (backend == "auto" and flux_installed): + check_init_function(block_allocation=block_allocation, init_function=init_function) + check_backend(backend=backend) + if backend == "flux" or (backend == "auto" and flux_installed): check_oversubscribe(oversubscribe=oversubscribe) check_command_line_argument_lst( command_line_argument_lst=command_line_argument_lst diff --git a/pympipool/shared/inputcheck.py b/pympipool/shared/inputcheck.py index e4a53870..69e6f87e 100644 --- a/pympipool/shared/inputcheck.py +++ b/pympipool/shared/inputcheck.py @@ -56,3 +56,25 @@ def check_resource_dict_is_empty(resource_dict): raise ValueError( "When block_allocation is enabled, the resource requirements have to be defined on the executor level." ) + + +def check_refresh_rate(refresh_rate): + if refresh_rate != 0.01: + raise ValueError( + "The sleep_interval parameter is only used when disable_dependencies=False." + ) + + +def check_backend(backend): + if backend not in ["auto", "mpi", "slurm", "flux"]: + raise ValueError( + 'The currently implemented backends are ["flux", "mpi", "slurm"]. ' + 'Alternatively, you can select "auto", the default option, to automatically determine the backend. But ' + + backend + + " is not a valid choice." + ) + + +def check_init_function(block_allocation, init_function): + if not block_allocation and init_function is not None: + raise ValueError("") diff --git a/tests/test_shared_input_check.py b/tests/test_shared_input_check.py new file mode 100644 index 00000000..bf1076f2 --- /dev/null +++ b/tests/test_shared_input_check.py @@ -0,0 +1,59 @@ +import unittest + +from pympipool.shared.inputcheck import ( + check_command_line_argument_lst, + check_gpus_per_worker, + check_threads_per_core, + check_oversubscribe, + check_executor, + check_backend, + check_init_function, + check_refresh_rate, + check_resource_dict, + check_resource_dict_is_empty, +) + + +class TestInputCheck(unittest.TestCase): + def test_check_command_line_argument_lst(self): + with self.assertRaises(ValueError): + check_command_line_argument_lst(command_line_argument_lst=["a"]) + + def test_check_gpus_per_worker(self): + with self.assertRaises(TypeError): + check_gpus_per_worker(gpus_per_worker=1) + + def test_check_threads_per_core(self): + with self.assertRaises(TypeError): + check_threads_per_core(threads_per_core=2) + + def test_check_oversubscribe(self): + with self.assertRaises(ValueError): + check_oversubscribe(oversubscribe=True) + + def test_check_executor(self): + with self.assertRaises(ValueError): + check_executor(executor=1) + + def test_check_backend(self): + with self.assertRaises(ValueError): + check_backend(backend="test") + + def test_check_init_function(self): + with self.assertRaises(ValueError): + check_init_function(init_function=1, block_allocation=False) + + def test_check_refresh_rate(self): + with self.assertRaises(ValueError): + check_refresh_rate(refresh_rate=1) + + def test_check_resource_dict(self): + def simple_function(resource_dict): + return resource_dict + + with self.assertRaises(ValueError): + check_resource_dict(function=simple_function) + + def test_check_resource_dict_is_empty(self): + with self.assertRaises(ValueError): + check_resource_dict_is_empty(resource_dict={"a": 1})