Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions executorlib/task_scheduler/interactive/blockallocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,10 @@ def _execute_multiple_tasks(
future_queue.join()
break
elif "fn" in task_dict and "future" in task_dict:
f = task_dict.pop("future")
execute_task_dict(
task_dict=task_dict,
future_obj=f,
interface=interface,
cache_directory=cache_directory,
cache_key=cache_key,
Expand Down
8 changes: 7 additions & 1 deletion executorlib/task_scheduler/interactive/onetoone.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import queue
from concurrent.futures import Future
from threading import Thread
from typing import Optional

Expand Down Expand Up @@ -186,6 +187,7 @@ def _wrap_execute_task_in_separate_process(
dictionary containing the future objects and the number of cores they require
"""
resource_dict = task_dict.pop("resource_dict").copy()
f = task_dict.pop("future")
if "cores" not in resource_dict or (
resource_dict["cores"] == 1 and executor_kwargs["cores"] >= 1
):
Expand All @@ -197,14 +199,15 @@ def _wrap_execute_task_in_separate_process(
max_cores=max_cores,
max_workers=max_workers,
)
active_task_dict[task_dict["future"]] = slots_required
active_task_dict[f] = slots_required
task_kwargs = executor_kwargs.copy()
task_kwargs.update(resource_dict)
task_kwargs.update(
{
"task_dict": task_dict,
"spawner": spawner,
"hostname_localhost": hostname_localhost,
"future_obj": f,
}
)
process = Thread(
Expand All @@ -217,6 +220,7 @@ def _wrap_execute_task_in_separate_process(

def _execute_task_in_thread(
task_dict: dict,
future_obj: Future,
cores: int = 1,
spawner: type[BaseSpawner] = MpiExecSpawner,
hostname_localhost: Optional[bool] = None,
Expand All @@ -233,6 +237,7 @@ def _execute_task_in_thread(
Args:
task_dict (dict): task submitted to the executor as dictionary. This dictionary has the following keys
{"fn": Callable, "args": (), "kwargs": {}, "resource_dict": {}}
future_obj (Future): A Future representing the given call.
cores (int): defines the total number of MPI ranks to use
spawner (BaseSpawner): Spawner to start process on selected compute resources
hostname_localhost (boolean): use localhost instead of the hostname to establish the zmq connection. In the
Expand All @@ -253,6 +258,7 @@ def _execute_task_in_thread(
"""
execute_task_dict(
task_dict=task_dict,
future_obj=future_obj,
interface=interface_bootup(
command_lst=get_interactive_execute_command(
cores=cores,
Expand Down
79 changes: 46 additions & 33 deletions executorlib/task_scheduler/interactive/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import queue
import time
from concurrent.futures import Future
from typing import Optional

from executorlib.standalone.interactive.communication import SocketInterface
Expand All @@ -10,6 +11,7 @@

def execute_task_dict(
task_dict: dict,
future_obj: Future,
interface: SocketInterface,
cache_directory: Optional[str] = None,
cache_key: Optional[str] = None,
Expand All @@ -21,52 +23,65 @@ def execute_task_dict(
Args:
task_dict (dict): task submitted to the executor as dictionary. This dictionary has the following keys
{"fn": Callable, "args": (), "kwargs": {}, "resource_dict": {}}
future_obj (Future): A Future representing the given call.
interface (SocketInterface): socket interface for zmq communication
cache_directory (str, optional): The directory to store cache files. Defaults to "executorlib_cache".
cache_key (str, optional): By default the cache_key is generated based on the function hash, this can be
overwritten by setting the cache_key.
error_log_file (str): Name of the error log file to use for storing exceptions raised by the Python functions
submitted to the Executor.
"""
if error_log_file is not None:
task_dict["error_log_file"] = error_log_file
if cache_directory is None:
_execute_task_without_cache(interface=interface, task_dict=task_dict)
else:
_execute_task_with_cache(
interface=interface,
task_dict=task_dict,
cache_directory=cache_directory,
cache_key=cache_key,
)
if not future_obj.done() and future_obj.set_running_or_notify_cancel():
if error_log_file is not None:
task_dict["error_log_file"] = error_log_file
if cache_directory is None:
_execute_task_without_cache(
interface=interface, task_dict=task_dict, future_obj=future_obj
)
else:
_execute_task_with_cache(
interface=interface,
task_dict=task_dict,
cache_directory=cache_directory,
cache_key=cache_key,
future_obj=future_obj,
)


def task_done(future_queue: queue.Queue):
"""
Mark the current task as done in the current queue.

Args:
future_queue (queue): Queue of task dictionaries waiting for execution.
"""
with contextlib.suppress(ValueError):
future_queue.task_done()


def _execute_task_without_cache(interface: SocketInterface, task_dict: dict):
def _execute_task_without_cache(
interface: SocketInterface, task_dict: dict, future_obj: Future
):
"""
Execute the task in the task_dict by communicating it via the interface.

Args:
interface (SocketInterface): socket interface for zmq communication
task_dict (dict): task submitted to the executor as dictionary. This dictionary has the following keys
{"fn": Callable, "args": (), "kwargs": {}, "resource_dict": {}}
future_obj (Future): A Future representing the given call.
"""
f = task_dict.pop("future")
if not f.done() and f.set_running_or_notify_cancel():
try:
f.set_result(interface.send_and_receive_dict(input_dict=task_dict))
except Exception as thread_exception:
interface.shutdown(wait=True)
f.set_exception(exception=thread_exception)
try:
future_obj.set_result(interface.send_and_receive_dict(input_dict=task_dict))
except Exception as thread_exception:
interface.shutdown(wait=True)
future_obj.set_exception(exception=thread_exception)


def _execute_task_with_cache(
interface: SocketInterface,
task_dict: dict,
future_obj: Future,
cache_directory: str,
cache_key: Optional[str] = None,
):
Expand All @@ -77,6 +92,7 @@ def _execute_task_with_cache(
interface (SocketInterface): socket interface for zmq communication
task_dict (dict): task submitted to the executor as dictionary. This dictionary has the following keys
{"fn": Callable, "args": (), "kwargs": {}, "resource_dict": {}}
future_obj (Future): A Future representing the given call.
cache_directory (str): The directory to store cache files.
cache_key (str, optional): By default the cache_key is generated based on the function hash, this can be
overwritten by setting the cache_key.
Expand All @@ -92,19 +108,16 @@ def _execute_task_with_cache(
)
file_name = os.path.abspath(os.path.join(cache_directory, task_key + "_o.h5"))
if file_name not in get_cache_files(cache_directory=cache_directory):
f = task_dict.pop("future")
if f.set_running_or_notify_cancel():
try:
time_start = time.time()
result = interface.send_and_receive_dict(input_dict=task_dict)
data_dict["output"] = result
data_dict["runtime"] = time.time() - time_start
dump(file_name=file_name, data_dict=data_dict)
f.set_result(result)
except Exception as thread_exception:
interface.shutdown(wait=True)
f.set_exception(exception=thread_exception)
try:
time_start = time.time()
result = interface.send_and_receive_dict(input_dict=task_dict)
data_dict["output"] = result
data_dict["runtime"] = time.time() - time_start
dump(file_name=file_name, data_dict=data_dict)
future_obj.set_result(result)
except Exception as thread_exception:
interface.shutdown(wait=True)
future_obj.set_exception(exception=thread_exception)
else:
_, _, result = get_output(file_name=file_name)
future = task_dict["future"]
future.set_result(result)
future_obj.set_result(result)
Comment on lines 121 to +123
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Cache-hit path ignores cancellation; may raise InvalidStateError

When the result is served from cache, the Future is completed unconditionally. If the Future was cancelled/prematurely completed, this can raise InvalidStateError and violate expected cancellation semantics. Mirror the non-cache path by honoring set_running_or_notify_cancel().

Apply:

-    else:
-        _, _, result = get_output(file_name=file_name)
-        future_obj.set_result(result)
+    else:
+        _, _, result = get_output(file_name=file_name)
+        if future_obj.set_running_or_notify_cancel():
+            future_obj.set_result(result)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
else:
_, _, result = get_output(file_name=file_name)
future = task_dict["future"]
future.set_result(result)
future_obj.set_result(result)
else:
_, _, result = get_output(file_name=file_name)
if future_obj.set_running_or_notify_cancel():
future_obj.set_result(result)
🤖 Prompt for AI Agents
In executorlib/task_scheduler/interactive/shared.py around lines 122 to 124, the
cache-hit branch unconditionally calls future_obj.set_result(result) which can
raise InvalidStateError if the Future was cancelled; mirror the non-cache path
by first calling future_obj.set_running_or_notify_cancel() and if it returns
False stop/return without setting the result, otherwise call set_result(result)
(optionally catch InvalidStateError as a safeguard).

Loading