Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch/distributed] Bugfix: wait for all child procs to exit before c… #125969

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 24 additions & 0 deletions test/distributed/elastic/multiprocessing/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,30 @@ def test_function_raise(self):
self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped())

def test_wait_for_all_child_procs_to_exit(self):
"""
Tests that MultiprocessingContext actually waits for
the child process to exit (not just that the entrypoint fn has
finished running).
"""

mpc = MultiprocessContext(
name="echo",
entrypoint=echo0,
args={},
envs={},
start_method="spawn",
logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
)

with mock.patch.object(
mpc, "_is_done", return_value=True
), mock.patch.object(mpc, "_pc"), mock.patch.object(
mpc._pc, "join", side_effect=[True, False, False, True]
) as mock_join:
mpc._poll()
self.assertEqual(4, mock_join.call_count)

########################################
# start_processes as binary tests
########################################
Expand Down
10 changes: 7 additions & 3 deletions torch/distributed/elastic/multiprocessing/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,9 +670,13 @@ def _poll(self) -> Optional[RunProcsResult]:
if self._is_done():
# we should ALWAYS have ALL the return values when all the processes are done
self._worker_finished_event.set()
# Wait untill all processes are finished. At this point workers finished executing
# user function
self._pc.join()

# At this point workers finished running the user function
# But the child process might still have not exited. Wait for them.
# pc.join() blocks [forever] until "a" proc exits. Loop until all of them exits.
while not self._pc.join():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we have a timeout on this? wondering what happens if we have a dead/hung worker process?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you reach this line, we've already validated that either:

  1. the entrypoint function actually ran and returned a result
  2. -- or -- at least one of the child procs have failed (and a SIGTERM was sent to the rest)

We're waiting for the spawned child proc to exit after the user-provided function has already returned.
This potentially could hang but we were waiting for _pc.join() indefinitely before this change as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it -- sgtm

logger.debug("entrypoint fn finished, waiting for all child procs to exit...")
kiukchung marked this conversation as resolved.
Show resolved Hide resolved

_validate_full_rank(
self._return_values, self.nprocs, "return_value queue"
)
Expand Down