Skip to content

Commit

Permalink
Update on "[pipelining] Add manual pipeline stage"
Browse files Browse the repository at this point in the history
Add `ManualPipelineStage` under `_PipelineStage.py`

Fix some type hints since `args_recv_info` can contain more than one RecvInfo. Previously the hint was `Tuple[InputInfo]` which meant it is a tuple of size 1. This is different from `List[InputInfo]` which can contain any number of items. I needed to update to `Tuple[InputInfo, ...]` to make the number of items flexible.




cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
  • Loading branch information
H-Huang committed May 14, 2024
1 parent 17da510 commit c14c901
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions torch/distributed/pipelining/_PipelineStage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,7 @@ def validate_stage_shapes(pipeline_stages: List[ManualPipelineStage]):
# perform all gathers between all stages
for virtual_id, stage in enumerate(pipeline_stages):
world_size = stage.group_size
stage_id = stage.stage_index
stage_id: int = stage.stage_index
rank = stage.group_rank
# check that world_size and num_stages are consistent across all stages
if stage.group_size != world_size:
Expand All @@ -1157,8 +1157,8 @@ def validate_stage_shapes(pipeline_stages: List[ManualPipelineStage]):
which does not match num stages ({num_stages}) of other stages."
)

# TODO: once we pass in pg to stage, check the pg rank is same as stage rank
if rank != (pg_rank := dist.get_rank()):
pg_rank = dist.get_rank(stage.group)
if rank != pg_rank:
raise ValueError(
f"Rank {rank} is not equal to process group rank {pg_rank}"
)
Expand Down Expand Up @@ -1203,16 +1203,16 @@ def validate_stage_shapes(pipeline_stages: List[ManualPipelineStage]):
all_inputs.extend(stage_input_shapes)
all_outputs.extend(stage_output_shapes)

# log only rank 0's view, they will all be equivalent
if pg_rank == 0:
logger.info(
f"all stage inputs: {all_inputs}" # noqa: G004
f"all stage outputs: {all_outputs}"
)
# log only rank 0's view, they will all be equivalent
if pg_rank == 0:
logger.info(
f"all stage inputs: {all_inputs}" # noqa: G004
f"all stage outputs: {all_outputs}"
)

# Check if the output for stage 0 matches the input at stage 1, and so forth
for i in range(virtual_pipeline_size * world_size - 1):
if (out := all_outputs[i]) != (inp := all_inputs[i + 1]):
raise ValueError(
f"Stage_id {stage_id} output shape {out} at does not match stage_id {i + 1} input shape {inp}."
f"Stage_id {i} output shape {out} at does not match stage_id {i + 1} input shape {inp}."
)

0 comments on commit c14c901

Please sign in to comment.