Skip to content
32 changes: 23 additions & 9 deletions pympipool/share/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@


class SocketInterface(object):
def __init__(self):
def __init__(self, queue_adapter=None, queue_adapter_kwargs=None):
self._context = zmq.Context()
self._socket = self._context.socket(zmq.PAIR)
self._process = None
self._queue_adapter = queue_adapter
self._queue_adapter_kwargs = queue_adapter_kwargs

def send_dict(self, input_dict):
self._socket.send(cloudpickle.dumps(input_dict))
Expand All @@ -28,14 +30,22 @@ def send_and_receive_dict(self, input_dict):
def bind_to_random_port(self):
return self._socket.bind_to_random_port("tcp://*")

def bootup(self, command_lst, cwd=None):
self._process = subprocess.Popen(
args=command_lst,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=subprocess.PIPE,
cwd=cwd,
)
def bootup(self, command_lst, cwd=None, cores=None):
if self._queue_adapter is not None:
self._queue_adapter.submit_job(
working_directory=cwd,
cores=cores,
command=" ".join(command_lst),
**self._queue_adapter_kwargs
)
else:
self._process = subprocess.Popen(
args=command_lst,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=subprocess.PIPE,
cwd=cwd,
)

def shutdown(self, wait=True):
result = None
Expand All @@ -44,6 +54,10 @@ def shutdown(self, wait=True):
input_dict={"shutdown": True, "wait": wait}
)
self._process_close(wait=wait)
elif self._queue_adapter is not None and self._socket is not None:
result = self.send_and_receive_dict(
input_dict={"shutdown": True, "wait": wait}
)
if self._socket is not None:
self._socket.close()
if self._context is not None:
Expand Down
8 changes: 8 additions & 0 deletions pympipool/share/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def __init__(
enable_flux_backend=False,
init_function=None,
cwd=None,
queue_adapter=None,
queue_adapter_kwargs=None,
):
super().__init__()
self._process = Thread(
Expand All @@ -106,6 +108,8 @@ def __init__(
"oversubscribe": oversubscribe,
"enable_flux_backend": enable_flux_backend,
"cwd": cwd,
"queue_adapter": queue_adapter,
"queue_adapter_kwargs": queue_adapter_kwargs,
},
)
self._process.start()
Expand All @@ -123,6 +127,8 @@ def __init__(
enable_flux_backend=False,
cwd=None,
sleep_interval=0.1,
queue_adapter=None,
queue_adapter_kwargs=None,
):
super().__init__()
self._process = Thread(
Expand All @@ -134,6 +140,8 @@ def __init__(
"enable_flux_backend": enable_flux_backend,
"cwd": cwd,
"sleep_interval": sleep_interval,
"queue_adapter": queue_adapter,
"queue_adapter_kwargs": queue_adapter_kwargs,
},
)
self._process.start()
22 changes: 18 additions & 4 deletions pympipool/share/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@


class PoolBase(object):
def __init__(self):
def __init__(self, queue_adapter=None, queue_adapter_kwargs=None):
self._future_dict = {}
self._interface = SocketInterface()
self._interface = SocketInterface(
queue_adapter=queue_adapter, queue_adapter_kwargs=queue_adapter_kwargs
)
cloudpickle_register(ind=3)

