Skip to content

Commit

Permalink
Autograd engine, only enqueue task when it is fully initialized (#50164)
Browse files Browse the repository at this point in the history
Summary:
This solves a race condition where the worker thread might
see a partially initialized graph_task

Fixes #49652

I don't know how to reliably trigger the race so I didn't add any test. But the rocm build flakyness (it just happens to race more often on rocm builds) should disappear after this PR.

Pull Request resolved: #50164

Reviewed By: zou3519

Differential Revision: D25824954

Pulled By: albanD

fbshipit-source-id: 6a3391753cb2afd2ab415d3fb2071a837cc565bb
  • Loading branch information
albanD authored and facebook-github-bot committed Jan 8, 2021
1 parent c215ffb commit fc2ead0
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions torch/csrc/autograd/engine.cpp
Expand Up @@ -916,7 +916,6 @@ std::shared_ptr<at::ivalue::Future> Engine::execute_with_graph_task(
std::unique_lock<std::mutex> lock(graph_task->mutex_);

auto queue = ready_queue(graph_task->cpu_ready_queue_, input_buffer.device());
queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer)));

// worker_device == NO_DEVICE it's a CPU thread and it's trying to drive the
// autograd engine with corresponding GraphTask, and its NOT a re-entrant call
Expand All @@ -929,8 +928,12 @@ std::shared_ptr<at::ivalue::Future> Engine::execute_with_graph_task(
// set the graph_task owner to the current device
graph_task->owner_ = worker_device;

// The owning thread start to drive the engine execution with the GraphTask
// that has already been pushed to the current CPU thread's ready_queue
// Now that all the non-thread safe fields of the graph_task have been populated,
// we can enqueue it.
queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer)));

// The owning thread start to drive the engine execution for any CPU task that
// was just pushed or will be added later from other worker threads
lock.unlock();
thread_main(graph_task);
TORCH_INTERNAL_ASSERT(graph_task->future_result_->completed());
Expand All @@ -943,6 +946,11 @@ std::shared_ptr<at::ivalue::Future> Engine::execute_with_graph_task(
// If worker_device is any devices (i.e. CPU, CUDA): this is a re-entrant
// backward call from that device.
graph_task->owner_ = worker_device;

// Now that all the non-thread safe fields of the graph_task have been populated,
// we can enqueue it.
queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer)));

if (current_depth >= max_recursion_depth_) {
// See Note [Reentrant backwards]
// If reached the max depth, switch to a different thread
Expand Down

0 comments on commit fc2ead0

Please sign in to comment.