Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions executorlib/standalone/interactive/communication.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import sys
from socket import gethostname
from typing import Any, Optional
from typing import Any, Callable, Optional

import cloudpickle
import zmq
Expand Down Expand Up @@ -43,6 +43,7 @@ def __init__(
self._logger = logging.getLogger("executorlib")
self._spawner = spawner
self._command_lst: list[str] = []
self._stop_function: Optional[Callable] = None

def send_dict(self, input_dict: dict):
"""
Expand Down Expand Up @@ -107,25 +108,35 @@ def bind_to_random_port(self) -> int:
def bootup(
self,
command_lst: list[str],
):
stop_function: Optional[Callable] = None,
) -> bool:
"""
Boot up the client process to connect to the SocketInterface.

Args:
command_lst (list): list of strings to start the client process
"""
self._command_lst = command_lst
self._spawner.bootup(
self._stop_function = stop_function
if not self._spawner.bootup(
command_lst=command_lst,
)
stop_function=stop_function,
):
self._reset_socket()
return False
return True

def restart(self):
"""
Restart the client process to onnect to the SocketInterface.
"""
self._spawner.bootup(
if not self._spawner.bootup(
command_lst=self._command_lst,
)
stop_function=self._stop_function,
):
self._reset_socket()
return False
return True

def shutdown(self, wait: bool = True):
"""
Expand All @@ -140,14 +151,17 @@ def shutdown(self, wait: bool = True):
input_dict={"shutdown": True, "wait": wait}
)
self._spawner.shutdown(wait=wait)
self._reset_socket()
return result

def _reset_socket(self):
if self._socket is not None:
self._socket.close()
if self._context is not None:
self._context.term()
self._process = None
self._socket = None
self._context = None
return result

def __del__(self):
"""
Expand All @@ -163,7 +177,8 @@ def interface_bootup(
hostname_localhost: Optional[bool] = None,
log_obj_size: bool = False,
worker_id: Optional[int] = None,
) -> SocketInterface:
stop_function: Optional[Callable] = None,
) -> Optional[SocketInterface]:
"""
Start interface for ZMQ communication

Expand Down Expand Up @@ -202,10 +217,13 @@ def interface_bootup(
"--zmqport",
str(interface.bind_to_random_port()),
]
interface.bootup(
if interface.bootup(
command_lst=command_lst,
)
return interface
stop_function=stop_function,
):
return interface
else:
return None


def interface_connect(host: str, port: str) -> tuple[zmq.Context, zmq.Socket]:
Expand Down
9 changes: 6 additions & 3 deletions executorlib/standalone/interactive/spawner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import subprocess
from abc import ABC, abstractmethod
from typing import Optional
from typing import Callable, Optional

MPI_COMMAND = "mpiexec"

Expand Down Expand Up @@ -29,7 +29,8 @@ def __init__(
def bootup(
self,
command_lst: list[str],
):
stop_function: Optional[Callable] = None,
) -> bool:
"""
Method to start the interface.

Expand Down Expand Up @@ -87,7 +88,8 @@ def __init__(
def bootup(
self,
command_lst: list[str],
):
stop_function: Optional[Callable] = None,
) -> bool:
"""
Method to start the subprocess interface.

Expand All @@ -101,6 +103,7 @@ def bootup(
cwd=self._cwd,
stdin=subprocess.DEVNULL,
)
return True

