Skip to content

Commit

Permalink
[Pipelining] Clean up function names in 1f1b schedule (#126582)
Browse files Browse the repository at this point in the history
Pull Request resolved: #126582
Approved by: https://github.com/kwen2501
ghstack dependencies: #126539
  • Loading branch information
wconstab authored and pytorchmergebot committed May 21, 2024
1 parent 8c9d332 commit 4b23c4f
Showing 1 changed file with 16 additions and 21 deletions.
37 changes: 16 additions & 21 deletions torch/distributed/pipelining/PipelineSchedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,6 @@ def _step_microbatches(
"""
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)

# forward for num_microbatches + backward for num_microbatches
total_ops = self._n_microbatches * 2

# Example, 4 GPUs, 8 microbatches
# Stage 0: 6 warmup, 2 1f1b, 6 cooldown
# Stage 1: 4 warmup, 4 1f1b, 4 cooldown
Expand All @@ -408,7 +405,9 @@ def _step_microbatches(
# fwd + bwd
main_1f1b_steps = self._n_microbatches - warmup_steps
# bwd only
cooldown_steps = total_ops - (warmup_steps + (2 * main_1f1b_steps))
cooldown_steps = (2 * self._n_microbatches) - (
warmup_steps + (2 * main_1f1b_steps)
)
total_steps = warmup_steps + main_1f1b_steps + cooldown_steps
logger.debug(
f"Stage {self._stage.stage_index}: " # noqa: G004
Expand All @@ -422,43 +421,39 @@ def _step_microbatches(
fwd_sends_to_wait: List[dist.Work] = []
bwd_sends_to_wait: List[dist.Work] = []

def is_forward_step(i):
def step_has_forward(i):
assert i >= 0, i
return i < self._n_microbatches

def is_backward_step(i):
def step_has_backward(i):
assert i < total_steps, i
return i >= warmup_steps and self._has_backward

def is_1f1b_step(i):
return is_forward_step(i) and is_backward_step(i)
return step_has_forward(i) and step_has_backward(i)

def is_warmup_step(i):
return is_forward_step(i) and not is_backward_step(i)
return step_has_forward(i) and not step_has_backward(i)

def is_cooldown_step(i):
return not is_forward_step(i) and is_backward_step(i)
return not step_has_forward(i) and step_has_backward(i)

def should_coalesce_fwd_send_bwd_recv(fwd_send_i):
def should_coalesce_fwd_send_bwd_recv(step):
return (
is_1f1b_step(fwd_send_i)
or (is_warmup_step(fwd_send_i) and is_cooldown_step(fwd_send_i + 1))
or (
fwd_send_i >= 1
and is_warmup_step(fwd_send_i - 1)
and is_cooldown_step(fwd_send_i)
)
is_1f1b_step(step)
or (is_warmup_step(step) and is_cooldown_step(step + 1))
or (step >= 1 and is_warmup_step(step - 1) and is_cooldown_step(step))
)

def should_coalesce_bwd_send_fwd_recv(bwd_send_i):
def should_coalesce_bwd_send_fwd_recv(bwd_send_step):
# The backward send to prev stage should be coalesced with the fwd recv from the previous stage
return bwd_send_i >= warmup_steps and is_1f1b_step(bwd_send_i + 1)
return bwd_send_step >= warmup_steps and is_1f1b_step(bwd_send_step + 1)

# bwd chunk counter
bwd_mb_index = 0
self._stage._configure_data_parallel_mode(last_backward=False)
for i in range(total_steps):
if is_forward_step(i):
if step_has_forward(i):
with record_function(f"Forward {i}"):
ops = self._stage.get_fwd_recv_ops()
desc = "fwd_recv"
Expand All @@ -479,7 +474,7 @@ def should_coalesce_bwd_send_fwd_recv(bwd_send_i):

self._maybe_compute_loss(self._stage, output, target_mbs, i)

if is_backward_step(i):
if step_has_backward(i):
self._stage._configure_data_parallel_mode(
last_backward=(i == total_steps - 1)
)
Expand Down

0 comments on commit 4b23c4f

Please sign in to comment.