Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delete DDP hooks in Reducer destructor #21591

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
65 changes: 65 additions & 0 deletions test/test_c10d.py
Expand Up @@ -2490,6 +2490,71 @@ def forward(self, x):
loss = criterion(output, target)
loss.backward()

@skip_if_not_nccl
@skip_if_not_multigpu
def test_failure_recovery(self):
store = c10d.FileStore(self.file.name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)

# need to create a separate file for the recovered FileStore, because
# the original one will be deleted when destructing the first FileStore.
recovery_filename = self.file.name + "_recovery"

if self.rank == 0:
# the file will be deleted by the recovered FileStore
open(recovery_filename, "w").close()

# not necessary to run barrier here, as DDP will synchronize

class TestModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
self.fc1 = nn.Linear(2, 10, bias=False)
self.fc2 = nn.Linear(10, 4, bias=False)
self.relu = nn.ReLU()

def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
return F.softmax(x, dim=1)

device_id = gpus_for_rank(self.world_size)[self.rank][0]
model = TestModel().float().to(device_id)
ddp = DistributedDataParallel(
model,
device_ids=[device_id],
process_group=process_group,
)

batch_size = 4
criterion = nn.CrossEntropyLoss()
input = torch.rand([batch_size, 2], dtype=torch.float)
target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(device_id)

for _ in range(6):
output = ddp(input)
loss = criterion(output, target)
loss.backward()

del ddp
del process_group
del store # this will delete self.file

store = c10d.FileStore(recovery_filename, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
ddp = DistributedDataParallel(
model,
device_ids=[device_id],
process_group=process_group,
)

input = torch.rand([batch_size, 2], dtype=torch.float)
target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(device_id)
for _ in range(6):
output = ddp(input)
loss = criterion(output, target)
loss.backward()


class ReducerModule(nn.Module):
def __init__(self):
Expand Down
16 changes: 15 additions & 1 deletion torch/csrc/autograd/function.h
Expand Up @@ -247,15 +247,29 @@ struct TORCH_API Function : std::enable_shared_from_this<Function> {
// Hook API
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

void add_post_hook(std::unique_ptr<FunctionPostHook>&& post_hook) {
uintptr_t add_post_hook(std::unique_ptr<FunctionPostHook>&& post_hook) {
post_hooks_.push_back(std::move(post_hook));
// Use the raw pointer as the unique key to identify this hook. This key
// can then be used in del_post_hook(key) to remove this hook.
return reinterpret_cast<std::uintptr_t>(post_hooks_.back().get());
}

const std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() const
noexcept {
return post_hooks_;
}

// delete a post hook matching the key
bool del_post_hook(const uintptr_t key) {
for (auto it = post_hooks_.begin(); it != post_hooks_.end();) {
if (key == reinterpret_cast<std::uintptr_t>(it->get())) {
post_hooks_.erase(it);
return true;
}
}
return false;
}

std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() noexcept {
return post_hooks_;
}
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/distributed/c10d/init.cpp
Expand Up @@ -62,7 +62,8 @@ PyObject* c10d_init(PyObject* _unused) {
[](::c10d::Reducer& reducer, const torch::autograd::Variable& output)
-> void { reducer.prepare_for_backward({output}); },
py::call_guard<py::gil_scoped_release>())
.def("get_backward_stats", &::c10d::Reducer::get_backward_stats);
.def("get_backward_stats", &::c10d::Reducer::get_backward_stats)
.def("remove_all_hooks", &::c10d::Reducer::remove_all_hooks);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Python binding can be removed I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kutta asked for exposing this API, just in case Python GC does not deterministically kick in. @kuttas do you still need this?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pietern Even if we call __del__ on the DistributedDataParallel object I am not totally sure that the C++ destructor (Reducer::~Reducer) will get called immediately or lazily (GC time). If this is 100% deterministic, then we don't need this call. Otherwise, we need to implement __del__ and explicitly clean up hooks in DDP. I searched pybind documentation and can't figure out what the expected behavior is if you don't have referential cycles - is ref count update immediate in python and does that immediately trigger the C++ destructor of the py-bound class?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't matter if it is lazily deleted. The hooks may still be called, but only as long as the reducer object is still alive (per the destructor). If the prepare_for_backward function isn't called by the DDP class that wraps the reducer, all the hooks will be nops.

Regarding referential cycles -- IIRC this is where the GC comes in. Objects can be destructed immediately if their refcount drops to zero and then there is a separate GC pass for the cycles.


py::enum_<::c10d::ReduceOp>(module, "ReduceOp", R"(
An enum-like class of available reduce operations: ``SUM``, ``PRODUCT``,
Expand Down
26 changes: 21 additions & 5 deletions torch/csrc/distributed/c10d/reducer.cpp
Expand Up @@ -105,11 +105,14 @@ Reducer::Reducer(
auto grad_accumulator = variable.grad_accumulator();

// Hook to execute after the gradient accumulator has executed.
grad_accumulator->add_post_hook(torch::make_unique<LambdaPostHook>([=] {
std::lock_guard<std::mutex> lock(this->mutex_);
this->mark_variable_ready(
replica_index, variable_index, /* called_from_autograd= */ true);
}));
hooks_[grad_accumulator->add_post_hook(
torch::make_unique<LambdaPostHook>([=] {
std::lock_guard<std::mutex> lock(this->mutex_);
this->mark_variable_ready(
replica_index,
variable_index,
/* called_from_autograd= */ true);
}))] = grad_accumulator;

// Map raw function pointer to replica index and parameter index.
// This is used later on when the autograd graph is traversed
Expand Down Expand Up @@ -138,6 +141,19 @@ Reducer::Reducer(
}
}

void Reducer::remove_all_hooks() const {
for (auto& hook : hooks_) {
auto& key = hook.first;
auto& grad_accumulator = hook.second;
AT_ASSERTM(grad_accumulator->del_post_hook(key),
"Reducer attempts to delete a non-existing hook.");
}
}

Reducer::~Reducer() noexcept(false) {
remove_all_hooks();
}

// Called when the gradient for the specified variable is ready.
// It can be called from two places:
// - By an autograd thread after executing a gradient accumulator function.
Expand Down
9 changes: 9 additions & 0 deletions torch/csrc/distributed/c10d/reducer.h
Expand Up @@ -24,6 +24,8 @@ class Reducer {
std::vector<std::vector<size_t>> bucket_indices,
std::shared_ptr<c10d::ProcessGroup> process_group);

~Reducer() noexcept(false);

// To (re-)initialize bucket assignment, pass a list of buckets, each
// of which is specified by a list of indices in the variables list.
// This function performs validation that the variables within a bucket
Expand All @@ -44,6 +46,11 @@ class Reducer {
return backward_stats_;
}

// Remove all hooks on variables registered by this Reducer. This is necessary
// to make DDP failure recoverable. Otherwise, multiple Reducer instances
// (from recoveries) will append add their hooks to the original model.
void remove_all_hooks() const;

protected:
std::mutex mutex_;
std::vector<std::vector<torch::autograd::Variable>> replicas_;
Expand All @@ -52,6 +59,8 @@ class Reducer {
std::vector<std::vector<std::shared_ptr<torch::autograd::Function>>>
grad_accumulators_;
std::unordered_map<torch::autograd::Function*, std::tuple<int, int>> func_;
std::unordered_map<uintptr_t, std::shared_ptr<torch::autograd::Function>>
hooks_;

bool expect_autograd_hooks_;
bool require_finalize_;
Expand Down