Skip to content

Commit

Permalink
Fixed a bug where starting multiple apply_async tasks with a task t…
Browse files Browse the repository at this point in the history
…imeout didn't interrupt all tasks when the timeout was reached. Fixes #98
  • Loading branch information
sybrenjansen committed Oct 27, 2023
1 parent 2e7655d commit a1d07d7
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 48 deletions.
6 changes: 5 additions & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@ Unreleased
----------

* Excluded the ``tests`` folder from MPIRE distributions (`#89`_)
* Added a workaround for semaphore leakage on macOS and fixed a bug when working in a fork context while the system default is spawn (`#92`_)
* Added a workaround for semaphore leakage on macOS and fixed a bug when working in a fork context while the system
default is spawn (`#92`_)
* Fix progressbar percentage on dashboard (`#101`_)
* Fixed a bug where starting multiple `apply_async` tasks with a task timeout didn't interrupt all tasks when the
timeout was reached (`#98`_)

.. _#89: https://github.com/sybrenjansen/mpire/issues/89
.. _#92: https://github.com/sybrenjansen/mpire/issues/92
.. _#98: https://github.com/sybrenjansen/mpire/issues/98
.. _#101: https://github.com/sybrenjansen/mpire/pull/101


Expand Down
7 changes: 6 additions & 1 deletion docs/usage/apply.rst
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,9 @@ Timeouts

The ``apply`` family of functions also has ``task_timeout``, ``worker_init_timeout`` and ``worker_exit_timeout``
arguments. These are timeouts for the task, the ``worker_init`` function and the ``worker_exit`` function, respectively.
They work completely the same as for the ``map`` functions. See :ref:`timeouts` for more information.
They work similarly as those for the ``map`` functions.

When a single task times out, only that task is cancelled. The other tasks will continue to run. When a worker init or
exit times out, the entire pool is stopped.

See :ref:`timeouts` for more information.
7 changes: 6 additions & 1 deletion mpire/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@


class StopWorker(Exception):
""" Exception used to kill workers from the main process """
""" Exception used to kill a worker """
pass


class InterruptWorker(Exception):
""" Exception used to interrupt a worker """
pass


