diff --git a/pympipool/__init__.py b/pympipool/__init__.py index e344cbe0..b0dc5e64 100644 --- a/pympipool/__init__.py +++ b/pympipool/__init__.py @@ -1,11 +1,15 @@ from ._version import get_versions from pympipool.mpi.executor import PyMPIExecutor -from pympipool.slurm.executor import PySlurmExecutor try: # The PyFluxExecutor requires flux-core to be installed. from pympipool.flux.executor import PyFluxExecutor except ImportError: pass +try: # The PySlurmExecutor requires the srun command to be available. + from pympipool.slurm.executor import PySlurmExecutor +except ImportError: + pass + __version__ = get_versions()["version"] del get_versions diff --git a/pympipool/shared/interface.py b/pympipool/shared/interface.py index e64493c3..0fee6905 100644 --- a/pympipool/shared/interface.py +++ b/pympipool/shared/interface.py @@ -2,6 +2,10 @@ import subprocess +MPI_COMMAND = "mpiexec" +SLURM_COMMAND = "srun" + + class BaseInterface(ABC): def __init__(self, cwd, cores=1, oversubscribe=False): self._cwd = cwd @@ -94,7 +98,7 @@ def generate_command(self, command_lst): def generate_mpiexec_command(cores, oversubscribe=False): - command_prepend_lst = ["mpiexec", "-n", str(cores)] + command_prepend_lst = [MPI_COMMAND, "-n", str(cores)] if oversubscribe: command_prepend_lst += ["--oversubscribe"] return command_prepend_lst @@ -103,7 +107,7 @@ def generate_mpiexec_command(cores, oversubscribe=False): def generate_slurm_command( cores, cwd, threads_per_core=1, gpus_per_core=0, oversubscribe=False ): - command_prepend_lst = ["srun", "-n", str(cores)] + command_prepend_lst = [SLURM_COMMAND, "-n", str(cores)] if cwd is not None: command_prepend_lst += ["-D", cwd] if threads_per_core > 1: diff --git a/pympipool/slurm/executor.py b/pympipool/slurm/executor.py index 2adf0264..33688506 100644 --- a/pympipool/slurm/executor.py +++ b/pympipool/slurm/executor.py @@ -1,13 +1,21 @@ +import shutil +import subprocess + + from pympipool.shared.executorbase import ( cloudpickle_register, execute_parallel_tasks, ExecutorBase, executor_broker, ) -from pympipool.shared.interface import SrunInterface +from pympipool.shared.interface import SrunInterface, SLURM_COMMAND from pympipool.shared.thread import RaisingThread +if shutil.which(SLURM_COMMAND) is None: + raise ImportError("SLURM command " + SLURM_COMMAND + " not found.") + + class PySlurmExecutor(ExecutorBase): """ Args: