Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions executorlib/standalone/interactive/communication.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,46 @@
import logging
import sys
from socket import gethostname
from typing import Optional
from typing import Any, Optional

import cloudpickle
import zmq


class ExecutorlibSocketError(RuntimeError):
pass


class SocketInterface:
"""
The SocketInterface is an abstraction layer on top of the zero message queue.
Args:
spawner (executorlib.shared.spawner.BaseSpawner): Interface for starting the parallel process
log_obj_size (boolean): Enable debug mode which reports the size of the communicated objects.
time_out_ms (int): Time out for waiting for a message on socket in milliseconds.
"""

def __init__(self, spawner=None, log_obj_size=False):
def __init__(
self, spawner=None, log_obj_size: bool = False, time_out_ms: int = 1000
):
"""
Initialize the SocketInterface.
Args:
spawner (executorlib.shared.spawner.BaseSpawner): Interface for starting the parallel process
log_obj_size (boolean): Enable debug mode which reports the size of the communicated objects.
time_out_ms (int): Time out for waiting for a message on socket in milliseconds.
"""
self._context = zmq.Context()
self._socket = self._context.socket(zmq.PAIR)
self._poller = zmq.Poller()
self._poller.register(self._socket, zmq.POLLIN)
self._process = None
self._time_out_ms = time_out_ms
self._logger: Optional[logging.Logger] = None
if log_obj_size:
self._logger = logging.getLogger("executorlib")
else:
self._logger = None
self._spawner = spawner

def send_dict(self, input_dict: dict):
Expand All @@ -52,7 +63,12 @@ def receive_dict(self) -> dict:
Returns:
dict: dictionary with response received from the connected client
"""
data = self._socket.recv()
response_lst: list[tuple[Any, int]] = []
while len(response_lst) == 0:
response_lst = self._poller.poll(self._time_out_ms)
if not self._spawner.poll():
raise ExecutorlibSocketError()
data = self._socket.recv(zmq.NOBLOCK)
if self._logger is not None:
self._logger.warning(
"Received dictionary of size: " + str(sys.getsizeof(data))
Expand Down
30 changes: 30 additions & 0 deletions tests/test_standalone_interactive_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
interface_send,
interface_receive,
SocketInterface,
ExecutorlibSocketError,
)
from executorlib.standalone.serialize import cloudpickle_register
from executorlib.standalone.interactive.spawner import MpiExecSpawner
Expand Down Expand Up @@ -114,6 +115,35 @@ def test_interface_serial_with_debug(self):
)
interface.shutdown(wait=True)

def test_interface_serial_with_stopped_process(self):
cloudpickle_register(ind=1)
task_dict = {"fn": calc, "args": (), "kwargs": {"i": 2}}
interface = SocketInterface(
spawner=MpiExecSpawner(cwd=None, cores=1, openmpi_oversubscribe=False),
log_obj_size=True,
)
interface.bootup(
command_lst=[
sys.executable,
os.path.abspath(
os.path.join(
__file__,
"..",
"..",
"executorlib",
"backend",
"interactive_serial.py",
)
),
"--zmqport",
str(interface.bind_to_random_port()),
]
)
interface.send_dict(input_dict=task_dict)
interface._spawner._process.terminate()
with self.assertRaises(ExecutorlibSocketError):
interface.receive_dict()

Comment on lines +118 to +146
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Likely inverted liveness check in receive_dict; test may mask the bug.

receive_dict() currently raises when the spawner is alive and may loop forever when it exited with a non-zero code (common after terminate()). This can cause this test to hang instead of raising, or to raise for the wrong reason.

  1. Proposed fix (outside this file) in executorlib/standalone/interactive/communication.py:
 def receive_dict(self) -> dict:
     """
     Receive a dictionary from a connected client process.
     """
-    response_lst: list[tuple[Any, int]] = []
-    while len(response_lst) == 0:
-        response_lst = self._poller.poll(self._time_out_ms)
-        if not self._spawner.poll():
-            raise ExecutorlibSocketError()
+    response_lst: list[tuple[Any, int]] = []
+    while True:
+        # If the spawner exited (any return code), fail fast.
+        if self._spawner.poll() is not None:
+            raise ExecutorlibSocketError("Spawner terminated while waiting for response")
+        response_lst = self._poller.poll(self._time_out_ms)
+        if response_lst:
+            break
     data = self._socket.recv(zmq.NOBLOCK)
  1. Verify the current logic and ensure the fix is applied:

🏁 Script executed:

#!/bin/bash
set -euo pipefail
# Show the receive_dict implementation and nearby lines
rg -n -C3 $'def\\s+receive_dict\\b' executorlib/standalone/interactive/communication.py
# Detect the inverted check; this should return no matches after the fix
rg -n $'if\\s+not\\s+self\\._spawner\\.poll\\(\\)\\s*:' executorlib/standalone/interactive/communication.py || true

Length of output: 506


Fix inverted liveness check in receive_dict
executorlib/standalone/interactive/communication.py:69 still uses

if not self._spawner.poll():
    raise ExecutorlibSocketError()

This is inverted – it hangs when the spawner exits. Change to:

if self._spawner.poll() is not None:
    raise ExecutorlibSocketError("Spawner terminated while waiting for response")

to fail fast on termination.

🤖 Prompt for AI Agents
In executorlib/standalone/interactive/communication.py around line 69, the
liveness check is inverted: currently it raises when poll() is falsy and thus
hangs when the spawner exits; change the condition to check for termination
using if self._spawner.poll() is not None and raise ExecutorlibSocketError with
a clear message like "Spawner terminated while waiting for response" so the code
fails fast when the spawner has exited.


class TestZMQ(unittest.TestCase):
def test_interface_receive(self):
Expand Down
Loading