Expand Down
18 changes: 14 additions & 4 deletions mpire/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,18 +346,28 @@ def _get_task_config(_job_id) -> Tuple[str, Optional[float], Callable[[int, floa
else:
timeout_func_name, timeout_var, has_timed_out_func = _get_task_config(job_id)

# If timeout has expired, then send kill signal and set job to failed
# If timeout has expired set job to failed
if timeout_var is not None and has_timed_out_func(worker_id, timeout_var):
self._worker_comms.signal_exception_thrown(job_id)

# If we're dealing with a map/init/exit task, send a kill signal to all workers. Otherwise, we're
# dealing with an apply task and we only interrupt that one
kill_pool = (
job_id in {INIT_FUNC, EXIT_FUNC} or
isinstance(self._cache[job_id], UnorderedAsyncResultIterator)
)
if kill_pool:
self._worker_comms.signal_exception_thrown(job_id)
self._send_kill_signal_to_worker(worker_id)
err = TimeoutError(f"Worker-{worker_id} {timeout_func_name} timed out (timeout={timeout_var})")

# When a worker_init times out, the pool shuts down and we set all tasks that haven't completed yet
# to failed
err = TimeoutError(f"Worker-{worker_id} {timeout_func_name} timed out (timeout={timeout_var})")
job_ids = (set(self._cache.keys()) - {MAIN_PROCESS, EXIT_FUNC}) if job_id == INIT_FUNC else {job_id}
for job_id in job_ids:
self._cache[job_id]._set(False, err)
return

if kill_pool:
return

# Check this every once in a while
time.sleep(0.1)
Expand Down
84 changes: 43 additions & 41 deletions mpire/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
WorkerComms)
from mpire.context import FORK_AVAILABLE, MP_CONTEXTS, RUNNING_WINDOWS
from mpire.dashboard.connection_utils import DashboardConnectionDetails, set_dashboard_connection
from mpire.exception import CannotPickleExceptionError, StopWorker
from mpire.exception import CannotPickleExceptionError, InterruptWorker, StopWorker
from mpire.insights import WorkerInsights
from mpire.params import WorkerMapParams, WorkerPoolParams
from mpire.tqdm_utils import TqdmConnectionDetails, TqdmManager
Expand Down Expand Up @@ -76,6 +76,7 @@ def __init__(self, worker_id: int, pool_params: WorkerPoolParams, map_params: Wo
self.progress_bar_n_tasks_completed = 0
self.max_task_duration_last_updated = datetime.now()
self.max_task_duration_list = self.worker_insights.get_max_task_duration_list(self.worker_id)
self.is_apply_func = False
self.last_job_id = None
self.init_func_completed = False

Expand Down Expand Up @@ -150,12 +151,15 @@ def run(self) -> None:
if self.map_params.worker_init and self._run_init_func():
return

# We only set the is_apply_func flag when we are not running the init/exit functions
self.is_apply_func = is_apply_func

results = []
for args in next_chunked_args:

# Try to run this function and save results
results_part, success, send_results, should_shut_down = self._run_func(
apply_func if is_apply_func else func, job_id, args, is_apply_func
apply_func if is_apply_func else func, job_id, args
)
if should_shut_down:
return
Expand All @@ -173,6 +177,7 @@ def run(self) -> None:

# In case an exception occurred and we need to return, we want to call task_done no matter what
finally:
self.is_apply_func = False
self.worker_comms.task_done(self.worker_id)

# Update task insights
Expand Down Expand Up @@ -231,22 +236,27 @@ def _on_kill_exit_gracefully(self, *_) -> None:

def _on_exception_exit_gracefully(self, *_) -> None:
"""
This function is called when the main process sends a kill signal to this process. This can only mean another
child process encountered an error which means we should exit.
This function is called when the main process sends a kill signal to this process. This can mean two things:
- Another child process encountered an error in either the init/exit or map function which means we should exit
- The current task timed out and we should interrupt it
This signal is only send when either the user defined function, worker init or worker exit function is running.
In such cases, a StopWorker exception is raised, which is caught by the ``_run_safely()`` function, so we can
quit gracefully.
"""
self.worker_comms.signal_kill_signal_received()
raise StopWorker
exception_job_id = self.worker_comms.get_worker_working_on_job(self.worker_comms.exception_thrown_by())
if exception_job_id in {INIT_FUNC, EXIT_FUNC} or not self.is_apply_func:
self.worker_comms.signal_kill_signal_received()
raise StopWorker
else:
raise InterruptWorker

def _on_exception_exit_gracefully_windows(self) -> None:
"""
Windows doesn't fully support signals as Unix-based systems do. Therefore, we have to work around it. This
function is started in a thread. We wait for a kill signal (Event object) and interrupt the main thread if we
got it (derived from https://stackoverflow.com/a/40281422). This will raise a KeyboardInterrupt, which is then
caught by the signal handler, which in turn checks if we need to raise a StopWorker.
caught by the signal handler, which in turn checks if we need to raise a StopWorker or InterruptWorker.
Note: functions that release the GIL won't be interupted by this procedure (e.g., time.sleep). If graceful
shutdown takes too long the process will be terminated by the main process.
Expand Down Expand Up @@ -359,24 +369,22 @@ def _init_func():
if self.map_params.worker_init_timeout is not None:
try:
self.worker_comms.signal_worker_init_started(self.worker_id)
_, _, _, should_shut_down = self._run_safely(_init_func, INIT_FUNC, is_apply_func=False)
_, _, _, should_shut_down = self._run_safely(_init_func, INIT_FUNC)
finally:
self.worker_comms.signal_worker_init_completed(self.worker_id)
else:
_, _, _, should_shut_down = self._run_safely(_init_func, INIT_FUNC, is_apply_func=False)
_, _, _, should_shut_down = self._run_safely(_init_func, INIT_FUNC)

self.init_func_completed = True
return should_shut_down

def _run_func(self, func: Callable, job_id: Optional[int], args: Optional[List],
is_apply_func: bool) -> Tuple[Any, bool, bool, bool]:
def _run_func(self, func: Callable, job_id: Optional[int], args: Optional[List]) -> Tuple[Any, bool, bool, bool]:
"""
Runs the main function when provided.
:param func: Function to call
:param job_id: Job ID
:param args: Args to pass to the function
:param is_apply_func: Whether this is an apply function
:return: Tuple containing results from the function and boolean values indicating whether the function was run
successfully, whether the results should send on the queue, and indicating whether the worker needs to shut
down
Expand All @@ -386,16 +394,16 @@ def _run_func(self, func: Callable, job_id: Optional[int], args: Optional[List],
self.last_job_id = job_id

def _func():
with TimeIt(self.worker_insights.worker_working_time, self.worker_id,
self.max_task_duration_list, lambda: self._format_args(args, is_apply_func, separator=' | ')):
_results = func(*args) if is_apply_func else func(args)
with TimeIt(self.worker_insights.worker_working_time, self.worker_id, self.max_task_duration_list,
lambda: self._format_args(args, separator=' | ')):
_results = func(*args) if self.is_apply_func else func(args)
self.worker_insights.update_n_completed_tasks(self.worker_id)
return _results

# Update timeout info
try:
self.worker_comms.signal_worker_task_started(self.worker_id)
results, success, send_results, should_shut_down = self._run_safely(_func, job_id, args, is_apply_func)
results, success, send_results, should_shut_down = self._run_safely(_func, job_id, args)
finally:
self.worker_comms.signal_worker_task_completed(self.worker_id)

Expand All @@ -418,30 +426,28 @@ def _exit_func():
if self.map_params.worker_exit_timeout is not None:
try:
self.worker_comms.signal_worker_exit_started(self.worker_id)
results, success, send_results, should_shut_down = self._run_safely(_exit_func, EXIT_FUNC,
is_apply_func=False)
results, success, send_results, should_shut_down = self._run_safely(_exit_func, EXIT_FUNC)
finally:
self.worker_comms.signal_worker_exit_completed(self.worker_id)
else:
results, success, send_results, should_shut_down = self._run_safely(_exit_func, EXIT_FUNC,
is_apply_func=False)
results, success, send_results, should_shut_down = self._run_safely(_exit_func, EXIT_FUNC)

if should_shut_down:
return True
elif send_results:
self.worker_comms.add_results(self.worker_id, [(EXIT_FUNC, True, results)])
return False

def _run_safely(self, func: Callable, job_id: Optional[int], exception_args: Optional[Any] = None,
is_apply_func: bool = False) -> Tuple[Any, bool, bool, bool]:
def _run_safely(
self, func: Callable, job_id: Optional[int], exception_args: Optional[Any] = None
) -> Tuple[Any, bool, bool, bool]:
"""
A rather complex locking and exception mechanism is used here so we can make sure we only raise an exception
when we should. See `_exit_gracefully` for more information.
:param func: Function to run
:param job_id: Job ID
:param exception_args: Arguments to pass to `_format_args` when an exception occurred
:param is_apply_func: Whether this is an apply function
:return: Tuple containing results from the function and boolean values indicating whether the function was run
successfully, whether the results should send on the queue, and indicating whether the worker needs to shut
down
Expand All @@ -458,29 +464,28 @@ def _run_safely(self, func: Callable, job_id: Optional[int], exception_args: Opt
results = func()
self.worker_comms.set_worker_running_task(self.worker_id, False)

except StopWorker:
# The main process tells us to stop working. When we were running an apply function, we're just going to
# continue. Otherwise, we're shutting down
if is_apply_func:
return None, False, False, False
raise
except InterruptWorker:
# The main process tells us to interrupt the current task. This means a timeout has expired and we
# need to stop this task and continue with the next one
return None, False, False, False

except (Exception, SystemExit) as err:
# An exception occurred inside the provided function. Let the signal handler know it shouldn't raise any
# StopWorker exceptions from the parent process anymore, we got this.
# StopWorker or InterruptWorker exceptions from the parent process anymore, we got this.
self.worker_comms.set_worker_running_task(self.worker_id, False)

if is_apply_func:
if self.is_apply_func:
# Obtain exception and send it back as normal results. The first False indicates the job has failed
exception = self._get_exception(exception_args, True, err)
exception = self._get_exception(exception_args, err)
return exception, False, True, False
else:
# Pass exception to parent process and stop
self._raise(job_id, exception_args, err)
raise StopWorker

except StopWorker:
# Stop working
# Either the main process tells us to stop working and kill all workers, or an exception occurred in this
# worker and we need to stop.
return None, False, False, True

# Carry on
Expand All @@ -502,24 +507,22 @@ def _raise(self, job_id: Optional[int], args: Optional[Any], err: Union[Exceptio
self.worker_comms.signal_exception_thrown(job_id)

# Get exception and traceback string
exception = self._get_exception(args, False, err)
exception = self._get_exception(args, err)

# Add exception
self.worker_comms.add_results(self.worker_id, [(job_id, False, exception)])

def _get_exception(self, args: Optional[Any], is_apply_func: bool,
err: Union[Exception, SystemExit]) -> Tuple[type, Tuple, Dict, str]:
def _get_exception(self, args: Optional[Any], err: Union[Exception, SystemExit]) -> Tuple[type, Tuple, Dict, str]:
"""
Try to pickle the exception and create a traceback string
:param args: Funtion arguments where exception was raised
:param is_apply_func: Whether this is an apply function
:param err: Exception that was raised
:return: Tuple containing the exception type, args, state, and a traceback string
"""
# Create traceback string
traceback_str = f"\n\nException occurred in Worker-{self.worker_id} with the following arguments:\n" \
f"{self._format_args(args, is_apply_func)}\n{traceback.format_exc()}"
f"{self._format_args(args)}\n{traceback.format_exc()}"

# Sometimes an exception cannot be pickled (i.e., we get the _pickle.PickleError: Can't pickle
# <class ...>: it's not the same object as ...). We check that here by trying the pickle.dumps manually.
Expand All @@ -534,17 +537,16 @@ def _get_exception(self, args: Optional[Any], is_apply_func: bool,

return type(err), err.args, err.__dict__, traceback_str

def _format_args(self, args: Optional[Any], is_apply_func: bool = False, separator: str = '\n') -> str:
def _format_args(self, args: Optional[Any], separator: str = '\n') -> str:
"""
Format the function arguments to a string form.
:param args: Funtion arguments
:param is_apply_func: Whether this is an apply function
:param separator: String to use as separator between arguments
:return: String containing the task arguments
"""
# Determine function arguments
if is_apply_func:
if self.is_apply_func:
func_args, func_kwargs = args
else:
func_args = args[1] if args and self.worker_comms.keep_order() else args
Expand Down
21 changes: 21 additions & 0 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,6 +1556,19 @@ def test_worker_exit_timeout(self):
pool.apply(self._f1, self.test_data[0], worker_exit=self._exit2, worker_exit_timeout=0.01)
pool.stop_and_join()

def test_apply_async_multiple_task_timeout(self):
""" Test that some apply_async() tasks time out correctly and don't kill the whole pool """
print()
for start_method in tqdm(TEST_START_METHODS):
with WorkerPool(n_jobs=3, start_method=start_method) as pool:
results = [pool.apply_async(self._f3, (i,), task_timeout=0.1) for i in range(6)]
for i, result in enumerate(results):
if i % 2 == 0:
self.assertEqual(result.get(), i)
else:
with self.assertRaises(TimeoutError):
result.get()

@staticmethod
def _init1():
pass
Expand All @@ -1581,6 +1594,14 @@ def _exit1():
def _exit2():
time.sleep(1)

@staticmethod
def _f3(x):
if x % 2 == 0:
return x
else:
time.sleep(1)
return x


class OrderTasksTest(unittest.TestCase):
"""
Expand Down

0 comments on commit a1d07d7

Please sign in to comment.