Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions pympipool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pympipool.shell.executor import SubprocessExecutor
from pympipool.shell.interactive import ShellExecutor
from pympipool.shared.dependencies import ExecutorWithDependencies
from pympipool.shared.inputcheck import check_refresh_rate as _check_refresh_rate


__version__ = get_versions()["version"]
Expand Down Expand Up @@ -149,10 +150,7 @@ def __new__(
refresh_rate=refresh_rate,
)
else:
if refresh_rate != 0.01:
raise ValueError(
"The sleep_interval parameter is only used when disable_dependencies=False."
)
_check_refresh_rate(refresh_rate=refresh_rate)
return create_executor(
max_cores=max_cores,
cores_per_worker=cores_per_worker,
Expand Down
15 changes: 5 additions & 10 deletions pympipool/scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
check_threads_per_core,
check_oversubscribe,
check_executor,
check_backend,
check_init_function,
)
from pympipool.scheduler.slurm import (
PySlurmExecutor,
Expand Down Expand Up @@ -79,16 +81,9 @@ def create_executor(
command_line_argument_lst (list): Additional command line arguments for the srun call (SLURM only)

"""
if not block_allocation and init_function is not None:
raise ValueError("")
if backend not in ["auto", "mpi", "slurm", "flux"]:
raise ValueError(
'The currently implemented backends are ["flux", "mpi", "slurm"]. '
'Alternatively, you can select "auto", the default option, to automatically determine the backend. But '
+ backend
+ " is not a valid choice."
)
elif backend == "flux" or (backend == "auto" and flux_installed):
check_init_function(block_allocation=block_allocation, init_function=init_function)
check_backend(backend=backend)
if backend == "flux" or (backend == "auto" and flux_installed):
check_oversubscribe(oversubscribe=oversubscribe)
check_command_line_argument_lst(
command_line_argument_lst=command_line_argument_lst
Expand Down
22 changes: 22 additions & 0 deletions pympipool/shared/inputcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,25 @@ def check_resource_dict_is_empty(resource_dict):
raise ValueError(
"When block_allocation is enabled, the resource requirements have to be defined on the executor level."
)


def check_refresh_rate(refresh_rate):
if refresh_rate != 0.01:
raise ValueError(
"The sleep_interval parameter is only used when disable_dependencies=False."
)


def check_backend(backend):
if backend not in ["auto", "mpi", "slurm", "flux"]:
raise ValueError(
'The currently implemented backends are ["flux", "mpi", "slurm"]. '
'Alternatively, you can select "auto", the default option, to automatically determine the backend. But '
+ backend
+ " is not a valid choice."
)


def check_init_function(block_allocation, init_function):
if not block_allocation and init_function is not None:
raise ValueError("")
59 changes: 59 additions & 0 deletions tests/test_shared_input_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import unittest

from pympipool.shared.inputcheck import (
check_command_line_argument_lst,
check_gpus_per_worker,
check_threads_per_core,
check_oversubscribe,
check_executor,
check_backend,
check_init_function,
check_refresh_rate,
check_resource_dict,
check_resource_dict_is_empty,
)


class TestInputCheck(unittest.TestCase):
def test_check_command_line_argument_lst(self):
with self.assertRaises(ValueError):
check_command_line_argument_lst(command_line_argument_lst=["a"])

def test_check_gpus_per_worker(self):
with self.assertRaises(TypeError):
check_gpus_per_worker(gpus_per_worker=1)

def test_check_threads_per_core(self):
with self.assertRaises(TypeError):
check_threads_per_core(threads_per_core=2)

def test_check_oversubscribe(self):
with self.assertRaises(ValueError):
check_oversubscribe(oversubscribe=True)

def test_check_executor(self):
with self.assertRaises(ValueError):
check_executor(executor=1)

def test_check_backend(self):
with self.assertRaises(ValueError):
check_backend(backend="test")

def test_check_init_function(self):
with self.assertRaises(ValueError):
check_init_function(init_function=1, block_allocation=False)

def test_check_refresh_rate(self):
with self.assertRaises(ValueError):
check_refresh_rate(refresh_rate=1)

def test_check_resource_dict(self):
def simple_function(resource_dict):
return resource_dict

with self.assertRaises(ValueError):
check_resource_dict(function=simple_function)

def test_check_resource_dict_is_empty(self):
with self.assertRaises(ValueError):
check_resource_dict_is_empty(resource_dict={"a": 1})