Skip to content

Commit

Permalink
[DDP] Rename num_iterations -> num_forward_calls
Browse files Browse the repository at this point in the history
This more accurately represents what we're counting. At iteration is a
forward + backward call, but here we're just counting forward calls. This makes
things less confusing in future diffs where we support DDP static graph
multiple forwards.

Differential Revision: [D46580601](https://our.internmc.facebook.com/intern/diff/D46580601/)

ghstack-source-id: 191683465
Pull Request resolved: #103283
  • Loading branch information
rohan-varma committed Jun 9, 2023
1 parent 734967a commit 2c5dfc3
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions torch/nn/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def forward(ctx, reducer, ddp_state, *inputs):
def backward(ctx, *grad_outputs):
# Enqueue delay allreduce for static graph training on the first
# iteration.
if ctx.ddp_state["static_graph"] and ctx.ddp_state["num_iterations"] == 1:
if ctx.ddp_state["static_graph"] and ctx.ddp_state["num_forward_calls"] == 1:
Variable._execution_engine.queue_callback( # type: ignore[call-arg,misc]
ctx.reducer._delay_all_reduce
)
Expand Down Expand Up @@ -1047,7 +1047,7 @@ def _ddp_init_helper(
(4) Logging construction-time DDP logging data
(5) passing a handle of DDP to SyncBatchNorm Layer
"""
self.num_iterations = 0
self.num_forward_calls = 0
# Notice, the parameters order is not in the order in which they are used,
# especially in models with control flow.
#
Expand Down Expand Up @@ -1381,7 +1381,7 @@ def _pre_forward(self, *inputs, **kwargs):
if torch.is_grad_enabled() and self.require_backward_grad_sync:
assert self.logger is not None
self.logger.set_runtime_stats_and_log()
self.num_iterations += 1
self.num_forward_calls += 1
self.reducer.prepare_for_forward()

# Notify the join context that this process has not joined, if
Expand Down Expand Up @@ -1466,11 +1466,11 @@ def _post_forward(self, output):
# TODO: DDPSink is currently enabled for unused parameter detection and
# static graph training for first iteration.
if (self.find_unused_parameters and not self.static_graph) or (
self.static_graph and self.num_iterations == 1
self.static_graph and self.num_forward_calls == 1
):
ddp_state = {
"static_graph": self.static_graph,
"num_iterations": self.num_iterations,
"num_forward_calls": self.num_forward_calls,
}

(
Expand Down

0 comments on commit 2c5dfc3

Please sign in to comment.