Skip to content
Closed
40 changes: 33 additions & 7 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3562,11 +3562,16 @@ def test_checkpointing(self):
mean_combined = torch.stack(feat_combined).mean()
mean_combined.backward()

def test_reentrant_with_callbacks(self):
counter = [0]
def _test_reentrant_with_callbacks(self, install_callbacks_in_depths):
counter = {}
counter["inner"] = 0
counter["outer"] = 0

def inc_counter():
counter[0] += 1
def inc_inner_counter():
counter["inner"] += 1

def inc_outer_counter():
counter["outer"] += 1

class MyFunc(Function):
@staticmethod
Expand All @@ -3576,8 +3581,9 @@ def forward(ctx, input):
@staticmethod
@once_differentiable
def backward(ctx, input):
# Add a callback to execute.
Variable._execution_engine.queue_callback(inc_counter)
if 1 in install_callbacks_in_depths:
# Add a callback to execute.
Variable._execution_engine.queue_callback(inc_inner_counter)

return input

Expand All @@ -3589,6 +3595,9 @@ def forward(ctx, input):
@staticmethod
@once_differentiable
def backward(ctx, input):
if 0 in install_callbacks_in_depths:
# Add a callback to execute.
Variable._execution_engine.queue_callback(inc_outer_counter)
# Reentrant backward call.
tmp_inp = input.detach().requires_grad_()
with torch.enable_grad():
Expand All @@ -3601,8 +3610,25 @@ def backward(ctx, input):
t3 = t2.sum()
torch.autograd.backward([t3])

return counter

def test_reentrant_with_callbacks_depth_0(self):
# Verify callback is called only once.
ret = self._test_reentrant_with_callbacks([0])
self.assertEqual(1, ret["outer"])
self.assertEqual(0, ret["inner"])

def test_reentrant_with_callbacks_depth_1(self):
# Verify callback is called only once.
self.assertEqual(1, counter[0])
ret = self._test_reentrant_with_callbacks([1])
self.assertEqual(0, ret["outer"])
self.assertEqual(1, ret["inner"])

def test_reentrant_with_callbacks_both_depths(self):
# Verify callback is called twice.
ret = self._test_reentrant_with_callbacks([0, 1])
self.assertEqual(1, ret["outer"])
self.assertEqual(1, ret["inner"])

def test_autograd_views_codegen(self):
# This is not necessarily the absolute correct behavior, but this is the current
Expand Down
62 changes: 37 additions & 25 deletions torch/csrc/autograd/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ static thread_local int current_depth = 0;
// Total nested reentrant backwards calls over all threads for workder_device
static thread_local int total_depth = 0;

// The current GraphTask being executed by this thread. This helps
// queue_callback() to find the target GraphTask to append final callbacks.
static thread_local std::shared_ptr<GraphTask> current_graph_task = nullptr;

// Returns true when t2 should be (weakly) BEFORE t1 in the queue.
// Shutdown tasks are first and then empty NodeTask are next.
struct CompareNodeTaskTime {
Expand Down Expand Up @@ -309,6 +313,21 @@ auto Engine::thread_init(int device) -> void {
}
}

// The guard that sets and restores current_graph_task.
struct GraphTaskGuard {
GraphTaskGuard(std::shared_ptr<GraphTask> graph_task) {
last_graph_task_ = std::move(current_graph_task);
current_graph_task = std::move(graph_task);
}
~GraphTaskGuard() { restore_current_graph_task(); }

void restore_current_graph_task() {
current_graph_task = std::move(last_graph_task_);
}

std::shared_ptr<GraphTask> last_graph_task_;
};

