Skip to content
This repository was archived by the owner on Aug 5, 2025. It is now read-only.
This repository was archived by the owner on Aug 5, 2025. It is now read-only.

[BUG] num_stages incorrect and some assertions  #1143

@jq-wei

Description

@jq-wei

Hi,

First of all, thank you for the great work.

I am trying the llama example script with llama2-7b-hf and the following key packages:

torch                    2.5.0
torchpippy               0.2.0
torchtext                0.6.0
torchview                0.2.6

When I run torchrun --nproc-per-node 4 pippy_llama.py, I got the following error on device 0 :

[rank0]: Traceback (most recent call last):
[rank0]:   File "/mnt/disk1/w84373270/test_pippy.py", line 48, in <module>
[rank0]:     stage = pipe.build_stage(rank,  device=device)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 1150, in build_stage
[rank0]:     return _PipelineStage(stage_module, stage_index, pipe_info, device, group)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py", line 799, in __init__
[rank0]:     _PipelineStageBase.__init__(
[rank0]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py", line 138, in __init__
[rank0]:     raise RuntimeError(
[rank0]: RuntimeError: Pipeline group size 4 cannot be larger than number of stages 1

I can trace back to _number_and_count_forward_stages in _IR.py and indeed the num_stages = 1 due to there is only one node.op == "call_module", and all the other node.op == "call_function".

Just for the sake to go deeper, I hard code the return in _number_and_count_forward_stages to be 4. Then I got the following error

[rank0]: Traceback (most recent call last):
[rank0]:   File "/mnt/disk1/w84373270/test_pippy.py", line 48, in <module>
[rank0]:     stage = pipe.build_stage(rank,  device=device)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 1150, in build_stage
[rank0]:     return _PipelineStage(stage_module, stage_index, pipe_info, device, group)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py", line 816, in __init__
[rank0]:     raise AssertionError(
[rank0]: AssertionError: Number of submodules in pipe graph 1 does not match number of stages 4
[rank2]: Traceback (most recent call last):
[rank2]:   File "/mnt/disk1/w84373270/test_pippy.py", line 48, in <module>
[rank2]:     stage = pipe.build_stage(rank,  device=device)
[rank2]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 1126, in build_stage
[rank2]:     stage_module = self.get_stage_module(stage_index)
[rank2]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 643, in get_stage_module
[rank2]:     return getattr(self.split_gm, f"submod_{stage_idx}")
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1931, in __getattr__
[rank2]:     raise AttributeError(
[rank2]: AttributeError: 'GraphModule' object has no attribute 'submod_2'. Did you mean: 'submod_0'?
[rank1]: Traceback (most recent call last):
[rank1]:   File "/mnt/disk1/w84373270/test_pippy.py", line 48, in <module>
[rank1]:     stage = pipe.build_stage(rank,  device=device)
[rank1]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 1126, in build_stage
[rank1]:     stage_module = self.get_stage_module(stage_index)
[rank1]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 643, in get_stage_module
[rank1]:     return getattr(self.split_gm, f"submod_{stage_idx}")
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1931, in __getattr__
[rank1]:     raise AttributeError(
[rank1]: AttributeError: 'GraphModule' object has no attribute 'submod_1'. Did you mean: 'submod_0'?
[rank3]: Traceback (most recent call last):
[rank3]:   File "/mnt/disk1/w84373270/test_pippy.py", line 48, in <module>
[rank3]:     stage = pipe.build_stage(rank,  device=device)
[rank3]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 1126, in build_stage
[rank3]:     stage_module = self.get_stage_module(stage_index)
[rank3]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 643, in get_stage_module
[rank3]:     return getattr(self.split_gm, f"submod_{stage_idx}")
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1931, in __getattr__
[rank3]:     raise AttributeError(
[rank3]: AttributeError: 'GraphModule' object has no attribute 'submod_3'. Did you mean: 'submod_0'?

It seems the version matching problem is still there. By the way, the same problems happen if I uninstall torchpippy.

Could you give me some hints?

Thank you very much!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions