diff --git a/pympipool/__init__.py b/pympipool/__init__.py index fbddb235..b560f164 100644 --- a/pympipool/__init__.py +++ b/pympipool/__init__.py @@ -1,6 +1,12 @@ from pympipool.share.pool import Pool, MPISpawnPool from pympipool.share.executor import Executor, PoolExecutor -from pympipool.share.communication import SocketInterface, connect_to_socket_interface +from pympipool.share.communication import ( + SocketInterface, + connect_to_socket_interface, + send_result, + close_connection, + receive_instruction, +) from pympipool.share.serial import cancel_items_in_queue from ._version import get_versions diff --git a/pympipool/executor/mpiexec.py b/pympipool/executor/mpiexec.py index 45c0767e..fb567833 100644 --- a/pympipool/executor/mpiexec.py +++ b/pympipool/executor/mpiexec.py @@ -3,7 +3,12 @@ import cloudpickle -from pympipool.share.communication import connect_to_socket_interface +from pympipool.share.communication import ( + connect_to_socket_interface, + send_result, + close_connection, + receive_instruction, +) from pympipool.share.parallel import call_funct, parse_arguments @@ -31,7 +36,7 @@ def main(): while True: # Read from socket if mpi_rank_zero: - input_dict = cloudpickle.loads(socket.recv()) + input_dict = receive_instruction(socket=socket) else: input_dict = None input_dict = MPI.COMM_WORLD.bcast(input_dict, root=0) @@ -39,9 +44,8 @@ def main(): # Parse input if "shutdown" in input_dict.keys() and input_dict["shutdown"]: if mpi_rank_zero: - socket.send(cloudpickle.dumps({"result": True})) - socket.close() - context.term() + send_result(socket=socket, result_dict={"result": True}) + close_connection(socket=socket, context=context) break elif ( "fn" in input_dict.keys() @@ -58,15 +62,14 @@ def main(): output_reply = output except Exception as error: if mpi_rank_zero: - socket.send( - cloudpickle.dumps( - {"error": error, "error_type": str(type(error))} - ) + send_result( + socket=socket, + result_dict={"error": error, "error_type": str(type(error))}, ) else: # Send output if mpi_rank_zero: - socket.send(cloudpickle.dumps({"result": output_reply})) + send_result(socket=socket, result_dict={"result": output_reply}) elif ( "init" in input_dict.keys() and input_dict["init"] diff --git a/pympipool/executor/mpipool.py b/pympipool/executor/mpipool.py index c855e062..87e0e8cf 100644 --- a/pympipool/executor/mpipool.py +++ b/pympipool/executor/mpipool.py @@ -3,7 +3,12 @@ import cloudpickle -from pympipool.share.communication import connect_to_socket_interface +from pympipool.share.communication import ( + connect_to_socket_interface, + send_result, + close_connection, + receive_instruction, +) from pympipool.share.parallel import ( parse_arguments, parse_socket_communication, @@ -30,20 +35,21 @@ def main(): while True: output = parse_socket_communication( executor=executor, - input_dict=cloudpickle.loads(socket.recv()), + input_dict=receive_instruction(socket=socket), future_dict=future_dict, cores_per_task=int(argument_dict["cores_per_task"]), ) if "exit" in output.keys() and output["exit"]: if "result" in output.keys(): - socket.send(cloudpickle.dumps({"result": output["result"]})) + send_result( + socket=socket, result_dict={"result": output["result"]} + ) else: - socket.send(cloudpickle.dumps({"result": True})) - socket.close() - context.term() + send_result(socket=socket, result_dict={"result": True}) + close_connection(socket=socket, context=context) break elif isinstance(output, dict): - socket.send(cloudpickle.dumps(output)) + send_result(socket=socket, result_dict=output) if __name__ == "__main__": diff --git a/pympipool/share/communication.py b/pympipool/share/communication.py index fa3c84f1..51ca4fe0 100644 --- a/pympipool/share/communication.py +++ b/pympipool/share/communication.py @@ -70,3 +70,16 @@ def connect_to_socket_interface(host, port): socket = context.socket(zmq.PAIR) socket.connect("tcp://" + host + ":" + port) return context, socket + + +def send_result(socket, result_dict): + socket.send(cloudpickle.dumps(result_dict)) + + +def receive_instruction(socket): + return cloudpickle.loads(socket.recv()) + + +def close_connection(socket, context): + socket.close() + context.term() diff --git a/pympipool/share/parallel.py b/pympipool/share/parallel.py index aa797985..4c06d29f 100644 --- a/pympipool/share/parallel.py +++ b/pympipool/share/parallel.py @@ -1,5 +1,4 @@ import inspect -import zmq from tqdm import tqdm diff --git a/tests/test_zmq.py b/tests/test_zmq.py index d63c338d..3d1e9302 100644 --- a/tests/test_zmq.py +++ b/tests/test_zmq.py @@ -1,7 +1,11 @@ import unittest import zmq -import cloudpickle -from pympipool.share.communication import connect_to_socket_interface +from pympipool.share.communication import ( + connect_to_socket_interface, + close_connection, + send_result, + receive_instruction +) class TestZMQ(unittest.TestCase): @@ -13,9 +17,7 @@ def test_initialize_zmq(self): socket_server = context_server.socket(zmq.PAIR) port = str(socket_server.bind_to_random_port("tcp://*")) context_client, socket_client = connect_to_socket_interface(host=host, port=port) - socket_server.send(cloudpickle.dumps(message)) - self.assertEqual(cloudpickle.loads(socket_client.recv()), message) - socket_client.close() - context_client.term() - socket_server.close() - context_server.term() + send_result(socket=socket_server, result_dict={"message": message}) + self.assertEqual(receive_instruction(socket=socket_client), {"message": message}) + close_connection(socket=socket_client, context=context_client) + close_connection(socket=socket_server, context=context_server)