Skip to content

Commit

Permalink
[torch/distributed] Bugfix: wait for all child procs to exit before c…
Browse files Browse the repository at this point in the history
…losing torch.distributed.elastic.multiprocessing.api.ProcessContext
  • Loading branch information
kiukchung committed May 13, 2024
1 parent 7f1d5ab commit 0bd79bd
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
24 changes: 24 additions & 0 deletions test/distributed/elastic/multiprocessing/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,30 @@ def test_function_raise(self):
self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped())

def test_wait_for_all_child_procs_to_exit(self):
"""
Tests that MultiprocessingContext actually waits for
the child process to exit (not just that the entrypoint fn has
finished running).
"""

mpc = MultiprocessContext(
name="echo",
entrypoint=echo0,
args={},
envs={},
start_method="spawn",
logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
)

with mock.patch.object(
mpc, "_is_done", return_value=True
), mock.patch.object(mpc, "_pc"), mock.patch.object(
mpc._pc, "join", side_effect=[True, False, False, True]
) as mock_join:
mpc._poll()
self.assertEqual(4, mock_join.call_count)

########################################
# start_processes as binary tests
########################################
Expand Down
10 changes: 7 additions & 3 deletions torch/distributed/elastic/multiprocessing/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,9 +670,13 @@ def _poll(self) -> Optional[RunProcsResult]:
if self._is_done():
# we should ALWAYS have ALL the return values when all the processes are done
self._worker_finished_event.set()
# Wait untill all processes are finished. At this point workers finished executing
# user function
self._pc.join()

# At this point workers finished running the user function
# But the child process might still have not exited. Wait for them.
# pc.join() blocks [forever] until "a" proc exits. Loop until all of them exits.
while not self._pc.join():
logger.debug("entrypoint fn finished, waiting for all child procs to exit...")

_validate_full_rank(
self._return_values, self.nprocs, "return_value queue"
)
Expand Down

0 comments on commit 0bd79bd

Please sign in to comment.