// NOTE: graph_tasks do not necessarily form a stack. Imagine this
// case:
//
Expand Down Expand Up @@ -359,6 +378,11 @@ auto Engine::thread_main(
if (task.fn_ && !local_graph_task->has_error_.load()) {
AutoGradMode grad_mode(local_graph_task->grad_mode_);
try {
// The guard sets the thread_local current_graph_task on construction
// and restores it on exit. The current_graph_task variable helps
// queue_callback() to find the target GraphTask to append final
// callbacks.
GraphTaskGuard guard(local_graph_task);
evaluate_function(local_graph_task, task.fn_.get(), task.inputs_);
} catch (std::exception& e) {
thread_on_exception(local_graph_task, task.fn_, e);
Expand Down Expand Up @@ -717,22 +741,6 @@ auto Engine::compute_dependencies(Node* root, GraphTask& task) -> void {
}
}

struct ClearCallbacks {
ClearCallbacks(std::vector<std::function<void()>>& callbacks,
std::mutex &callbacks_lock)
: callbacks_(callbacks)
, callbacks_lock_(callbacks_lock) { clear(); }
~ClearCallbacks() { clear(); }

void clear() {
std::lock_guard<std::mutex> lock(callbacks_lock_);
callbacks_.clear();
}

std::vector<std::function<void()>>& callbacks_;
std::mutex& callbacks_lock_;
};

auto Engine::execute(const edge_list& roots,
const variable_list& inputs,
bool keep_graph,
Expand All @@ -743,10 +751,6 @@ auto Engine::execute(const edge_list& roots,
return msg;
});

// Callbacks are only valid for the duration of this run and should always be cleared
// Lock post_callbacks_lock_ before clearing final_callbacks_
ClearCallbacks _cb_guard(final_callbacks_, post_callbacks_lock_);

auto graph_task = std::make_shared<GraphTask>(
keep_graph,
create_graph,
Expand Down Expand Up @@ -851,17 +855,21 @@ void Engine::graph_task_exec_post_processing(
throw std::runtime_error("could not compute gradients for some functions");
}

// set the thread_local current_graph_task_ as more callbacks can be installed
// by existing final callbacks.
GraphTaskGuard guard(graph_task);
// Lock mutex during each iteration for accessing final_callbacks.size()
// Unlocking is necessary, because the callback can register
// more callbacks (or they can be registered from other threads
// while it's waiting.
std::unique_lock<std::mutex> cb_lock(post_callbacks_lock_);
std::unique_lock<std::mutex> cb_lock(graph_task->final_callbacks_lock_);
const auto& final_callbacks = graph_task->final_callbacks_;
// WARNING: Don't use a range-for loop here because more callbacks may be
// added in between callback calls, so iterators may become invalidated.
// NOLINTNEXTLINE(modernize-loop-convert)
for (size_t i = 0; i < final_callbacks_.size(); ++i) {
for (size_t i = 0; i < final_callbacks.size(); ++i) {
cb_lock.unlock();
final_callbacks_[i]();
final_callbacks[i]();
cb_lock.lock();
}

Expand Down Expand Up @@ -898,8 +906,12 @@ Engine& Engine::get_default_engine() {
}

void Engine::queue_callback(std::function<void()> callback) {
std::lock_guard<std::mutex> lock(post_callbacks_lock_);
final_callbacks_.emplace_back(std::move(callback));
TORCH_CHECK(
current_graph_task,
"Final callbacks can only be installed during backward pass.");

std::lock_guard<std::mutex> lock(current_graph_task->final_callbacks_lock_);
current_graph_task->final_callbacks_.emplace_back(std::move(callback));
}

bool Engine::is_checkpoint_valid() {
Expand Down
9 changes: 6 additions & 3 deletions torch/csrc/autograd/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ struct GraphTask {
// tasks are done.
std::shared_ptr<FutureVariableList> future_result_;

// Final callbacks installed during execution of this GraphTask
std::vector<std::function<void()>> final_callbacks_;
// To protect reads and writes to final_callbacks_. Intentionally no reusing
// mutex_ as the two are protecting different data structures.
std::mutex final_callbacks_lock_;

GraphTask(
bool keep_graph,
bool grad_mode,
Expand Down Expand Up @@ -239,9 +245,6 @@ struct TORCH_API Engine {
std::once_flag start_threads_flag_;
// Safe to read ready_queues_ without synchronization after intialization
std::vector<std::shared_ptr<ReadyQueue>> ready_queues_;
std::vector<std::function<void()>> final_callbacks_;
// To protect reads and writes to final_callbacks_
std::mutex post_callbacks_lock_;
// How many nested reentrant calls are allowed until a new thread is used
int max_recursion_depth_;

Expand Down