Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions pympipool/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from pympipool.shared.communication import (
SocketInterface,
connect_to_socket_interface,
send_result,
close_connection,
receive_instruction,
interface_connect,
interface_bootup,
interface_send,
interface_shutdown,
interface_receive,
)
from pympipool.interfaces.taskbroker import HPCExecutor
from pympipool.interfaces.taskexecutor import Executor
Expand Down
20 changes: 10 additions & 10 deletions pympipool/backend/mpiexec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import cloudpickle

from pympipool.shared.communication import (
connect_to_socket_interface,
send_result,
close_connection,
receive_instruction,
interface_connect,
interface_send,
interface_shutdown,
interface_receive,
)
from pympipool.shared.backend import call_funct, parse_arguments

Expand All @@ -26,7 +26,7 @@ def main():

argument_dict = parse_arguments(argument_lst=sys.argv)
if mpi_rank_zero:
context, socket = connect_to_socket_interface(
context, socket = interface_connect(
host=argument_dict["host"], port=argument_dict["zmqport"]
)
else:
Expand All @@ -43,16 +43,16 @@ def main():
while True:
# Read from socket
if mpi_rank_zero:
input_dict = receive_instruction(socket=socket)
input_dict = interface_receive(socket=socket)
else:
input_dict = None
input_dict = MPI.COMM_WORLD.bcast(input_dict, root=0)

# Parse input
if "shutdown" in input_dict.keys() and input_dict["shutdown"]:
if mpi_rank_zero:
send_result(socket=socket, result_dict={"result": True})
close_connection(socket=socket, context=context)
interface_send(socket=socket, result_dict={"result": True})
interface_shutdown(socket=socket, context=context)
break
elif (
"fn" in input_dict.keys()
Expand All @@ -69,14 +69,14 @@ def main():
output_reply = output
except Exception as error:
if mpi_rank_zero:
send_result(
interface_send(
socket=socket,
result_dict={"error": error, "error_type": str(type(error))},
)
else:
# Send output
if mpi_rank_zero:
send_result(socket=socket, result_dict={"result": output_reply})
interface_send(socket=socket, result_dict={"result": output_reply})
elif (
"init" in input_dict.keys()
and input_dict["init"]
Expand Down
20 changes: 10 additions & 10 deletions pympipool/legacy/backend/mpipool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import cloudpickle

from pympipool.shared.communication import (
connect_to_socket_interface,
send_result,
close_connection,
receive_instruction,
interface_connect,
interface_send,
interface_shutdown,
interface_receive,
)
from pympipool.legacy.shared.backend import parse_socket_communication, parse_arguments

Expand Down Expand Up @@ -36,27 +36,27 @@ def main():
path=sys.path, # required for flux interface - otherwise the current path is not included in the python path
) as executor:
if executor is not None:
context, socket = connect_to_socket_interface(
context, socket = interface_connect(
host=argument_dict["host"], port=argument_dict["zmqport"]
)
while True:
output = parse_socket_communication(
executor=executor,
input_dict=receive_instruction(socket=socket),
input_dict=interface_receive(socket=socket),
future_dict=future_dict,
cores_per_task=int(argument_dict["cores_per_task"]),
)
if "exit" in output.keys() and output["exit"]:
if "result" in output.keys():
send_result(
interface_send(
socket=socket, result_dict={"result": output["result"]}
)
else:
send_result(socket=socket, result_dict={"result": True})
close_connection(socket=socket, context=context)
interface_send(socket=socket, result_dict={"result": True})
interface_shutdown(socket=socket, context=context)
break
elif isinstance(output, dict):
send_result(socket=socket, result_dict=output)
interface_send(socket=socket, result_dict=output)


if __name__ == "__main__":
Expand Down
61 changes: 25 additions & 36 deletions pympipool/legacy/interfaces/pool.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from abc import ABC

from pympipool.shared.communication import SocketInterface
from pympipool.shared.communication import interface_bootup
from pympipool.shared.taskexecutor import cloudpickle_register
from pympipool.legacy.shared.interface import get_parallel_subprocess_command
from pympipool.legacy.shared.interface import get_pool_command


class PoolBase(ABC):
Expand All @@ -11,11 +11,9 @@ class PoolBase(ABC):
alone. Rather it implements the __enter__(), __exit__() and shutdown() function shared between the derived classes.
"""

def __init__(self, queue_adapter=None, queue_adapter_kwargs=None):
def __init__(self):
self._future_dict = {}
self._interface = SocketInterface(
queue_adapter=queue_adapter, queue_adapter_kwargs=queue_adapter_kwargs
)
self._interface = None
cloudpickle_register(ind=3)

def __enter__(self):
Expand Down Expand Up @@ -71,23 +69,17 @@ def __init__(
queue_adapter=None,
queue_adapter_kwargs=None,
):
super().__init__(
queue_adapter=queue_adapter, queue_adapter_kwargs=queue_adapter_kwargs
)
self._interface.bootup(
command_lst=get_parallel_subprocess_command(
port_selected=self._interface.bind_to_random_port(),
cores=max_workers,
cores_per_task=1,
gpus_per_task=gpus_per_task,
oversubscribe=oversubscribe,
enable_flux_backend=enable_flux_backend,
enable_slurm_backend=enable_slurm_backend,
enable_mpi4py_backend=True,
enable_multi_host=queue_adapter is not None,
),
super().__init__()
self._interface = interface_bootup(
command_lst=get_pool_command(cores_total=max_workers, ranks_per_task=1)[0],
cwd=cwd,
cores=max_workers,
gpus_per_core=gpus_per_task,
oversubscribe=oversubscribe,
enable_flux_backend=enable_flux_backend,
enable_slurm_backend=enable_slurm_backend,
queue_adapter=queue_adapter,
queue_adapter_kwargs=queue_adapter_kwargs,
)

def map(self, func, iterable, chunksize=None):
Expand Down Expand Up @@ -178,23 +170,20 @@ def __init__(
queue_adapter=None,
queue_adapter_kwargs=None,
):
super().__init__(
queue_adapter=queue_adapter, queue_adapter_kwargs=queue_adapter_kwargs
super().__init__()
command_lst, cores = get_pool_command(
cores_total=max_ranks, ranks_per_task=ranks_per_task
)
self._interface.bootup(
command_lst=get_parallel_subprocess_command(
port_selected=self._interface.bind_to_random_port(),
cores=max_ranks,
cores_per_task=ranks_per_task,
gpus_per_task=gpus_per_task,
oversubscribe=oversubscribe,
enable_flux_backend=False,
enable_slurm_backend=False,
enable_mpi4py_backend=True,
enable_multi_host=queue_adapter is not None,
),
self._interface = interface_bootup(
command_lst=command_lst,
cwd=cwd,
cores=max_ranks,
cores=cores,
gpus_per_core=gpus_per_task,
oversubscribe=oversubscribe,
enable_flux_backend=False,
enable_slurm_backend=False,
queue_adapter=queue_adapter,
queue_adapter_kwargs=queue_adapter_kwargs,
)

def map(self, func, iterable, chunksize=None):
Expand Down
Loading