def generate_command(self, command_lst: list[str]) -> list[str]:
"""
Expand Down
13 changes: 12 additions & 1 deletion executorlib/task_scheduler/interactive/blockallocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from executorlib.task_scheduler.base import TaskSchedulerBase
from executorlib.task_scheduler.interactive.shared import execute_tasks

_task_schedulder_dict: dict = {}


class BlockAllocationTaskScheduler(TaskSchedulerBase):
"""
Expand Down Expand Up @@ -61,11 +63,18 @@ def __init__(
executor_kwargs["queue_join_on_shutdown"] = False
self._process_kwargs = executor_kwargs
self._max_workers = max_workers
self_id = id(self)
self._self_id = self_id
_task_schedulder_dict[self._self_id] = False
self._set_process(
process=[
Thread(
target=execute_tasks,
kwargs=executor_kwargs | {"worker_id": worker_id},
kwargs=executor_kwargs
| {
"worker_id": worker_id,
"stop_function": lambda: _task_schedulder_dict[self_id],
},
)
for worker_id in range(self._max_workers)
],
Expand Down Expand Up @@ -155,7 +164,9 @@ def shutdown(self, wait: bool = True, *, cancel_futures: bool = False):
if self._future_queue is not None:
if cancel_futures:
cancel_items_in_queue(que=self._future_queue)
self._shutdown_flag = True
if isinstance(self._process, list):
_task_schedulder_dict[self._self_id] = True
for _ in range(len(self._process)):
self._future_queue.put({"shutdown": True, "wait": wait})
if wait:
Expand Down
6 changes: 4 additions & 2 deletions executorlib/task_scheduler/interactive/fluxspawner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Optional
from typing import Callable, Optional

import flux
import flux.job
Expand Down Expand Up @@ -75,7 +75,8 @@ def __init__(
def bootup(
self,
command_lst: list[str],
):
stop_function: Optional[Callable] = None,
) -> bool:
"""
Boot up the client process to connect to the SocketInterface.

Expand Down Expand Up @@ -126,6 +127,7 @@ def bootup(
)
else:
self._future = self._flux_executor.submit(jobspec=jobspec)
return True

def shutdown(self, wait: bool = True):
"""
Expand Down
6 changes: 5 additions & 1 deletion executorlib/task_scheduler/interactive/pysqaspawner.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
def bootup(
self,
command_lst: list[str],
stop_function: Optional[Callable] = None,
):
"""
Method to start the subprocess interface.
Expand All @@ -76,7 +77,10 @@ def bootup(
)
while True:
if self._check_process_helper(command_lst=command_lst):
break
return True
elif stop_function is not None and stop_function():
self.shutdown(wait=True)
return False
else:
sleep(1) # Wait for the process to start

Expand Down
35 changes: 24 additions & 11 deletions executorlib/task_scheduler/interactive/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def execute_tasks(
log_obj_size: bool = False,
error_log_file: Optional[str] = None,
worker_id: Optional[int] = None,
stop_function: Optional[Callable] = None,
**kwargs,
) -> None:
"""
Expand Down Expand Up @@ -63,39 +64,49 @@ def execute_tasks(
hostname_localhost=hostname_localhost,
log_obj_size=log_obj_size,
worker_id=worker_id,
stop_function=stop_function,
)
if init_function is not None:
if init_function is not None and interface is not None:
interface.send_dict(
input_dict={"init": True, "fn": init_function, "args": (), "kwargs": {}}
)
while True:
task_dict = future_queue.get()
if "shutdown" in task_dict and task_dict["shutdown"]:
interface.shutdown(wait=task_dict["wait"])
if interface is not None:
interface.shutdown(wait=task_dict["wait"])
_task_done(future_queue=future_queue)
if queue_join_on_shutdown:
future_queue.join()
break
elif "fn" in task_dict and "future" in task_dict:
if error_log_file is not None:
task_dict["error_log_file"] = error_log_file
if cache_directory is None:
_execute_task_without_cache(
interface=interface, task_dict=task_dict, future_queue=future_queue
if cache_directory is None and interface is not None:
result_flag = _execute_task_without_cache(
interface=interface,
task_dict=task_dict,
future_queue=future_queue,
)
else:
_execute_task_with_cache(
elif cache_directory is not None and interface is not None:
result_flag = _execute_task_with_cache(
interface=interface,
task_dict=task_dict,
future_queue=future_queue,
cache_directory=cache_directory,
cache_key=cache_key,
)
else:
raise ValueError()
if not result_flag:
if queue_join_on_shutdown:
future_queue.join()
break


def _execute_task_without_cache(
interface: SocketInterface, task_dict: dict, future_queue: queue.Queue
):
) -> bool:
"""
Execute the task in the task_dict by communicating it via the interface.

Expand All @@ -114,13 +125,14 @@ def _execute_task_without_cache(
_reset_task_dict(
future_obj=f, future_queue=future_queue, task_dict=task_dict
)
interface.restart()
return interface.restart()
else:
interface.shutdown(wait=True)
_task_done(future_queue=future_queue)
f.set_exception(exception=thread_exception)
else:
_task_done(future_queue=future_queue)
return True


def _execute_task_with_cache(
Expand All @@ -129,7 +141,7 @@ def _execute_task_with_cache(
future_queue: queue.Queue,
cache_directory: str,
cache_key: Optional[str] = None,
):
) -> bool:
"""
Execute the task in the task_dict by communicating it via the interface using the cache in the cache directory.

Expand Down Expand Up @@ -167,7 +179,7 @@ def _execute_task_with_cache(
_reset_task_dict(
future_obj=f, future_queue=future_queue, task_dict=task_dict
)
interface.restart()
return interface.restart()
else:
interface.shutdown(wait=True)
_task_done(future_queue=future_queue)
Expand All @@ -180,6 +192,7 @@ def _execute_task_with_cache(
future = task_dict["future"]
future.set_result(result)
_task_done(future_queue=future_queue)
return True


def _task_done(future_queue: queue.Queue):
Expand Down
Loading