Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 3 additions & 18 deletions executorlib/task_scheduler/interactive/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,47 +77,38 @@ def execute_tasks(
if error_log_file is not None:
task_dict["error_log_file"] = error_log_file
if cache_directory is None:
_execute_task_without_cache(
interface=interface, task_dict=task_dict, future_queue=future_queue
)
_execute_task_without_cache(interface=interface, task_dict=task_dict)
else:
_execute_task_with_cache(
interface=interface,
task_dict=task_dict,
future_queue=future_queue,
cache_directory=cache_directory,
cache_key=cache_key,
)
_task_done(future_queue=future_queue)


def _execute_task_without_cache(
interface: SocketInterface, task_dict: dict, future_queue: queue.Queue
):
def _execute_task_without_cache(interface: SocketInterface, task_dict: dict):
"""
Execute the task in the task_dict by communicating it via the interface.

Args:
interface (SocketInterface): socket interface for zmq communication
task_dict (dict): task submitted to the executor as dictionary. This dictionary has the following keys
{"fn": Callable, "args": (), "kwargs": {}, "resource_dict": {}}
future_queue (Queue): Queue for receiving new tasks.
"""
f = task_dict.pop("future")
if not f.done() and f.set_running_or_notify_cancel():
try:
f.set_result(interface.send_and_receive_dict(input_dict=task_dict))
except Exception as thread_exception:
interface.shutdown(wait=True)
_task_done(future_queue=future_queue)
f.set_exception(exception=thread_exception)
else:
_task_done(future_queue=future_queue)

Comment on lines 100 to 107
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Unify exception propagation; re-raise after setting Future exception

Without re-raising here, the outer loop keeps running with a shut-down interface, causing follow-up failures. Make this consistent with the cache path.

     if not f.done() and f.set_running_or_notify_cancel():
         try:
             f.set_result(interface.send_and_receive_dict(input_dict=task_dict))
         except Exception as thread_exception:
             interface.shutdown(wait=True)
             f.set_exception(exception=thread_exception)
+            raise thread_exception
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
f = task_dict.pop("future")
if not f.done() and f.set_running_or_notify_cancel():
try:
f.set_result(interface.send_and_receive_dict(input_dict=task_dict))
except Exception as thread_exception:
interface.shutdown(wait=True)
_task_done(future_queue=future_queue)
f.set_exception(exception=thread_exception)
else:
_task_done(future_queue=future_queue)
f = task_dict.pop("future")
if not f.done() and f.set_running_or_notify_cancel():
try:
f.set_result(interface.send_and_receive_dict(input_dict=task_dict))
except Exception as thread_exception:
interface.shutdown(wait=True)
f.set_exception(exception=thread_exception)
raise thread_exception
🤖 Prompt for AI Agents
In executorlib/task_scheduler/interactive/shared.py around lines 100 to 107, the
except block sets the Future exception and shuts down the interface but does not
re-raise, which leaves the outer loop running against a shut-down interface;
update the except block to call interface.shutdown(wait=True), set the future
exception as done, and then re-raise the caught thread_exception (raise) so the
surrounding caller observes the failure—matching the cache path behavior.


def _execute_task_with_cache(
interface: SocketInterface,
task_dict: dict,
future_queue: queue.Queue,
cache_directory: str,
cache_key: Optional[str] = None,
):
Expand All @@ -128,7 +119,6 @@ def _execute_task_with_cache(
interface (SocketInterface): socket interface for zmq communication
task_dict (dict): task submitted to the executor as dictionary. This dictionary has the following keys
{"fn": Callable, "args": (), "kwargs": {}, "resource_dict": {}}
future_queue (Queue): Queue for receiving new tasks.
cache_directory (str): The directory to store cache files.
cache_key (str, optional): By default the cache_key is generated based on the function hash, this can be
overwritten by setting the cache_key.
Expand All @@ -155,16 +145,11 @@ def _execute_task_with_cache(
f.set_result(result)
except Exception as thread_exception:
interface.shutdown(wait=True)
_task_done(future_queue=future_queue)
f.set_exception(exception=thread_exception)
raise thread_exception
else:
_task_done(future_queue=future_queue)
else:
_, _, result = get_output(file_name=file_name)
future = task_dict["future"]
future.set_result(result)
_task_done(future_queue=future_queue)


def _task_done(future_queue: queue.Queue):
Expand Down
40 changes: 21 additions & 19 deletions tests/test_mpiexecspawner.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,13 +443,13 @@ def test_execute_task_failed_no_argument(self):
q.put({"fn": calc_array, "args": (), "kwargs": {}, "future": f})
q.put({"shutdown": True, "wait": True})
cloudpickle_register(ind=1)
execute_tasks(
future_queue=q,
cores=1,
openmpi_oversubscribe=False,
spawner=MpiExecSpawner,
)
with self.assertRaises(TypeError):
execute_tasks(
future_queue=q,
cores=1,
openmpi_oversubscribe=False,
spawner=MpiExecSpawner,
)
f.result()
q.join()

Expand All @@ -459,13 +459,13 @@ def test_execute_task_failed_wrong_argument(self):
q.put({"fn": calc_array, "args": (), "kwargs": {"j": 4}, "future": f})
q.put({"shutdown": True, "wait": True})
cloudpickle_register(ind=1)
execute_tasks(
future_queue=q,
cores=1,
openmpi_oversubscribe=False,
spawner=MpiExecSpawner,
)
with self.assertRaises(TypeError):
execute_tasks(
future_queue=q,
cores=1,
openmpi_oversubscribe=False,
spawner=MpiExecSpawner,
)
f.result()
q.join()

Expand Down Expand Up @@ -533,13 +533,15 @@ def test_execute_task_cache_failed_no_argument(self):
f = Future()
q = Queue()
q.put({"fn": calc_array, "args": (), "kwargs": {}, "future": f})
q.put({"shutdown": True, "wait": True})
cloudpickle_register(ind=1)
execute_tasks(
future_queue=q,
cores=1,
openmpi_oversubscribe=False,
spawner=MpiExecSpawner,
cache_directory="executorlib_cache",
)
with self.assertRaises(TypeError):
execute_tasks(
future_queue=q,
cores=1,
openmpi_oversubscribe=False,
spawner=MpiExecSpawner,
cache_directory="executorlib_cache",
)
f.result()
q.join()
Loading