diff --git a/test/distributed/elastic/multiprocessing/api_test.py b/test/distributed/elastic/multiprocessing/api_test.py index 9658ed087ab05..75e903807ff9b 100644 --- a/test/distributed/elastic/multiprocessing/api_test.py +++ b/test/distributed/elastic/multiprocessing/api_test.py @@ -465,6 +465,30 @@ def test_function_raise(self): self.assertTrue(pc._stderr_tail.stopped()) self.assertTrue(pc._stdout_tail.stopped()) + def test_wait_for_all_child_procs_to_exit(self): + """ + Tests that MultiprocessingContext actually waits for + the child process to exit (not just that the entrypoint fn has + finished running). + """ + + mpc = MultiprocessContext( + name="echo", + entrypoint=echo0, + args={}, + envs={}, + start_method="spawn", + logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()), + ) + + with mock.patch.object( + mpc, "_is_done", return_value=True + ), mock.patch.object(mpc, "_pc"), mock.patch.object( + mpc._pc, "join", side_effect=[True, False, False, True] + ) as mock_join: + mpc._poll() + self.assertEqual(4, mock_join.call_count) + ######################################## # start_processes as binary tests ######################################## diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index 72c3955e7d1e5..acfba81899c04 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -670,9 +670,13 @@ def _poll(self) -> Optional[RunProcsResult]: if self._is_done(): # we should ALWAYS have ALL the return values when all the processes are done self._worker_finished_event.set() - # Wait untill all processes are finished. At this point workers finished executing - # user function - self._pc.join() + + # At this point workers finished running the user function + # But the child process might still have not exited. Wait for them. + # pc.join() blocks [forever] until "a" proc exits. Loop until all of them exits. + while not self._pc.join(): + logger.debug("entrypoint fn finished, waiting for all child procs to exit...") + _validate_full_rank( self._return_values, self.nprocs, "return_value queue" )