diff --git a/executorlib/standalone/interactive/communication.py b/executorlib/standalone/interactive/communication.py index b5af3c56..b0c4bc39 100644 --- a/executorlib/standalone/interactive/communication.py +++ b/executorlib/standalone/interactive/communication.py @@ -1,7 +1,7 @@ import logging import sys from socket import gethostname -from typing import Any, Optional +from typing import Any, Callable, Optional import cloudpickle import zmq @@ -43,6 +43,7 @@ def __init__( self._logger = logging.getLogger("executorlib") self._spawner = spawner self._command_lst: list[str] = [] + self._stop_function: Optional[Callable] = None def send_dict(self, input_dict: dict): """ @@ -107,7 +108,8 @@ def bind_to_random_port(self) -> int: def bootup( self, command_lst: list[str], - ): + stop_function: Optional[Callable] = None, + ) -> bool: """ Boot up the client process to connect to the SocketInterface. @@ -115,17 +117,26 @@ def bootup( command_lst (list): list of strings to start the client process """ self._command_lst = command_lst - self._spawner.bootup( + self._stop_function = stop_function + if not self._spawner.bootup( command_lst=command_lst, - ) + stop_function=stop_function, + ): + self._reset_socket() + return False + return True def restart(self): """ Restart the client process to onnect to the SocketInterface. """ - self._spawner.bootup( + if not self._spawner.bootup( command_lst=self._command_lst, - ) + stop_function=self._stop_function, + ): + self._reset_socket() + return False + return True def shutdown(self, wait: bool = True): """ @@ -140,6 +151,10 @@ def shutdown(self, wait: bool = True): input_dict={"shutdown": True, "wait": wait} ) self._spawner.shutdown(wait=wait) + self._reset_socket() + return result + + def _reset_socket(self): if self._socket is not None: self._socket.close() if self._context is not None: @@ -147,7 +162,6 @@ def shutdown(self, wait: bool = True): self._process = None self._socket = None self._context = None - return result def __del__(self): """ @@ -163,7 +177,8 @@ def interface_bootup( hostname_localhost: Optional[bool] = None, log_obj_size: bool = False, worker_id: Optional[int] = None, -) -> SocketInterface: + stop_function: Optional[Callable] = None, +) -> Optional[SocketInterface]: """ Start interface for ZMQ communication @@ -202,10 +217,13 @@ def interface_bootup( "--zmqport", str(interface.bind_to_random_port()), ] - interface.bootup( + if interface.bootup( command_lst=command_lst, - ) - return interface + stop_function=stop_function, + ): + return interface + else: + return None def interface_connect(host: str, port: str) -> tuple[zmq.Context, zmq.Socket]: diff --git a/executorlib/standalone/interactive/spawner.py b/executorlib/standalone/interactive/spawner.py index 4a5cb390..ce90052b 100644 --- a/executorlib/standalone/interactive/spawner.py +++ b/executorlib/standalone/interactive/spawner.py @@ -1,7 +1,7 @@ import os import subprocess from abc import ABC, abstractmethod -from typing import Optional +from typing import Callable, Optional MPI_COMMAND = "mpiexec" @@ -29,7 +29,8 @@ def __init__( def bootup( self, command_lst: list[str], - ): + stop_function: Optional[Callable] = None, + ) -> bool: """ Method to start the interface. @@ -87,7 +88,8 @@ def __init__( def bootup( self, command_lst: list[str], - ): + stop_function: Optional[Callable] = None, + ) -> bool: """ Method to start the subprocess interface. @@ -101,6 +103,7 @@ def bootup( cwd=self._cwd, stdin=subprocess.DEVNULL, ) + return True def generate_command(self, command_lst: list[str]) -> list[str]: """ diff --git a/executorlib/task_scheduler/interactive/blockallocation.py b/executorlib/task_scheduler/interactive/blockallocation.py index 96cec2c1..2e1d1f02 100644 --- a/executorlib/task_scheduler/interactive/blockallocation.py +++ b/executorlib/task_scheduler/interactive/blockallocation.py @@ -12,6 +12,8 @@ from executorlib.task_scheduler.base import TaskSchedulerBase from executorlib.task_scheduler.interactive.shared import execute_tasks +_task_schedulder_dict: dict = {} + class BlockAllocationTaskScheduler(TaskSchedulerBase): """ @@ -61,11 +63,18 @@ def __init__( executor_kwargs["queue_join_on_shutdown"] = False self._process_kwargs = executor_kwargs self._max_workers = max_workers + self_id = id(self) + self._self_id = self_id + _task_schedulder_dict[self._self_id] = False self._set_process( process=[ Thread( target=execute_tasks, - kwargs=executor_kwargs | {"worker_id": worker_id}, + kwargs=executor_kwargs + | { + "worker_id": worker_id, + "stop_function": lambda: _task_schedulder_dict[self_id], + }, ) for worker_id in range(self._max_workers) ], @@ -155,7 +164,9 @@ def shutdown(self, wait: bool = True, *, cancel_futures: bool = False): if self._future_queue is not None: if cancel_futures: cancel_items_in_queue(que=self._future_queue) + self._shutdown_flag = True if isinstance(self._process, list): + _task_schedulder_dict[self._self_id] = True for _ in range(len(self._process)): self._future_queue.put({"shutdown": True, "wait": wait}) if wait: diff --git a/executorlib/task_scheduler/interactive/fluxspawner.py b/executorlib/task_scheduler/interactive/fluxspawner.py index 5a35dd5c..378bbe92 100644 --- a/executorlib/task_scheduler/interactive/fluxspawner.py +++ b/executorlib/task_scheduler/interactive/fluxspawner.py @@ -1,5 +1,5 @@ import os -from typing import Optional +from typing import Callable, Optional import flux import flux.job @@ -75,7 +75,8 @@ def __init__( def bootup( self, command_lst: list[str], - ): + stop_function: Optional[Callable] = None, + ) -> bool: """ Boot up the client process to connect to the SocketInterface. @@ -126,6 +127,7 @@ def bootup( ) else: self._future = self._flux_executor.submit(jobspec=jobspec) + return True def shutdown(self, wait: bool = True): """ diff --git a/executorlib/task_scheduler/interactive/pysqaspawner.py b/executorlib/task_scheduler/interactive/pysqaspawner.py index 73a8cb87..31f57c8b 100644 --- a/executorlib/task_scheduler/interactive/pysqaspawner.py +++ b/executorlib/task_scheduler/interactive/pysqaspawner.py @@ -58,6 +58,7 @@ def __init__( def bootup( self, command_lst: list[str], + stop_function: Optional[Callable] = None, ): """ Method to start the subprocess interface. @@ -76,7 +77,10 @@ def bootup( ) while True: if self._check_process_helper(command_lst=command_lst): - break + return True + elif stop_function is not None and stop_function(): + self.shutdown(wait=True) + return False else: sleep(1) # Wait for the process to start diff --git a/executorlib/task_scheduler/interactive/shared.py b/executorlib/task_scheduler/interactive/shared.py index 883c3dac..fea9f86a 100644 --- a/executorlib/task_scheduler/interactive/shared.py +++ b/executorlib/task_scheduler/interactive/shared.py @@ -28,6 +28,7 @@ def execute_tasks( log_obj_size: bool = False, error_log_file: Optional[str] = None, worker_id: Optional[int] = None, + stop_function: Optional[Callable] = None, **kwargs, ) -> None: """ @@ -63,15 +64,17 @@ def execute_tasks( hostname_localhost=hostname_localhost, log_obj_size=log_obj_size, worker_id=worker_id, + stop_function=stop_function, ) - if init_function is not None: + if init_function is not None and interface is not None: interface.send_dict( input_dict={"init": True, "fn": init_function, "args": (), "kwargs": {}} ) while True: task_dict = future_queue.get() if "shutdown" in task_dict and task_dict["shutdown"]: - interface.shutdown(wait=task_dict["wait"]) + if interface is not None: + interface.shutdown(wait=task_dict["wait"]) _task_done(future_queue=future_queue) if queue_join_on_shutdown: future_queue.join() @@ -79,23 +82,31 @@ def execute_tasks( elif "fn" in task_dict and "future" in task_dict: if error_log_file is not None: task_dict["error_log_file"] = error_log_file - if cache_directory is None: - _execute_task_without_cache( - interface=interface, task_dict=task_dict, future_queue=future_queue + if cache_directory is None and interface is not None: + result_flag = _execute_task_without_cache( + interface=interface, + task_dict=task_dict, + future_queue=future_queue, ) - else: - _execute_task_with_cache( + elif cache_directory is not None and interface is not None: + result_flag = _execute_task_with_cache( interface=interface, task_dict=task_dict, future_queue=future_queue, cache_directory=cache_directory, cache_key=cache_key, ) + else: + raise ValueError() + if not result_flag: + if queue_join_on_shutdown: + future_queue.join() + break def _execute_task_without_cache( interface: SocketInterface, task_dict: dict, future_queue: queue.Queue -): +) -> bool: """ Execute the task in the task_dict by communicating it via the interface. @@ -114,13 +125,14 @@ def _execute_task_without_cache( _reset_task_dict( future_obj=f, future_queue=future_queue, task_dict=task_dict ) - interface.restart() + return interface.restart() else: interface.shutdown(wait=True) _task_done(future_queue=future_queue) f.set_exception(exception=thread_exception) else: _task_done(future_queue=future_queue) + return True def _execute_task_with_cache( @@ -129,7 +141,7 @@ def _execute_task_with_cache( future_queue: queue.Queue, cache_directory: str, cache_key: Optional[str] = None, -): +) -> bool: """ Execute the task in the task_dict by communicating it via the interface using the cache in the cache directory. @@ -167,7 +179,7 @@ def _execute_task_with_cache( _reset_task_dict( future_obj=f, future_queue=future_queue, task_dict=task_dict ) - interface.restart() + return interface.restart() else: interface.shutdown(wait=True) _task_done(future_queue=future_queue) @@ -180,6 +192,7 @@ def _execute_task_with_cache( future = task_dict["future"] future.set_result(result) _task_done(future_queue=future_queue) + return True def _task_done(future_queue: queue.Queue):