Skip to content

Commit

Permalink
Better handing of Autograd+Fork errors. (pytorch#33885)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#33885

Fixes: pytorch#32835
Fixes: pytorch#5834

Can not combine with CUDA's implementation as each of them requires individual `std::once_flag` as well as different `forked_autograd_child` functions. CUDA version relays to python module, autograd uses TORCH_CHECK to report error to python and cpp.

Test Plan: Imported from OSS

Differential Revision: D20144024

Pulled By: VitalyFedyunin

fbshipit-source-id: e7cf30568fff5110e9df7fe5b23f18ed992fa17f
  • Loading branch information
VitalyFedyunin authored and ttumiel committed Mar 4, 2020
1 parent 45e707f commit dd2ab4b
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 16 deletions.
32 changes: 18 additions & 14 deletions test/test_multiprocessing.py
Expand Up @@ -80,11 +80,6 @@ def receive_and_send(queue, out_queue, event, count):
event.wait()


def call_backward():
x = torch.randn(3, 3, requires_grad=True)
x.sum().backward()


def sum_tensors(inq, outq):
with torch.cuda.device(1):
tensors = inq.get()
Expand Down Expand Up @@ -156,6 +151,9 @@ def mixed_type_producer(queue, event):
event.wait()
event.clear()

def simple_autograd_function(a=1):
torch.rand(3).requires_grad_(True).mean().backward()
return a ** 2

@contextlib.contextmanager
def fs_sharing():
Expand Down Expand Up @@ -358,6 +356,21 @@ def test_inherit_tensor(self):
p.join(1)
self.assertEqual(t, torch.ones(5, 5) * 3, 0)

@unittest.skipIf(IS_WINDOWS, "Test needs to use fork multiprocessing")
def test_autograd_errors(self):
ctx = mp.get_context('fork')
simple_autograd_function()
with self.assertRaisesRegex(RuntimeError, r'Unable to handle autograd'):
with ctx.Pool(3) as pool:
pool.map(simple_autograd_function, [1, 2, 3])

@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Test needs to use spawn multiprocessing")
def test_autograd_fine_with_spawn(self):
ctx = mp.get_context('spawn')
simple_autograd_function()
with ctx.Pool(3) as pool:
pool.map(simple_autograd_function, [1, 2, 3])

@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
don't support multiprocessing with spawn start method")
@unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
Expand Down Expand Up @@ -821,15 +834,6 @@ def test_is_shared_cuda(self):
t = torch.randn(5, 5).cuda()
self.assertTrue(t.is_shared())

@unittest.skip('this test occasionally fails and deadlocks; see https://github.com/pytorch/pytorch/issues/5834')
def test_backwards_fork(self):
r"backwards() should succeed when called before and after a fork"
call_backward()
p = mp.Process(target=call_backward)
p.start()
p.join(1)
self.assertFalse(p.is_alive())


if __name__ == '__main__':
run_tests()
30 changes: 28 additions & 2 deletions torch/csrc/autograd/engine.cpp
Expand Up @@ -35,6 +35,24 @@

namespace torch { namespace autograd {

namespace {
static bool in_bad_autograd_fork =
false; // True for children forked after engine's thread pool init

// Called in the forked child if engine's thread pool has already been
// initialized
static void forked_autograd_child() { in_bad_autograd_fork = true; }

// Should be called before unsafe for forks (thread pool) calls
static void track_bad_autograd_forks() {
#ifndef WIN32
static std::once_flag flag;
std::call_once(
flag, [&] { pthread_atfork(nullptr, nullptr, forked_autograd_child); });
#endif
}
}

// Threads spawned by the engine are assigned a constant 'worker_device'
// specifying what device they process work for. This variable is initialized
// at thread creation time and is constant afterwards. This is used when
Expand Down Expand Up @@ -728,16 +746,24 @@ auto Engine::execute(const edge_list& roots,
return execute_with_graph_task(graph_task, graph_root)->wait();
}

void Engine::enqueue_blocked_task_on_cpu(NodeTask task) {
void Engine::initialize_threads_pool() {
track_bad_autograd_forks();
TORCH_CHECK(!in_bad_autograd_fork,
"Unable to handle autograd's threading in combination with fork-based multiprocessing. "
"See https://github.com/pytorch/pytorch/wiki/Autograd-and-Fork");
std::call_once(start_threads_flag_, &Engine::start_threads, this);
}

void Engine::enqueue_blocked_task_on_cpu(NodeTask task) {
initialize_threads_pool();
ready_queue(at::kCPU).push(
std::move(task), /* incrementOutstandingTasks */ false);
}

std::shared_ptr<FutureVariableList> Engine::execute_with_graph_task(
const std::shared_ptr<GraphTask>& graph_task,
std::shared_ptr<Node> graph_root) {
std::call_once(start_threads_flag_, &Engine::start_threads, this);
initialize_threads_pool();
// Lock mutex for GraphTask.
std::unique_lock<std::mutex> lock(graph_task->mutex_);

Expand Down
1 change: 1 addition & 0 deletions torch/csrc/autograd/engine.h
Expand Up @@ -217,6 +217,7 @@ struct TORCH_API Engine {
void reentrant_thread_init();
void add_thread_pool_task(const std::weak_ptr<GraphTask>& graph_task);
void set_device(int device);
void initialize_threads_pool();

// Ensures ready_queues_ are initialized only once
std::once_flag start_threads_flag_;
Expand Down

0 comments on commit dd2ab4b

Please sign in to comment.