Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions test/test_multiprocessing_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
NO_MULTIPROCESSING_SPAWN,
run_tests,
TestCase,
parametrize,
instantiate_parametrized_tests
)

def _test_success_func(i):
Expand Down Expand Up @@ -92,6 +94,7 @@ def _test_nested(i, pids_queue, nested_child_sleep, start_method):
# Kill self. This should take down the child processes as well.
os.kill(os.getpid(), signal.SIGTERM)

@instantiate_parametrized_tests
class _TestMultiProcessing:
start_method = None

Expand Down Expand Up @@ -143,13 +146,28 @@ def test_terminate_signal(self):
with self.assertRaisesRegex(Exception, message):
mp.start_processes(_test_terminate_signal_func, nprocs=2, start_method=self.start_method)

def test_terminate_exit(self):
@parametrize("grace_period", [None, 5])
def test_terminate_exit(self, grace_period):
exitcode = 123
ctx = mp.start_processes(_test_terminate_exit_func, args=(exitcode,), nprocs=2, start_method=self.start_method, join=False)
pid1 = ctx.processes[1].pid
with self.assertRaisesRegex(
Exception,
"process 0 terminated with exit code %d" % exitcode,
):
mp.start_processes(_test_terminate_exit_func, args=(exitcode,), nprocs=2, start_method=self.start_method)
), self.assertLogs(level='WARNING') as logs:
while not ctx.join(grace_period=grace_period):
pass
if grace_period is None:
# pid1 is killed by signal.
expected_log = "Terminating process %d via signal" % pid1
self.assertIn(expected_log, logs.records[0].getMessage())
else:
# pid1 exits on its own.
self.assertFalse(logs.records)

# Check that no processes are left.
for p in ctx.processes:
self.assertFalse(p.is_alive())

def test_success_first_then_exception(self):
exitcode = 123
Expand Down
42 changes: 27 additions & 15 deletions torch/multiprocessing/spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,31 @@ def __init__(self, processes, error_files):
def pids(self):
return [int(process.pid) for process in self.processes]

def join(self, timeout=None):
def _join_procs_with_timeout(self, timeout: float):
"""Attempt to join all processes with a shared timeout."""
end = time.monotonic() + timeout
for process in self.processes:
time_to_wait = max(0, end - time.monotonic())
process.join(time_to_wait)

def join(
self, timeout: Optional[float] = None, grace_period: Optional[float] = None
):
r"""Join one or more processes within spawn context.

Attempt to join one or more processes in this spawn context.
If one of them exited with a non-zero exit status, this function
kills the remaining processes and raises an exception with the cause
of the first process exiting.
kills the remaining processes (optionally with a grace period)
and raises an exception with the cause of the first process exiting.

Returns ``True`` if all processes have been joined successfully,
``False`` if there are more processes that need to be joined.

Args:
timeout (float): Wait this long before giving up on waiting.
timeout (float): Wait this long (in seconds) before giving up on waiting.
grace_period (float): When any processes fail, wait this long (in seconds)
for others to shutdown gracefully before terminating them. If they
still don't exit, wait another grace period before killing them.
"""
# Ensure this function can be called even when we're done.
if len(self.sentinels) == 0:
Expand All @@ -147,22 +159,22 @@ def join(self, timeout=None):
if error_index is None:
# Return whether or not all processes have been joined.
return len(self.sentinels) == 0
# An error occurred. Clean-up all processes before returning.
# First, allow a grace period for processes to shutdown themselves.
if grace_period is not None:
self._join_procs_with_timeout(grace_period)
# Then, terminate processes that are still alive. Try SIGTERM first.
for process in self.processes:
if process.is_alive():
log.warning("Terminating process %s via signal SIGTERM", process.pid)
process.terminate()

# Assume failure. Terminate processes that are still alive.
# Try SIGTERM then SIGKILL if the process isn't going down.
# Try SIGKILL if the process isn't going down after another grace_period.
# The reason is related to python signal handling is limited
# to main thread and if that is in c/c++ land and stuck it won't
# to handle it. We have seen processes getting stuck not handling
# SIGTERM for the above reason.
timeout: int = 30
for process in self.processes:
if process.is_alive():
log.warning("Terminating process %s via signal SIGTERM", process.pid)
process.terminate()
end = time.monotonic() + timeout
for process in self.processes:
time_to_wait = max(0, end - time.monotonic())
process.join(time_to_wait)
self._join_procs_with_timeout(30 if grace_period is None else grace_period)
for process in self.processes:
if process.is_alive():
log.warning(
Expand Down
Loading