diff --git a/pympipool/share/communication.py b/pympipool/share/communication.py index 51ca4fe0..ac29e4f0 100644 --- a/pympipool/share/communication.py +++ b/pympipool/share/communication.py @@ -5,10 +5,12 @@ class SocketInterface(object): - def __init__(self): + def __init__(self, queue_adapter=None, queue_adapter_kwargs=None): self._context = zmq.Context() self._socket = self._context.socket(zmq.PAIR) self._process = None + self._queue_adapter = queue_adapter + self._queue_adapter_kwargs = queue_adapter_kwargs def send_dict(self, input_dict): self._socket.send(cloudpickle.dumps(input_dict)) @@ -28,14 +30,22 @@ def send_and_receive_dict(self, input_dict): def bind_to_random_port(self): return self._socket.bind_to_random_port("tcp://*") - def bootup(self, command_lst, cwd=None): - self._process = subprocess.Popen( - args=command_lst, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - stdin=subprocess.PIPE, - cwd=cwd, - ) + def bootup(self, command_lst, cwd=None, cores=None): + if self._queue_adapter is not None: + self._queue_adapter.submit_job( + working_directory=cwd, + cores=cores, + command=" ".join(command_lst), + **self._queue_adapter_kwargs + ) + else: + self._process = subprocess.Popen( + args=command_lst, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + stdin=subprocess.PIPE, + cwd=cwd, + ) def shutdown(self, wait=True): result = None @@ -44,6 +54,10 @@ def shutdown(self, wait=True): input_dict={"shutdown": True, "wait": wait} ) self._process_close(wait=wait) + elif self._queue_adapter is not None and self._socket is not None: + result = self.send_and_receive_dict( + input_dict={"shutdown": True, "wait": wait} + ) if self._socket is not None: self._socket.close() if self._context is not None: diff --git a/pympipool/share/executor.py b/pympipool/share/executor.py index 8851c67f..6b8c6aaa 100644 --- a/pympipool/share/executor.py +++ b/pympipool/share/executor.py @@ -96,6 +96,8 @@ def __init__( enable_flux_backend=False, init_function=None, cwd=None, + queue_adapter=None, + queue_adapter_kwargs=None, ): super().__init__() self._process = Thread( @@ -106,6 +108,8 @@ def __init__( "oversubscribe": oversubscribe, "enable_flux_backend": enable_flux_backend, "cwd": cwd, + "queue_adapter": queue_adapter, + "queue_adapter_kwargs": queue_adapter_kwargs, }, ) self._process.start() @@ -123,6 +127,8 @@ def __init__( enable_flux_backend=False, cwd=None, sleep_interval=0.1, + queue_adapter=None, + queue_adapter_kwargs=None, ): super().__init__() self._process = Thread( @@ -134,6 +140,8 @@ def __init__( "enable_flux_backend": enable_flux_backend, "cwd": cwd, "sleep_interval": sleep_interval, + "queue_adapter": queue_adapter, + "queue_adapter_kwargs": queue_adapter_kwargs, }, ) self._process.start() diff --git a/pympipool/share/pool.py b/pympipool/share/pool.py index bbf8d610..dea4bf92 100644 --- a/pympipool/share/pool.py +++ b/pympipool/share/pool.py @@ -3,9 +3,11 @@ class PoolBase(object): - def __init__(self): + def __init__(self, queue_adapter=None, queue_adapter_kwargs=None): self._future_dict = {} - self._interface = SocketInterface() + self._interface = SocketInterface( + queue_adapter=queue_adapter, queue_adapter_kwargs=queue_adapter_kwargs + ) cloudpickle_register(ind=3) def __enter__(self): @@ -51,8 +53,12 @@ def __init__( oversubscribe=False, enable_flux_backend=False, cwd=None, + queue_adapter=None, + queue_adapter_kwargs=None, ): - super().__init__() + super().__init__( + queue_adapter=queue_adapter, queue_adapter_kwargs=queue_adapter_kwargs + ) self._interface.bootup( command_lst=get_parallel_subprocess_command( port_selected=self._interface.bind_to_random_port(), @@ -61,8 +67,10 @@ def __init__( oversubscribe=oversubscribe, enable_flux_backend=enable_flux_backend, enable_mpi4py_backend=True, + enable_multi_host=queue_adapter is not None, ), cwd=cwd, + cores=max_workers, ) def map(self, func, iterable, chunksize=None): @@ -145,8 +153,12 @@ def __init__( ranks_per_task=1, oversubscribe=False, cwd=None, + queue_adapter=None, + queue_adapter_kwargs=None, ): - super().__init__() + super().__init__( + queue_adapter=queue_adapter, queue_adapter_kwargs=queue_adapter_kwargs + ) self._interface.bootup( command_lst=get_parallel_subprocess_command( port_selected=self._interface.bind_to_random_port(), @@ -155,8 +167,10 @@ def __init__( oversubscribe=oversubscribe, enable_flux_backend=False, enable_mpi4py_backend=True, + enable_multi_host=queue_adapter is not None, ), cwd=cwd, + cores=max_ranks, ) def map(self, func, iterable, chunksize=None): diff --git a/pympipool/share/serial.py b/pympipool/share/serial.py index a17d91a5..f7d56e5f 100644 --- a/pympipool/share/serial.py +++ b/pympipool/share/serial.py @@ -18,6 +18,7 @@ def command_line_options( oversubscribe=False, enable_flux_backend=False, enable_mpi4py_backend=True, + enable_multi_host=False, ): if enable_flux_backend: command_lst = ["flux", "run"] @@ -34,7 +35,7 @@ def command_line_options( else: command_lst += ["-n", str(cores), "python"] command_lst += [path] - if enable_flux_backend: + if enable_flux_backend or enable_multi_host: command_lst += [ "--host", hostname, @@ -60,6 +61,7 @@ def get_parallel_subprocess_command( oversubscribe=False, enable_flux_backend=False, enable_mpi4py_backend=True, + enable_multi_host=False, ): if enable_mpi4py_backend: executable = "mpipool.py" @@ -74,6 +76,7 @@ def get_parallel_subprocess_command( oversubscribe=oversubscribe, enable_flux_backend=enable_flux_backend, enable_mpi4py_backend=enable_mpi4py_backend, + enable_multi_host=enable_multi_host, ) return command_lst @@ -116,9 +119,17 @@ def execute_serial_tasks_loop(interface, future_queue, future_dict, sleep_interv def execute_parallel_tasks( - future_queue, cores, oversubscribe=False, enable_flux_backend=False, cwd=None + future_queue, + cores, + oversubscribe=False, + enable_flux_backend=False, + cwd=None, + queue_adapter=None, + queue_adapter_kwargs=None, ): - interface = SocketInterface() + interface = SocketInterface( + queue_adapter=queue_adapter, queue_adapter_kwargs=queue_adapter_kwargs + ) interface.bootup( command_lst=get_parallel_subprocess_command( port_selected=interface.bind_to_random_port(), @@ -127,8 +138,10 @@ def execute_parallel_tasks( oversubscribe=oversubscribe, enable_flux_backend=enable_flux_backend, enable_mpi4py_backend=False, + enable_multi_host=queue_adapter is not None, ), cwd=cwd, + cores=cores, ) execute_parallel_tasks_loop(interface=interface, future_queue=future_queue) @@ -140,9 +153,13 @@ def execute_serial_tasks( enable_flux_backend=False, cwd=None, sleep_interval=0.1, + queue_adapter=None, + queue_adapter_kwargs=None, ): future_dict = {} - interface = SocketInterface() + interface = SocketInterface( + queue_adapter=queue_adapter, queue_adapter_kwargs=queue_adapter_kwargs + ) interface.bootup( command_lst=get_parallel_subprocess_command( port_selected=interface.bind_to_random_port(), @@ -151,8 +168,10 @@ def execute_serial_tasks( oversubscribe=oversubscribe, enable_flux_backend=enable_flux_backend, enable_mpi4py_backend=True, + enable_multi_host=queue_adapter is not None, ), cwd=cwd, + cores=cores, ) execute_serial_tasks_loop( interface=interface, diff --git a/tests/test_interface.py b/tests/test_interface.py index 621ddd8b..6ae8aa97 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -12,7 +12,7 @@ class TestInterface(unittest.TestCase): def test_interface(self): cloudpickle_register(ind=1) task_dict = {"fn": calc, 'args': (), "kwargs": {"i": 2}} - interface = SocketInterface() + interface = SocketInterface(queue_adapter=None, queue_adapter_kwargs=None) interface.bootup( command_lst=get_parallel_subprocess_command( port_selected=interface.bind_to_random_port(),