Skip to content

Commit

Permalink
[torchelastic] Make sure torchelastic mp wait for queue to be drained…
Browse files Browse the repository at this point in the history
… before finishing the process (#55412)

Summary:
Pull Request resolved: #55412

The diff resolves bug where worker processes could exit before torchelastic process would read the return values. This is a rare event, but still can happen, e.g. https://fb.workplace.com/groups/319878845696681/permalink/512409069776990/

When users want to return torch.Tensor object from worker process, the torchelastic multiprocessing will fail. Currently worker process finishes its job after it writes output to the IPC queue without receiver process confirmation. When this happens, the underlying channel between worker and torchelastic process could be closed (in case of mp.SimpleQueue it is file descriptors, that is why we see FileNotFoundException: since worker process finished execution, the file descriptor just got deleted, and torchelastic process cannot find it).

Test Plan:
buck test mode/dev-nosan //caffe2/test/distributed/elastic/agent/server/test:local_agent_test

User workflow: f263531643

Reviewed By: cbalioglu

Differential Revision: D27602838

fbshipit-source-id: 29871178232e3af4ad3dec406c234aba9c5faba1
  • Loading branch information
aivanou authored and facebook-github-bot committed Apr 7, 2021
1 parent 3bb1f59 commit f5675f8
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 5 deletions.
Expand Up @@ -71,6 +71,13 @@ def _sad_function():
raise RuntimeError("sad because i throw")


def dummy_compute() -> torch.Tensor:
"""
returns a predefined size random Tensor
"""
return torch.rand(100, 100)


def _fatal_signal_function(expected_error_index: int, sig: int):
rank = int(os.environ["RANK"])
if rank == expected_error_index:
Expand Down Expand Up @@ -306,6 +313,16 @@ def run_job(
results.setdefault(role, []).append(run_result)
return results

@unittest.skipIf(
TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
)
def test_dummy_compute(self):
res = self.run_agent(Conf(entrypoint=dummy_compute, local_world_size=2))
self.assertFalse(res.is_failed())
for return_value in res.return_values.values():
self.assertIsInstance(return_value, torch.Tensor)
self.assertEqual((100, 100), return_value.shape)

@unittest.skipIf(
TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
)
Expand Down
29 changes: 29 additions & 0 deletions test/distributed/elastic/multiprocessing/api_test.py
Expand Up @@ -18,6 +18,7 @@
from typing import Dict, List
from unittest import mock

import torch
import torch.multiprocessing as mp
from torch.distributed.elastic.multiprocessing import ProcessFailure, start_processes
from torch.distributed.elastic.multiprocessing.api import (
Expand Down Expand Up @@ -143,6 +144,11 @@ def echo_large(size: int) -> Dict[int, str]:
out[idx] = f"test{idx}"
return out

def dummy_compute() -> torch.Tensor:
"""
returns a predefined size random Tensor
"""
return torch.rand(100, 100)

def redirects() -> List[Std]:
return [
Expand Down Expand Up @@ -205,6 +211,7 @@ def test_wrap(self):

for stdout_redir, stderr_redir in redirs:
queue = multiprocessing.SimpleQueue()
worker_finished_event_mock = mock.Mock()
_wrap(
local_rank=0,
fn=echo1,
Expand All @@ -213,12 +220,14 @@ def test_wrap(self):
stdout_redirects={0: stdout_redir},
stderr_redirects={0: stderr_redir},
ret_vals={0: queue},
queue_finished_reading_event=worker_finished_event_mock,
)
self.assertEqual("hello_0", queue.get())
if stdout_redir:
self.assert_in_file(["hello stdout from 0"], stdout_log)
if stderr_redir:
self.assert_in_file(["hello stderr from 0"], stderr_log)
worker_finished_event_mock.wait.assert_called_once()

def test_invalid_log_dir(self):
with tempfile.NamedTemporaryFile(dir=self.test_dir) as not_a_dir:
Expand Down Expand Up @@ -339,6 +348,26 @@ def test_function(self):
[f"hello stderr from {i}"], results.stderrs[i]
)

@unittest.skipIf(
TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
)
def test_function_with_tensor(self):
for start_method in self._start_methods:
pc = start_processes(
name="dummy_compute",
entrypoint=dummy_compute,
args={},
envs={},
log_dir=self.log_dir(),
start_method=start_method,
)

results = pc.wait()
self.assert_pids_noexist(pc.pids())
for return_value in results.return_values.values():
self.assertIsInstance(return_value, torch.Tensor)
self.assertEqual((100, 100), return_value.shape)

@unittest.skipIf(
TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan"
)
Expand Down
25 changes: 20 additions & 5 deletions torch/distributed/elastic/multiprocessing/api.py
Expand Up @@ -17,6 +17,7 @@
from contextlib import AbstractContextManager
from dataclasses import dataclass, field
from enum import IntFlag
from multiprocessing import synchronize
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union

import torch.multiprocessing as mp
Expand Down Expand Up @@ -271,6 +272,7 @@ def _wrap(
stdout_redirects: Dict[int, str], # redirect file for stdout (to console if None)
stderr_redirects: Dict[int, str], # redirect file for stderr (to console if None)
ret_vals: Dict[int, mp.SimpleQueue],
queue_finished_reading_event: synchronize.Event,
) -> None:
# get the per-rank params up front so we fail fast if no mapping is found
args_ = args[local_rank]
Expand All @@ -289,6 +291,7 @@ def _wrap(
with stdout_cm, stderr_cm:
ret = record(fn)(*args_)
ret_val_.put(ret)
queue_finished_reading_event.wait()


class MultiprocessContext(PContext):
Expand Down Expand Up @@ -331,6 +334,9 @@ def __init__(
# see comments in ``join()`` for what this is
self._return_values: Dict[int, Any] = {}
self._pc: Optional[mp.ProcessContext] = None
# Note: set method should ONLY be invoked for the use case when all processes finished
# successfully. If any process died on event.wait() calling set() method will deadlock.
self._worker_finished_event = mp.get_context(self.start_method).Event()

def _start(self):
if self._pc:
Expand All @@ -347,22 +353,27 @@ def _start(self):
self.stdouts,
self.stderrs,
self._ret_vals,
self._worker_finished_event,
),
nprocs=self.nprocs,
join=False,
daemon=False,
start_method=self.start_method,
)

def _is_done(self) -> bool:
return len(self._return_values) == self.nprocs

def _poll(self) -> Optional[RunProcsResult]:
assert self._pc is not None # assertion for mypy type checker

try:
# torch.mp.ProcessContext returns True if all the workers have
# successfully finished, False if some/all are still running
# and throws an Exception if some/all of them failed
# torch.mp.ProcessContext Throws an Exception if some/all of
# worker processes failed
# timeout < 0 checks worker status and return immediately
done = self._pc.join(-1)
# Join will never return success since we use synchronize.Event to wait
# for all processes to finish.
self._pc.join(-1)

# IMPORTANT: we use multiprocessing.Queue to carry worker return values
# back to the parent, the worker process will wait before terminating
Expand All @@ -376,8 +387,12 @@ def _poll(self) -> Optional[RunProcsResult]:
# save the return values temporarily into a member var
self._return_values[local_rank] = return_queue.get()

if done:
if self._is_done():
# we should ALWAYS have ALL the return values when all the processes are done
self._worker_finished_event.set()
# Wait untill all processes are finished. At this point workers finished executing
# user function
self._pc.join()
_validate_full_rank(
self._return_values, self.nprocs, "return_value queue"
)
Expand Down

0 comments on commit f5675f8

Please sign in to comment.