Skip to content

Commit

Permalink
Defunct processes (#39)
Browse files Browse the repository at this point in the history
* MPIRE now handles defunct child processes properly, instead of deadlocking (`#34`_)
* Added benchmark highlights to README (`#38`_)

Co-authored-by: sybrenjansen <sybren.jansen@gmail.com>
  • Loading branch information
sybrenjansen and sybrenjansen committed Apr 25, 2022
1 parent f238b70 commit a6659ac
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 17 deletions.
8 changes: 8 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,14 @@ MPIRE has been benchmarked on three different benchmarks: numerical computation,
initialization. More details on these benchmarks can be found in this `blog post`_. All code for these benchmarks can
be found in this project_.

In short, the main reasons why MPIRE is faster are:

- When ``fork`` is available we can make use of copy-on-write shared objects, which reduces the need to copy objects
that need to be shared over child processes
- Workers can hold state over multiple tasks. Therefore you can choose to load a big file or send resources over only
once per worker
- Automatic task chunking

The following graph shows the average normalized results of all three benchmarks. Results for individual benchmarks
can be found in the `blog post`_. The benchmarks were run on a Linux machine with 20 cores, with disabled hyperthreading
and 200GB of RAM. For each task, experiments were run with different numbers of processes/workers and results were
Expand Down
10 changes: 10 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
Changelog
=========

Master
------

* MPIRE now handles defunct child processes properly, instead of deadlocking (`#34`_)
* Added benchmark highlights to README (`#38`_)

.. _#34: https://github.com/Slimmer-AI/mpire/issues/34
.. _#38: https://github.com/Slimmer-AI/mpire/issues/38


2.3.4
-----

Expand Down
13 changes: 12 additions & 1 deletion mpire/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ def __init__(self, ctx: mp.context.BaseContext, n_jobs: int) -> None:
# Array where the child processes can request a restart
self._worker_done_array = None

# List of Event objects to indicate whether workers are alive
# List of Event objects to indicate whether workers are alive, together with accompanying locks
self._workers_dead = None
self._workers_dead_locks = None

# Queue where the child processes can pass on an encountered exception
self._exception_queue = None
Expand Down Expand Up @@ -112,6 +113,7 @@ def init_comms(self, has_worker_exit: bool, has_progress_bar: bool) -> None:
self._worker_done_array = self.ctx.Array('b', self.n_jobs, lock=False)
self._workers_dead = [self.ctx.Event() for _ in range(self.n_jobs)]
[worker_dead.set() for worker_dead in self._workers_dead]
self._workers_dead_locks = [self.ctx.Lock() for _ in range(self.n_jobs)]

# Exception related
self._exception_queue = self.ctx.JoinableQueue()
Expand Down Expand Up @@ -499,6 +501,15 @@ def reset_worker_restart(self, worker_id) -> None:
"""
self._worker_done_array[worker_id] = False

def get_worker_dead_lock(self, worker_id: int) -> mp.Lock:
"""
Returns the worker dead lock for a specific worker
:param worker_id: Worker ID
:return: Lock object
"""
return self._workers_dead_locks[worker_id]

def signal_worker_alive(self, worker_id: int) -> None:
"""
Indicate that a worker is alive
Expand Down
31 changes: 24 additions & 7 deletions mpire/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,15 @@ def _start_workers(self, progress_bar: bool) -> None:
self._workers.append(self._start_worker(worker_id))
logger.debug("Workers created")

def _restart_workers(self) -> List[Any]:
def _check_worker_status(self) -> List[Any]:
"""
Restarts workers that need to be restarted.
Checks the worker status:
- If the worker is supposed to be alive, but isn't, terminate.
- Restarts workers that need to be restarted.
:return: List of unordered results produces by workers
"""
# Check restarts
obtained_results = []
for worker_id in self._worker_comms.get_worker_restarts():
# Obtain results from exit results queue (should be done before joining the worker)
Expand All @@ -178,6 +181,20 @@ def _restart_workers(self) -> List[Any]:
# Start new worker
self._workers[worker_id] = self._start_worker(worker_id)

# Check that workers that are supposed to be alive, are actually alive. If not, then a worker died unexpectedly.
# Note that a worker can be alive, but their alive status is still False. This doesn't really matter, because we
# know the worker is alive according to the OS. The only way we know that something bad happened is when a
# worker is supposed to be alive but according to the OS it's not.
for worker_id in range(self.pool_params.n_jobs):
with self._worker_comms.get_worker_dead_lock(worker_id):
worker_died = self._worker_comms.is_worker_alive(worker_id) and not self._workers[worker_id].is_alive()
if worker_died:
# We need to add an exception if we're using the progress bar handler
if self._worker_comms.has_progress_bar():
self._worker_comms.add_exception(RuntimeError, f"Worker-{worker_id} died unexpectedly")
self.terminate()
raise RuntimeError(f"Worker-{worker_id} died unexpectedly")

return obtained_results

def _start_worker(self, worker_id: int) -> mp.Process:
Expand Down Expand Up @@ -547,8 +564,8 @@ def imap_unordered(self, func: Callable, iterable_of_args: Union[Sized, Iterable
except queue.Empty:
pass

# Restart workers if necessary. This can yield intermediate results
for results in self._restart_workers():
# Check worker status (e.g., restarts). This can yield intermediate results
for results in self._check_worker_status():
yield from results
n_active -= 1

Expand All @@ -560,8 +577,8 @@ def imap_unordered(self, func: Callable, iterable_of_args: Union[Sized, Iterable
except queue.Empty:
pass

# Restart workers if necessary. This can yield intermediate results
for results in self._restart_workers():
# Check worker status (e.g., restarts). This can yield intermediate results
for results in self._check_worker_status():
yield from results
n_active -= 1

Expand Down Expand Up @@ -668,7 +685,7 @@ def stop_and_join(self, progress_bar_handler: Optional[ProgressBarHandler] = Non
t.join(timeout=0.01)
if not t.is_alive():
break
self._restart_workers()
self._check_worker_status()
logger.debug("Done joining task queues")

# When an exception occurred in the above process (i.e., the worker init function raises), we need to handle
Expand Down
12 changes: 7 additions & 5 deletions mpire/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _exit_gracefully(self, *_) -> None:
self.is_running = False
raise StopWorker

def _exit_gracefully_windows(self):
def _exit_gracefully_windows(self) -> None:
"""
Windows doesn't fully support signals as Unix-based systems do. Therefore, we have to work around it. This
function is started in a thread. We wait for a kill signal (Event object) and interrupt the main thread if we
Expand All @@ -134,7 +134,8 @@ def run(self) -> None:
t = Thread(target=self._exit_gracefully_windows)
t.start()

self.worker_comms.signal_worker_alive(self.worker_id)
with self.worker_comms.get_worker_dead_lock(self.worker_id):
self.worker_comms.signal_worker_alive(self.worker_id)

# Set tqdm and dashboard connection details. This is needed for nested pools and in the case forkserver or
# spawn is used as start method
Expand Down Expand Up @@ -235,7 +236,8 @@ def run(self) -> None:
self.worker_comms.signal_worker_restart(self.worker_id)

finally:
self.worker_comms.signal_worker_dead(self.worker_id)
with self.worker_comms.get_worker_dead_lock(self.worker_id):
self.worker_comms.signal_worker_dead(self.worker_id)

def _get_func(self, additional_args: List) -> Callable:
"""
Expand Down Expand Up @@ -327,7 +329,7 @@ def _run_safely(self, func: Callable, exception_args: Optional[Any] = None,
# The main process tells us to stop working, shutting down
raise

except Exception as err:
except (Exception, SystemExit) as err:
# An exception occurred inside the provided function. Let the signal handler know it shouldn't raise any
# StopWorker exceptions from the parent process anymore, we got this.
with self.is_running_lock:
Expand All @@ -344,7 +346,7 @@ def _run_safely(self, func: Callable, exception_args: Optional[Any] = None,
# Carry on
return results, False

def _raise(self, args: Any, no_args: bool, err: Exception) -> None:
def _raise(self, args: Any, no_args: bool, err: Union[Exception, SystemExit]) -> None:
"""
Create exception and pass it to the parent process. Let other processes know an exception is set
Expand Down
7 changes: 7 additions & 0 deletions tests/test_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_init_comms(self):
self.assertListEqual(comms._exit_results_queues, [])
self.assertIsNone(comms._all_exit_results_obtained)
self.assertIsNone(comms._worker_done_array)
self.assertIsNone(comms._workers_dead_locks)
self.assertIsNone(comms._workers_dead)
self.assertIsNone(comms._exception_queue)
self.assertIsInstance(comms.exception_lock, lock_type)
Expand Down Expand Up @@ -82,6 +83,9 @@ def test_init_comms(self):
for worker_dead in comms._workers_dead:
self.assertIsInstance(worker_dead, event_type)
self.assertTrue(worker_dead.is_set())
self.assertEqual(len(comms._workers_dead_locks), n_jobs)
for worker_dead_lock in comms._workers_dead_locks:
self.assertIsInstance(worker_dead_lock, lock_type)
self.assertIsInstance(comms._exception_queue, joinable_queue_type)
self.assertFalse(comms._exception_thrown.is_set())
self.assertFalse(comms._kill_signal_received.is_set())
Expand Down Expand Up @@ -142,6 +146,9 @@ def test_init_comms(self):
for worker_dead in comms._workers_dead:
self.assertIsInstance(worker_dead, event_type)
self.assertTrue(worker_dead.is_set())
self.assertEqual(len(comms._workers_dead_locks), n_jobs)
for worker_dead_lock in comms._workers_dead_locks:
self.assertIsInstance(worker_dead_lock, lock_type)
self.assertIsInstance(comms._exception_queue, joinable_queue_type)
self.assertFalse(comms._exception_thrown.is_set())
self.assertFalse(comms._kill_signal_received.is_set())
Expand Down
82 changes: 78 additions & 4 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
import os
import signal
import types
import unittest
import warnings
from itertools import product, repeat
from multiprocessing import Barrier, Value
from threading import Thread
from unittest.mock import patch

import numpy as np
Expand All @@ -25,7 +27,7 @@ def square(idx, x):
return idx, x * x


def extremely_large_output(idx, x):
def extremely_large_output(idx, _):
return idx, os.urandom(1024 * 1024)


Expand Down Expand Up @@ -691,9 +693,9 @@ def test_start_methods(self):
pool.map(self._square_daemon, ((X,) for X in repeat(self.test_data, 3)), chunk_size=1)

@staticmethod
def _square_daemon(X):
def _square_daemon(x):
with WorkerPool(n_jobs=4) as pool:
return pool.map(square, X, chunk_size=1)
return pool.map(square, x, chunk_size=1)


class CPUPinningTest(unittest.TestCase):
Expand Down Expand Up @@ -722,7 +724,7 @@ def test_cpu_pinning(self):
(4, [[0, 3]], [[0, 3], [0, 3], [0, 3], [0, 3]])]:
# The test has been designed for a system with at least 4 cores. We'll skip those test cases where the CPU
# IDs exceed the number of CPUs.
if cpu_ids is not None and np.array(cpu_ids).max() >= cpu_count():
if cpu_ids is not None and np.array(cpu_ids).max(initial=0) >= cpu_count():
continue

with self.subTest(n_jobs=n_jobs, cpu_ids=cpu_ids), patch('mpire.pool.set_cpu_affinity') as p, \
Expand Down Expand Up @@ -1156,6 +1158,52 @@ def test_start_methods(self):
with self.subTest(function='square_raises_on_idx', map='imap'), self.assertRaises(ValueError):
list(pool.imap_unordered(self._square_raises_on_idx, self.test_data, progress_bar=progress_bar))

def test_defunct_processes_exit(self):
"""
Tests if MPIRE correctly shuts down after process becomes defunct using exit()
"""
print()
for n_jobs, progress_bar, worker_lifespan in [(1, False, None),
(3, True, 1),
(3, False, 3)]:
for start_method in TEST_START_METHODS:
# Progress bar on Windows + threading is not supported right now
if RUNNING_WINDOWS and start_method == 'threading' and progress_bar:
continue
self.logger.debug(f"========== {start_method}, {n_jobs}, {progress_bar}, {worker_lifespan} ==========")
with self.subTest(n_jobs=n_jobs, progress_bar=progress_bar, worker_lifespan=worker_lifespan,
start_method=start_method), self.assertRaises(SystemExit), \
WorkerPool(n_jobs=n_jobs, start_method=start_method) as pool:
pool.map(self._exit, range(100), progress_bar=progress_bar, worker_lifespan=worker_lifespan)

def test_defunct_processes_kill(self):
"""
Tests if MPIRE correctly shuts down after one process becomes defunct using os.kill().
We kill worker 0 and to be sure it's alive we set an event object and then go in an infinite loop. The kill
thread waits until the event is set and then kills the worker. The other workers are also ensured to have done
something so we can test what happens during restarts
"""
print()
for n_jobs, progress_bar, worker_lifespan in [(1, False, None),
(3, True, 1),
(3, False, 3)]:
for start_method in TEST_START_METHODS:
# Can't kill threads
if start_method == 'threading':
continue

self.logger.debug(f"========== {start_method}, {n_jobs}, {progress_bar}, {worker_lifespan} ==========")
with self.subTest(n_jobs=n_jobs, progress_bar=progress_bar, worker_lifespan=worker_lifespan,
start_method=start_method), self.assertRaises(RuntimeError), \
WorkerPool(n_jobs=n_jobs, pass_worker_id=True, start_method=start_method) as pool:
events = [pool.ctx.Event() for _ in range(n_jobs)]
kill_thread = Thread(target=self._kill_process, args=(events[0], pool))
kill_thread.start()
pool.set_shared_objects(events)
pool.map(self._worker_0_sleeps_others_square, range(1000), progress_bar=progress_bar,
worker_lifespan=worker_lifespan, chunk_size=1)

@staticmethod
def _square_raises(_, x):
raise ValueError(x)
Expand All @@ -1166,3 +1214,29 @@ def _square_raises_on_idx(idx, x):
raise ValueError(x)
else:
return idx, x * x

@staticmethod
def _exit(_):
exit()

@staticmethod
def _worker_0_sleeps_others_square(worker_id, events, x):
"""
Worker 0 waits until the other workers have at least spun up and then sets her event and sleeps
"""
if worker_id == 0:
[event.wait() for event in events[1:]]
events[0].set()
while True:
pass
else:
events[worker_id].set()
return x * x

@staticmethod
def _kill_process(event, pool):
"""
Wait for event and kill
"""
event.wait()
pool._workers[0].terminate()

0 comments on commit a6659ac

Please sign in to comment.