def __enter__(self):
Expand Down Expand Up @@ -51,8 +53,12 @@ def __init__(
oversubscribe=False,
enable_flux_backend=False,
cwd=None,
queue_adapter=None,
queue_adapter_kwargs=None,
):
super().__init__()
super().__init__(
queue_adapter=queue_adapter, queue_adapter_kwargs=queue_adapter_kwargs
)
self._interface.bootup(
command_lst=get_parallel_subprocess_command(
port_selected=self._interface.bind_to_random_port(),
Expand All @@ -61,8 +67,10 @@ def __init__(
oversubscribe=oversubscribe,
enable_flux_backend=enable_flux_backend,
enable_mpi4py_backend=True,
enable_multi_host=queue_adapter is not None,
),
cwd=cwd,
cores=max_workers,
)

def map(self, func, iterable, chunksize=None):
Expand Down Expand Up @@ -145,8 +153,12 @@ def __init__(
ranks_per_task=1,
oversubscribe=False,
cwd=None,
queue_adapter=None,
queue_adapter_kwargs=None,
):
super().__init__()
super().__init__(
queue_adapter=queue_adapter, queue_adapter_kwargs=queue_adapter_kwargs
)
self._interface.bootup(
command_lst=get_parallel_subprocess_command(
port_selected=self._interface.bind_to_random_port(),
Expand All @@ -155,8 +167,10 @@ def __init__(
oversubscribe=oversubscribe,
enable_flux_backend=False,
enable_mpi4py_backend=True,
enable_multi_host=queue_adapter is not None,
),
cwd=cwd,
cores=max_ranks,
)

def map(self, func, iterable, chunksize=None):
Expand Down
27 changes: 23 additions & 4 deletions pympipool/share/serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def command_line_options(
oversubscribe=False,
enable_flux_backend=False,
enable_mpi4py_backend=True,
enable_multi_host=False,
):
if enable_flux_backend:
command_lst = ["flux", "run"]
Expand All @@ -34,7 +35,7 @@ def command_line_options(
else:
command_lst += ["-n", str(cores), "python"]
command_lst += [path]
if enable_flux_backend:
if enable_flux_backend or enable_multi_host:
command_lst += [
"--host",
hostname,
Expand All @@ -60,6 +61,7 @@ def get_parallel_subprocess_command(
oversubscribe=False,
enable_flux_backend=False,
enable_mpi4py_backend=True,
enable_multi_host=False,
):
if enable_mpi4py_backend:
executable = "mpipool.py"
Expand All @@ -74,6 +76,7 @@ def get_parallel_subprocess_command(
oversubscribe=oversubscribe,
enable_flux_backend=enable_flux_backend,
enable_mpi4py_backend=enable_mpi4py_backend,
enable_multi_host=enable_multi_host,
)
return command_lst

Expand Down Expand Up @@ -116,9 +119,17 @@ def execute_serial_tasks_loop(interface, future_queue, future_dict, sleep_interv


def execute_parallel_tasks(
future_queue, cores, oversubscribe=False, enable_flux_backend=False, cwd=None
future_queue,
cores,
oversubscribe=False,
enable_flux_backend=False,
cwd=None,
queue_adapter=None,
queue_adapter_kwargs=None,
):
interface = SocketInterface()
interface = SocketInterface(
queue_adapter=queue_adapter, queue_adapter_kwargs=queue_adapter_kwargs
)
interface.bootup(
command_lst=get_parallel_subprocess_command(
port_selected=interface.bind_to_random_port(),
Expand All @@ -127,8 +138,10 @@ def execute_parallel_tasks(
oversubscribe=oversubscribe,
enable_flux_backend=enable_flux_backend,
enable_mpi4py_backend=False,
enable_multi_host=queue_adapter is not None,
),
cwd=cwd,
cores=cores,
)
execute_parallel_tasks_loop(interface=interface, future_queue=future_queue)

Expand All @@ -140,9 +153,13 @@ def execute_serial_tasks(
enable_flux_backend=False,
cwd=None,
sleep_interval=0.1,
queue_adapter=None,
queue_adapter_kwargs=None,
):
future_dict = {}
interface = SocketInterface()
interface = SocketInterface(
queue_adapter=queue_adapter, queue_adapter_kwargs=queue_adapter_kwargs
)
interface.bootup(
command_lst=get_parallel_subprocess_command(
port_selected=interface.bind_to_random_port(),
Expand All @@ -151,8 +168,10 @@ def execute_serial_tasks(
oversubscribe=oversubscribe,
enable_flux_backend=enable_flux_backend,
enable_mpi4py_backend=True,
enable_multi_host=queue_adapter is not None,
),
cwd=cwd,
cores=cores,
)
execute_serial_tasks_loop(
interface=interface,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class TestInterface(unittest.TestCase):
def test_interface(self):
cloudpickle_register(ind=1)
task_dict = {"fn": calc, 'args': (), "kwargs": {"i": 2}}
interface = SocketInterface()
interface = SocketInterface(queue_adapter=None, queue_adapter_kwargs=None)
interface.bootup(
command_lst=get_parallel_subprocess_command(
port_selected=interface.bind_to_random_port(),
Expand Down