From 9b2884ad76f90f647831580680d8590808bc0120 Mon Sep 17 00:00:00 2001 From: Jan Janssen Date: Wed, 29 May 2024 11:36:12 +0200 Subject: [PATCH 1/2] Add validate backend fucntion --- pympipool/scheduler/__init__.py | 8 ++++---- pympipool/shared/inputcheck.py | 10 +++++++++- tests/test_shared_input_check.py | 4 ++-- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/pympipool/scheduler/__init__.py b/pympipool/scheduler/__init__.py index 3138bc2f..23a1d720 100644 --- a/pympipool/scheduler/__init__.py +++ b/pympipool/scheduler/__init__.py @@ -12,8 +12,8 @@ check_threads_per_core, check_oversubscribe, check_executor, - check_backend, check_init_function, + validate_backend, validate_number_of_cores, ) from pympipool.scheduler.slurm import ( @@ -87,8 +87,8 @@ def create_executor( """ max_cores = validate_number_of_cores(max_cores=max_cores, max_workers=max_workers) check_init_function(block_allocation=block_allocation, init_function=init_function) - check_backend(backend=backend) - if backend == "flux" or (backend == "auto" and flux_installed): + backend = validate_backend(backend=backend, flux_installed=flux_installed, slurm_installed=slurm_installed) + if backend == "flux": check_oversubscribe(oversubscribe=oversubscribe) check_command_line_argument_lst( command_line_argument_lst=command_line_argument_lst @@ -114,7 +114,7 @@ def create_executor( executor=executor, hostname_localhost=hostname_localhost, ) - elif backend == "slurm" or (backend == "auto" and slurm_installed): + elif backend == "slurm": check_executor(executor=executor) if block_allocation: return PySlurmExecutor( diff --git a/pympipool/shared/inputcheck.py b/pympipool/shared/inputcheck.py index 11595dfe..036b8c20 100644 --- a/pympipool/shared/inputcheck.py +++ b/pympipool/shared/inputcheck.py @@ -67,7 +67,9 @@ def check_refresh_rate(refresh_rate: float): ) -def check_backend(backend: str): +def validate_backend( + backend: str, flux_installed: bool = False, slurm_installed: bool = False +) -> str: if backend not in ["auto", "mpi", "slurm", "flux"]: raise ValueError( 'The currently implemented backends are ["flux", "mpi", "slurm"]. ' @@ -75,6 +77,12 @@ def check_backend(backend: str): + backend + " is not a valid choice." ) + elif backend == "flux" or (backend == "auto" and flux_installed): + return "flux" + elif backend == "slurm" or (backend == "auto" and slurm_installed): + return "slurm" + else: + return "mpi" def check_init_function(block_allocation: bool, init_function: callable): diff --git a/tests/test_shared_input_check.py b/tests/test_shared_input_check.py index bf1076f2..65f8c2af 100644 --- a/tests/test_shared_input_check.py +++ b/tests/test_shared_input_check.py @@ -6,11 +6,11 @@ check_threads_per_core, check_oversubscribe, check_executor, - check_backend, check_init_function, check_refresh_rate, check_resource_dict, check_resource_dict_is_empty, + validate_backend, ) @@ -37,7 +37,7 @@ def test_check_executor(self): def test_check_backend(self): with self.assertRaises(ValueError): - check_backend(backend="test") + validate_backend(backend="test", slurm_installed=False, flux_installed=False) def test_check_init_function(self): with self.assertRaises(ValueError): From 5c05f5f31aae4727f1aae996873c3a58a884c89f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 May 2024 09:36:29 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pympipool/scheduler/__init__.py | 4 +++- tests/test_shared_input_check.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pympipool/scheduler/__init__.py b/pympipool/scheduler/__init__.py index 23a1d720..118e6b46 100644 --- a/pympipool/scheduler/__init__.py +++ b/pympipool/scheduler/__init__.py @@ -87,7 +87,9 @@ def create_executor( """ max_cores = validate_number_of_cores(max_cores=max_cores, max_workers=max_workers) check_init_function(block_allocation=block_allocation, init_function=init_function) - backend = validate_backend(backend=backend, flux_installed=flux_installed, slurm_installed=slurm_installed) + backend = validate_backend( + backend=backend, flux_installed=flux_installed, slurm_installed=slurm_installed + ) if backend == "flux": check_oversubscribe(oversubscribe=oversubscribe) check_command_line_argument_lst( diff --git a/tests/test_shared_input_check.py b/tests/test_shared_input_check.py index 65f8c2af..ad322193 100644 --- a/tests/test_shared_input_check.py +++ b/tests/test_shared_input_check.py @@ -37,7 +37,9 @@ def test_check_executor(self): def test_check_backend(self): with self.assertRaises(ValueError): - validate_backend(backend="test", slurm_installed=False, flux_installed=False) + validate_backend( + backend="test", slurm_installed=False, flux_installed=False + ) def test_check_init_function(self): with self.assertRaises(ValueError):