diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 1494bb8fc863..a1f4d22ca283 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -670,25 +670,14 @@ def test_ddp_checkpointing_weight_sharing(self, use_reentrant): l2 = nn.Linear(20, 20) l1.weight = l2.weight model = nn.Sequential(l1, l2) - # TODO: non-reentrant based checkpointing of DDP module with - # static_graph runs into the below issue, see - # https://github.com/pytorch/pytorch/issues/70865 and - # https://github.com/pytorch/pytorch/issues/58111 for details. - err_ctx = ( - self.assertRaisesRegex( - RuntimeError, - "Your training graph has changed in this iteration" - ) if static_graph and not use_reentrant else nullcontext() + self._test_ddp_checkpointing( + model, + process_group=process_group, + use_bucket_view=use_bucket_view, + static_graph=static_graph, + run_checkpoint=True, + use_reentrant=use_reentrant, ) - with err_ctx: - self._test_ddp_checkpointing( - model, - process_group=process_group, - use_bucket_view=use_bucket_view, - static_graph=static_graph, - run_checkpoint=True, - use_reentrant=use_reentrant, - ) @skip_if_lt_x_gpu(2) def test_ddp_checkpointing_twice_weight_sharing(self): diff --git a/test/distributed/test_distributed_spawn.py b/test/distributed/test_distributed_spawn.py index 8499f167c6c9..b33bfcb34635 100644 --- a/test/distributed/test_distributed_spawn.py +++ b/test/distributed/test_distributed_spawn.py @@ -34,6 +34,8 @@ def setUp(self): super().setUp() self._spawn_processes() torch.backends.cudnn.flags(enabled=True, allow_tf32=False).__enter__() +else: + print(f"Invalid backend {BACKEND}. Tests will not be run!") if __name__ == "__main__": diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index a0e46e48ac5d..70f8bcd2b47e 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -109,6 +109,8 @@ Reducer::Reducer( gradient_as_bucket_view_(gradient_as_bucket_view), local_used_map_reduced_(false), num_iterations_(0), + num_bwd_calls_(0), + first_autograd_hook_called_(false), num_buckets_ready_(0), has_rebuilt_bucket_(false), bucket_bytes_cap_(bucket_bytes_cap), @@ -267,11 +269,11 @@ bool Reducer::dynamic_graph_find_unused() { } bool Reducer::static_graph_first_iteration() { - return static_graph_ && num_iterations_ == 1; + return static_graph_ && num_bwd_calls_ == 1; } bool Reducer::static_graph_after_first_iteration() { - return static_graph_ && num_iterations_ > 1; + return static_graph_ && num_bwd_calls_ > 1; } bool Reducer::ddp_graph_static() { @@ -613,6 +615,10 @@ void Reducer::set_logger(std::weak_ptr logger) { // This function is only to be called from the autograd thread. void Reducer::autograd_hook(size_t index) { std::lock_guard lock(this->mutex_); + if (!first_autograd_hook_called_) { + first_autograd_hook_called_ = true; + num_bwd_calls_++; + } // Ignore if we don't expect to be called. // This may be the case if the user wants to accumulate gradients // for number of iterations before reducing them. @@ -1520,6 +1526,8 @@ void Reducer::finalize_backward() { // No longer expect autograd hooks to fire after this function returns. TORCH_INTERNAL_ASSERT(expect_autograd_hooks_); expect_autograd_hooks_ = false; + // reset for the next iteration + first_autograd_hook_called_ = false; // No longer require call to finalize after this function returns. TORCH_INTERNAL_ASSERT(require_finalize_); diff --git a/torch/csrc/distributed/c10d/reducer.hpp b/torch/csrc/distributed/c10d/reducer.hpp index 7194c6443d19..3b90309e0f31 100644 --- a/torch/csrc/distributed/c10d/reducer.hpp +++ b/torch/csrc/distributed/c10d/reducer.hpp @@ -404,6 +404,11 @@ class TORCH_API Reducer { // track the number of iterations to synchronize grads in training so far. long num_iterations_; + // track distinct iteration of backward call. This is distinct from num_iterations_, + // for example in the case of multiple forward before backward. + long num_bwd_calls_; + // whether the first autograd hook for a distinct backward pass has been called. + bool first_autograd_hook_called_; // track the number of buckets that have been ready for // communication calls like allReduce or communication hooks. int num_buckets_ready_; diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index def38a76c0e3..d0487bb39ab9 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -239,12 +239,11 @@ class _BufferCommHook: # is completed. class _DDPSink(Function): @staticmethod - def forward(ctx, reducer, state_dict, *inputs): + def forward(ctx, ddp_weakref, *inputs): # set_materialize_grads(False) will ensure that None gradients stay as # None and are not filled with zeros. ctx.set_materialize_grads(False) - ctx.reducer = reducer - ctx.state_dict = state_dict + ctx.ddp_weakref = ddp_weakref ret = tuple( inp.clone() if isinstance(inp, torch.Tensor) else inp for inp in inputs ) @@ -254,12 +253,19 @@ def forward(ctx, reducer, state_dict, *inputs): def backward(ctx, *grad_outputs): # Enqueue delay allreduce for static graph training on the first # iteration. - if ctx.state_dict["static_graph"] and ctx.state_dict["num_iterations"] == 1: + ddp_weakref = ctx.ddp_weakref() + reducer = ddp_weakref.reducer + static_graph = ddp_weakref.static_graph + delay_ar_enqueued = ( + static_graph and ddp_weakref._static_graph_delay_allreduce_enqueued + ) + if static_graph and not delay_ar_enqueued: Variable._execution_engine.queue_callback( # type: ignore[call-arg,misc] - ctx.reducer._delay_all_reduce + reducer._delay_all_reduce ) + ddp_weakref._static_graph_delay_allreduce_enqueued = True - return (None, None, *grad_outputs) + return (None, *grad_outputs) class _DDPJoinHook(JoinHook): @@ -1047,7 +1053,6 @@ def _ddp_init_helper( (4) Logging construction-time DDP logging data (5) passing a handle of DDP to SyncBatchNorm Layer """ - self.num_iterations = 0 # Notice, the parameters order is not in the order in which they are used, # especially in models with control flow. # @@ -1381,7 +1386,6 @@ def _pre_forward(self, *inputs, **kwargs): if torch.is_grad_enabled() and self.require_backward_grad_sync: assert self.logger is not None self.logger.set_runtime_stats_and_log() - self.num_iterations += 1 self.reducer.prepare_for_forward() # Notify the join context that this process has not joined, if @@ -1466,13 +1470,8 @@ def _post_forward(self, output): # TODO: DDPSink is currently enabled for unused parameter detection and # static graph training for first iteration. if (self.find_unused_parameters and not self.static_graph) or ( - self.static_graph and self.num_iterations == 1 + self.static_graph and not self._static_graph_delay_allreduce_enqueued ): - state_dict = { - "static_graph": self.static_graph, - "num_iterations": self.num_iterations, - } - ( output_tensor_list, treespec, @@ -1491,8 +1490,7 @@ def _post_forward(self, output): # undefined gradient which the reducer then handles to ensure # param.grad field is not touched and we don't error out. passthrough_tensor_list = _DDPSink.apply( - self.reducer, - state_dict, + weakref.ref(self), *output_tensor_list, ) for i in range(len(output_placeholders)): @@ -2204,6 +2202,7 @@ def _set_static_graph(self): ) return self.static_graph = True + self._static_graph_delay_allreduce_enqueued = False self.reducer._set_static_graph() assert self.logger is not None self.logger._set_static_graph() diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 1cb1182452a0..9228e0ab53d5 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -9810,6 +9810,62 @@ def forward(self, x): for buf in bufs[1:]: self.assertEqual(rank_0_buf, buf) + @skip_if_lt_x_gpu(2) + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl" and BACKEND != "gloo", + "Only Nccl & Gloo backend support DistributedDataParallel", + ) + def test_static_graph_multi_forward(self): + class Net(nn.Module): + def __init__(self): + super().__init__() + self.lin = nn.Linear(10, 10) + self.relu = nn.ReLU() + + def forward(self, x): + return self.relu(self.lin(x)) + + torch.cuda.set_device(self.rank) + torch.manual_seed(42 << 1337 % (self.rank + 1)) + model = Net().cuda(self.rank) + local_model = copy.deepcopy(model) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[self.rank], static_graph=True + ) + inp = torch.ones(2, 10, device="cuda") + for _ in range(3): + model.zero_grad() + local_model.zero_grad() + a = model(inp) + b = model(inp) + loss = a.sum() + b.sum() + loss.backward() + # Grads should be equal to a local model that ran through inp twice and averaged grads + if self.rank == 0: + inp_clone = inp.clone() + for _ in range(2): + a = local_model(inp_clone) + b = local_model(inp_clone) + loss = a.sum() + b.sum() + loss.backward() + + ws = dist.get_world_size() + for p in local_model.parameters(): + p.grad.data = p.grad / dist.get_world_size() + + for p_ddp, p_local in zip( + model.parameters(), + local_model.parameters() + ): + self.assertTrue( + torch.allclose( + p_ddp.grad, p_local.grad + ), + f"{p_ddp.grad} vs {p_local.grad}" + ) + + dist.barrier() + @skip_if_lt_x_gpu(2) @skip_but_pass_in_sandcastle_if( BACKEND != "nccl" and BACKEND != "gloo",