From fb09803112232b69fee26e7b76125ed8d46d6545 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Mon, 10 Jun 2019 07:52:35 -0700 Subject: [PATCH 1/5] Delete DDP hooks in Reducer destructor --- test/test_c10d.py | 55 +++++++++++++++++++++++++ torch/csrc/autograd/function.h | 15 +++++++ torch/csrc/distributed/c10d/reducer.cpp | 8 ++++ torch/csrc/distributed/c10d/reducer.h | 2 + 4 files changed, 80 insertions(+) diff --git a/test/test_c10d.py b/test/test_c10d.py index 0af1979099b62..49407659052b8 100644 --- a/test/test_c10d.py +++ b/test/test_c10d.py @@ -2490,6 +2490,61 @@ def forward(self, x): loss = criterion(output, target) loss.backward() + @skip_if_not_nccl + @skip_if_not_multigpu + def test_failure_recovery(self): + store = c10d.FileStore(self.file.name, self.world_size) + process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) + + class TestModel(nn.Module): + def __init__(self): + super(TestModel, self).__init__() + self.fc1 = nn.Linear(2, 10, bias=False) + self.fc2 = nn.Linear(10, 4, bias=False) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + return F.softmax(x, dim=1) + + device_id = gpus_for_rank(self.world_size)[self.rank][0] + model = TestModel().float().to(device_id) + ddp = DistributedDataParallel( + model, + device_ids=[device_id], + process_group=process_group, + ) + + batch_size = 4 + criterion = nn.CrossEntropyLoss() + input = torch.rand([batch_size, 2], dtype=torch.float) + target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(device_id) + + for _ in range(6): + output = ddp(input) + loss = criterion(output, target) + loss.backward() + + del ddp + del process_group + del store + + store = c10d.FileStore(self.file.name, self.world_size) + process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) + ddp = DistributedDataParallel( + model, + device_ids=[device_id], + process_group=process_group, + ) + + input = torch.rand([batch_size, 2], dtype=torch.float) + target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(device_id) + for i in range(6): + output = ddp(input) + loss = criterion(output, target) + loss.backward() + class ReducerModule(nn.Module): def __init__(self): diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index b28a1f1e93666..4100ff7902e8b 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -256,6 +256,20 @@ struct TORCH_API Function : std::enable_shared_from_this { return post_hooks_; } + // delete all post hooks matching HookType + template + void delete_post_hooks() { + std::lock_guard lock(this->mutex_); + for (auto it = post_hooks_.begin(); it != post_hooks_.end();) { + HookType* ptr = dynamic_cast(it->get()); + if (ptr) { + it = post_hooks_.erase(it); + } else { + ++it; + } + } + } + std::vector>& post_hooks() noexcept { return post_hooks_; } @@ -325,6 +339,7 @@ struct TORCH_API Function : std::enable_shared_from_this { std::vector> pre_hooks_; std::vector> post_hooks_; at::SmallVector input_metadata_; + std::mutex mutex_; }; /// See Function::is_traceable() for definition. diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index d919705d1e576..39ced176be4ef 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -138,6 +138,14 @@ Reducer::Reducer( } } +Reducer::~Reducer() { + for (auto& replica_grad_accumulators: grad_accumulators_) { + for (auto& grad_accumulator: replica_grad_accumulators) { + grad_accumulator->delete_post_hooks(); + } + } +} + // Called when the gradient for the specified variable is ready. // It can be called from two places: // - By an autograd thread after executing a gradient accumulator function. diff --git a/torch/csrc/distributed/c10d/reducer.h b/torch/csrc/distributed/c10d/reducer.h index bd399a6981275..dc0173d6addd6 100644 --- a/torch/csrc/distributed/c10d/reducer.h +++ b/torch/csrc/distributed/c10d/reducer.h @@ -24,6 +24,8 @@ class Reducer { std::vector> bucket_indices, std::shared_ptr process_group); + ~Reducer(); + // To (re-)initialize bucket assignment, pass a list of buckets, each // of which is specified by a list of indices in the variables list. // This function performs validation that the variables within a bucket From eb90ba25898f1347ca4a2c7915d9ed9bce928fb7 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Mon, 10 Jun 2019 12:45:52 -0700 Subject: [PATCH 2/5] address comments --- torch/csrc/autograd/function.h | 21 +++++++++---------- torch/csrc/distributed/c10d/init.cpp | 3 ++- torch/csrc/distributed/c10d/reducer.cpp | 28 ++++++++++++++++--------- torch/csrc/distributed/c10d/reducer.h | 9 +++++++- 4 files changed, 38 insertions(+), 23 deletions(-) diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index 4100ff7902e8b..25210600a29c1 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -247,8 +247,11 @@ struct TORCH_API Function : std::enable_shared_from_this { // Hook API //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - void add_post_hook(std::unique_ptr&& post_hook) { + uintptr_t add_post_hook(std::unique_ptr&& post_hook) { post_hooks_.push_back(std::move(post_hook)); + // Use the raw pointer as the unique key to identify this hook. This key + // can then be used in del_post_hook(key) to remove this hook. + return reinterpret_cast(post_hooks_.back().get()); } const std::vector>& post_hooks() const @@ -256,18 +259,15 @@ struct TORCH_API Function : std::enable_shared_from_this { return post_hooks_; } - // delete all post hooks matching HookType - template - void delete_post_hooks() { - std::lock_guard lock(this->mutex_); + // delete a post hook matching the key + bool del_post_hook(const uintptr_t key) { for (auto it = post_hooks_.begin(); it != post_hooks_.end();) { - HookType* ptr = dynamic_cast(it->get()); - if (ptr) { - it = post_hooks_.erase(it); - } else { - ++it; + if (key == reinterpret_cast(it->get())) { + post_hooks_.erase(it); + return true; } } + return false; } std::vector>& post_hooks() noexcept { @@ -339,7 +339,6 @@ struct TORCH_API Function : std::enable_shared_from_this { std::vector> pre_hooks_; std::vector> post_hooks_; at::SmallVector input_metadata_; - std::mutex mutex_; }; /// See Function::is_traceable() for definition. diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 018ee65057bb0..0616d65f4280f 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -62,7 +62,8 @@ PyObject* c10d_init(PyObject* _unused) { [](::c10d::Reducer& reducer, const torch::autograd::Variable& output) -> void { reducer.prepare_for_backward({output}); }, py::call_guard()) - .def("get_backward_stats", &::c10d::Reducer::get_backward_stats); + .def("get_backward_stats", &::c10d::Reducer::get_backward_stats) + .def("remove_all_hooks", &::c10d::Reducer::remove_all_hooks); py::enum_<::c10d::ReduceOp>(module, "ReduceOp", R"( An enum-like class of available reduce operations: ``SUM``, ``PRODUCT``, diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index 39ced176be4ef..616589d0d71b7 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -105,11 +105,14 @@ Reducer::Reducer( auto grad_accumulator = variable.grad_accumulator(); // Hook to execute after the gradient accumulator has executed. - grad_accumulator->add_post_hook(torch::make_unique([=] { - std::lock_guard lock(this->mutex_); - this->mark_variable_ready( - replica_index, variable_index, /* called_from_autograd= */ true); - })); + hooks_[grad_accumulator->add_post_hook( + torch::make_unique([=] { + std::lock_guard lock(this->mutex_); + this->mark_variable_ready( + replica_index, + variable_index, + /* called_from_autograd= */ true); + }))] = grad_accumulator; // Map raw function pointer to replica index and parameter index. // This is used later on when the autograd graph is traversed @@ -138,14 +141,19 @@ Reducer::Reducer( } } -Reducer::~Reducer() { - for (auto& replica_grad_accumulators: grad_accumulators_) { - for (auto& grad_accumulator: replica_grad_accumulators) { - grad_accumulator->delete_post_hooks(); - } +void Reducer::remove_all_hooks() const { + for (auto& hook : hooks_) { + auto& key = hook.first; + auto& grad_accumulator = hook.second; + AT_ASSERTM(grad_accumulator->del_post_hook(key), + "Reducer attempts to delete a non-existing hook."); } } +Reducer::~Reducer() noexcept(false) { + remove_all_hooks(); +} + // Called when the gradient for the specified variable is ready. // It can be called from two places: // - By an autograd thread after executing a gradient accumulator function. diff --git a/torch/csrc/distributed/c10d/reducer.h b/torch/csrc/distributed/c10d/reducer.h index dc0173d6addd6..fa4e216f27b0f 100644 --- a/torch/csrc/distributed/c10d/reducer.h +++ b/torch/csrc/distributed/c10d/reducer.h @@ -24,7 +24,7 @@ class Reducer { std::vector> bucket_indices, std::shared_ptr process_group); - ~Reducer(); + ~Reducer() noexcept(false); // To (re-)initialize bucket assignment, pass a list of buckets, each // of which is specified by a list of indices in the variables list. @@ -46,6 +46,11 @@ class Reducer { return backward_stats_; } + // Remove all hooks on variables registered by this Reducer. This is necessary + // to make DDP failure recoverable. Otherwise, multiple Reducer instances + // (from recoveries) will append add their hooks to the original model. + void remove_all_hooks() const; + protected: std::mutex mutex_; std::vector> replicas_; @@ -54,6 +59,8 @@ class Reducer { std::vector>> grad_accumulators_; std::unordered_map> func_; + std::unordered_map> + hooks_; bool expect_autograd_hooks_; bool require_finalize_; From 50645225909494a8ddee5a7e5052373f083a5bc4 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Mon, 10 Jun 2019 14:23:22 -0700 Subject: [PATCH 3/5] enable c10d test in multi-gpu environment --- .jenkins/pytorch/multigpu-test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/.jenkins/pytorch/multigpu-test.sh b/.jenkins/pytorch/multigpu-test.sh index 3d082b55e04ac..740035ef0f987 100755 --- a/.jenkins/pytorch/multigpu-test.sh +++ b/.jenkins/pytorch/multigpu-test.sh @@ -28,4 +28,5 @@ if [ -n "${IN_CIRCLECI}" ]; then fi time python test/run_test.py --verbose -i distributed +time python test/run_test.py --verbose -i c10d assert_git_not_dirty From c8c41f063048be5e2583c07bec5da4478e961541 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Tue, 11 Jun 2019 08:38:19 -0700 Subject: [PATCH 4/5] fix tests --- test/test_c10d.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/test/test_c10d.py b/test/test_c10d.py index 49407659052b8..05859492c7582 100644 --- a/test/test_c10d.py +++ b/test/test_c10d.py @@ -2496,6 +2496,16 @@ def test_failure_recovery(self): store = c10d.FileStore(self.file.name, self.world_size) process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) + # need to create a separate file for the recovered FileStore, because + # the original one will be deleted when destructing the first FileStore. + recovery_filename = self.file.name + "_recovery" + + if self.rank == 0: + # the file will be deleted by the recovered FileStore + open(recovery_filename, "w").close() + + # not necessary to run barrier here, as DDP will synchronize + class TestModel(nn.Module): def __init__(self): super(TestModel, self).__init__() @@ -2528,9 +2538,9 @@ def forward(self, x): del ddp del process_group - del store + del store # this will delete self.file - store = c10d.FileStore(self.file.name, self.world_size) + store = c10d.FileStore(recovery_filename, self.world_size) process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) ddp = DistributedDataParallel( model, @@ -2540,7 +2550,7 @@ def forward(self, x): input = torch.rand([batch_size, 2], dtype=torch.float) target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(device_id) - for i in range(6): + for _ in range(6): output = ddp(input) loss = criterion(output, target) loss.backward() From 08c0c8b53b6561c330f701c964d5c53a73a459bb Mon Sep 17 00:00:00 2001 From: Shen Li Date: Tue, 11 Jun 2019 14:54:05 -0700 Subject: [PATCH 5/5] Get rid of remove_all_hooks and do it in Reducer's destructor --- test/test_c10d.py | 2 +- torch/csrc/autograd/function.h | 2 +- torch/csrc/distributed/c10d/init.cpp | 3 +-- torch/csrc/distributed/c10d/reducer.cpp | 10 +++++----- torch/csrc/distributed/c10d/reducer.h | 5 ----- 5 files changed, 8 insertions(+), 14 deletions(-) diff --git a/test/test_c10d.py b/test/test_c10d.py index 05859492c7582..712aff9c7c4d8 100644 --- a/test/test_c10d.py +++ b/test/test_c10d.py @@ -2538,7 +2538,7 @@ def forward(self, x): del ddp del process_group - del store # this will delete self.file + del store # this will delete self.file store = c10d.FileStore(recovery_filename, self.world_size) process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index 25210600a29c1..f273e8a1b57a2 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -260,7 +260,7 @@ struct TORCH_API Function : std::enable_shared_from_this { } // delete a post hook matching the key - bool del_post_hook(const uintptr_t key) { + bool del_post_hook(const uintptr_t& key) { for (auto it = post_hooks_.begin(); it != post_hooks_.end();) { if (key == reinterpret_cast(it->get())) { post_hooks_.erase(it); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 0616d65f4280f..018ee65057bb0 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -62,8 +62,7 @@ PyObject* c10d_init(PyObject* _unused) { [](::c10d::Reducer& reducer, const torch::autograd::Variable& output) -> void { reducer.prepare_for_backward({output}); }, py::call_guard()) - .def("get_backward_stats", &::c10d::Reducer::get_backward_stats) - .def("remove_all_hooks", &::c10d::Reducer::remove_all_hooks); + .def("get_backward_stats", &::c10d::Reducer::get_backward_stats); py::enum_<::c10d::ReduceOp>(module, "ReduceOp", R"( An enum-like class of available reduce operations: ``SUM``, ``PRODUCT``, diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index 616589d0d71b7..4e1b0fa7a08a3 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -141,7 +141,11 @@ Reducer::Reducer( } } -void Reducer::remove_all_hooks() const { +Reducer::~Reducer() noexcept(false) { + // Remove all hooks on variables registered by this Reducer. This is necessary + // to make DDP failure recoverable. Otherwise, multiple Reducer instances + // (from recoveries) will add their hooks to the original model, and those + // hooks will try to invoke methods on a deleted Reducer objects. for (auto& hook : hooks_) { auto& key = hook.first; auto& grad_accumulator = hook.second; @@ -150,10 +154,6 @@ void Reducer::remove_all_hooks() const { } } -Reducer::~Reducer() noexcept(false) { - remove_all_hooks(); -} - // Called when the gradient for the specified variable is ready. // It can be called from two places: // - By an autograd thread after executing a gradient accumulator function. diff --git a/torch/csrc/distributed/c10d/reducer.h b/torch/csrc/distributed/c10d/reducer.h index fa4e216f27b0f..509317971be05 100644 --- a/torch/csrc/distributed/c10d/reducer.h +++ b/torch/csrc/distributed/c10d/reducer.h @@ -46,11 +46,6 @@ class Reducer { return backward_stats_; } - // Remove all hooks on variables registered by this Reducer. This is necessary - // to make DDP failure recoverable. Otherwise, multiple Reducer instances - // (from recoveries) will append add their hooks to the original model. - void remove_all_hooks() const; - protected: std::mutex mutex_; std::vector> replicas_;