diff --git a/executorlib/standalone/interactive/communication.py b/executorlib/standalone/interactive/communication.py index 4a198882..4fa224ac 100644 --- a/executorlib/standalone/interactive/communication.py +++ b/executorlib/standalone/interactive/communication.py @@ -1,12 +1,16 @@ import logging import sys from socket import gethostname -from typing import Optional +from typing import Any, Optional import cloudpickle import zmq +class ExecutorlibSocketError(RuntimeError): + pass + + class SocketInterface: """ The SocketInterface is an abstraction layer on top of the zero message queue. @@ -14,22 +18,29 @@ class SocketInterface: Args: spawner (executorlib.shared.spawner.BaseSpawner): Interface for starting the parallel process log_obj_size (boolean): Enable debug mode which reports the size of the communicated objects. + time_out_ms (int): Time out for waiting for a message on socket in milliseconds. """ - def __init__(self, spawner=None, log_obj_size=False): + def __init__( + self, spawner=None, log_obj_size: bool = False, time_out_ms: int = 1000 + ): """ Initialize the SocketInterface. Args: spawner (executorlib.shared.spawner.BaseSpawner): Interface for starting the parallel process + log_obj_size (boolean): Enable debug mode which reports the size of the communicated objects. + time_out_ms (int): Time out for waiting for a message on socket in milliseconds. """ self._context = zmq.Context() self._socket = self._context.socket(zmq.PAIR) + self._poller = zmq.Poller() + self._poller.register(self._socket, zmq.POLLIN) self._process = None + self._time_out_ms = time_out_ms + self._logger: Optional[logging.Logger] = None if log_obj_size: self._logger = logging.getLogger("executorlib") - else: - self._logger = None self._spawner = spawner def send_dict(self, input_dict: dict): @@ -52,7 +63,12 @@ def receive_dict(self) -> dict: Returns: dict: dictionary with response received from the connected client """ - data = self._socket.recv() + response_lst: list[tuple[Any, int]] = [] + while len(response_lst) == 0: + response_lst = self._poller.poll(self._time_out_ms) + if not self._spawner.poll(): + raise ExecutorlibSocketError() + data = self._socket.recv(zmq.NOBLOCK) if self._logger is not None: self._logger.warning( "Received dictionary of size: " + str(sys.getsizeof(data)) diff --git a/tests/test_standalone_interactive_communication.py b/tests/test_standalone_interactive_communication.py index edb75845..ebf00ba1 100644 --- a/tests/test_standalone_interactive_communication.py +++ b/tests/test_standalone_interactive_communication.py @@ -12,6 +12,7 @@ interface_send, interface_receive, SocketInterface, + ExecutorlibSocketError, ) from executorlib.standalone.serialize import cloudpickle_register from executorlib.standalone.interactive.spawner import MpiExecSpawner @@ -114,6 +115,35 @@ def test_interface_serial_with_debug(self): ) interface.shutdown(wait=True) + def test_interface_serial_with_stopped_process(self): + cloudpickle_register(ind=1) + task_dict = {"fn": calc, "args": (), "kwargs": {"i": 2}} + interface = SocketInterface( + spawner=MpiExecSpawner(cwd=None, cores=1, openmpi_oversubscribe=False), + log_obj_size=True, + ) + interface.bootup( + command_lst=[ + sys.executable, + os.path.abspath( + os.path.join( + __file__, + "..", + "..", + "executorlib", + "backend", + "interactive_serial.py", + ) + ), + "--zmqport", + str(interface.bind_to_random_port()), + ] + ) + interface.send_dict(input_dict=task_dict) + interface._spawner._process.terminate() + with self.assertRaises(ExecutorlibSocketError): + interface.receive_dict() + class TestZMQ(unittest.TestCase): def test_interface_receive(self):