diff --git a/test/test_c10d.py b/test/test_c10d.py index 0af1979099b62..712aff9c7c4d8 100644 --- a/test/test_c10d.py +++ b/test/test_c10d.py @@ -2490,6 +2490,71 @@ def forward(self, x): loss = criterion(output, target) loss.backward() + @skip_if_not_nccl + @skip_if_not_multigpu + def test_failure_recovery(self): + store = c10d.FileStore(self.file.name, self.world_size) + process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) + + # need to create a separate file for the recovered FileStore, because + # the original one will be deleted when destructing the first FileStore. + recovery_filename = self.file.name + "_recovery" + + if self.rank == 0: + # the file will be deleted by the recovered FileStore + open(recovery_filename, "w").close() + + # not necessary to run barrier here, as DDP will synchronize + + class TestModel(nn.Module): + def __init__(self): + super(TestModel, self).__init__() + self.fc1 = nn.Linear(2, 10, bias=False) + self.fc2 = nn.Linear(10, 4, bias=False) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + return F.softmax(x, dim=1) + + device_id = gpus_for_rank(self.world_size)[self.rank][0] + model = TestModel().float().to(device_id) + ddp = DistributedDataParallel( + model, + device_ids=[device_id], + process_group=process_group, + ) + + batch_size = 4 + criterion = nn.CrossEntropyLoss() + input = torch.rand([batch_size, 2], dtype=torch.float) + target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(device_id) + + for _ in range(6): + output = ddp(input) + loss = criterion(output, target) + loss.backward() + + del ddp + del process_group + del store # this will delete self.file + + store = c10d.FileStore(recovery_filename, self.world_size) + process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) + ddp = DistributedDataParallel( + model, + device_ids=[device_id], + process_group=process_group, + ) + + input = torch.rand([batch_size, 2], dtype=torch.float) + target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(device_id) + for _ in range(6): + output = ddp(input) + loss = criterion(output, target) + loss.backward() + class ReducerModule(nn.Module): def __init__(self): diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index b28a1f1e93666..f273e8a1b57a2 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -247,8 +247,11 @@ struct TORCH_API Function : std::enable_shared_from_this { // Hook API //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - void add_post_hook(std::unique_ptr&& post_hook) { + uintptr_t add_post_hook(std::unique_ptr&& post_hook) { post_hooks_.push_back(std::move(post_hook)); + // Use the raw pointer as the unique key to identify this hook. This key + // can then be used in del_post_hook(key) to remove this hook. + return reinterpret_cast(post_hooks_.back().get()); } const std::vector>& post_hooks() const @@ -256,6 +259,17 @@ struct TORCH_API Function : std::enable_shared_from_this { return post_hooks_; } + // delete a post hook matching the key + bool del_post_hook(const uintptr_t& key) { + for (auto it = post_hooks_.begin(); it != post_hooks_.end();) { + if (key == reinterpret_cast(it->get())) { + post_hooks_.erase(it); + return true; + } + } + return false; + } + std::vector>& post_hooks() noexcept { return post_hooks_; } diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index d919705d1e576..4e1b0fa7a08a3 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -105,11 +105,14 @@ Reducer::Reducer( auto grad_accumulator = variable.grad_accumulator(); // Hook to execute after the gradient accumulator has executed. - grad_accumulator->add_post_hook(torch::make_unique([=] { - std::lock_guard lock(this->mutex_); - this->mark_variable_ready( - replica_index, variable_index, /* called_from_autograd= */ true); - })); + hooks_[grad_accumulator->add_post_hook( + torch::make_unique([=] { + std::lock_guard lock(this->mutex_); + this->mark_variable_ready( + replica_index, + variable_index, + /* called_from_autograd= */ true); + }))] = grad_accumulator; // Map raw function pointer to replica index and parameter index. // This is used later on when the autograd graph is traversed @@ -138,6 +141,19 @@ Reducer::Reducer( } } +Reducer::~Reducer() noexcept(false) { + // Remove all hooks on variables registered by this Reducer. This is necessary + // to make DDP failure recoverable. Otherwise, multiple Reducer instances + // (from recoveries) will add their hooks to the original model, and those + // hooks will try to invoke methods on a deleted Reducer objects. + for (auto& hook : hooks_) { + auto& key = hook.first; + auto& grad_accumulator = hook.second; + AT_ASSERTM(grad_accumulator->del_post_hook(key), + "Reducer attempts to delete a non-existing hook."); + } +} + // Called when the gradient for the specified variable is ready. // It can be called from two places: // - By an autograd thread after executing a gradient accumulator function. diff --git a/torch/csrc/distributed/c10d/reducer.h b/torch/csrc/distributed/c10d/reducer.h index bd399a6981275..509317971be05 100644 --- a/torch/csrc/distributed/c10d/reducer.h +++ b/torch/csrc/distributed/c10d/reducer.h @@ -24,6 +24,8 @@ class Reducer { std::vector> bucket_indices, std::shared_ptr process_group); + ~Reducer() noexcept(false); + // To (re-)initialize bucket assignment, pass a list of buckets, each // of which is specified by a list of indices in the variables list. // This function performs validation that the variables within a bucket @@ -52,6 +54,8 @@ class Reducer { std::vector>> grad_accumulators_; std::unordered_map> func_; + std::unordered_map> + hooks_; bool expect_autograd_hooks_; bool require_finalize_;