Skip to content

Conversation

tom-pollak
Copy link
Contributor

@tom-pollak tom-pollak commented May 6, 2025

Fixes #132644

_batch_p2p incorrectly assumes that dist.batch_isend_irecv returns a single-element list of dist.Work, likely due to NCCL's coalescing behaviour.

For none NCCL backends like Gloo, multiple dist.Work objects are returned, causing the code to discard some operations via .pop(). This leads to deadlocks during pipeline parallelism.

Changes:

  • Modified _batch_p2p to return list[dist.Work] instead of popping a single element.
  • Added _wait_batch_p2p to call wait() on multiple dist.Work objects, consuming the result of _batch_p2p.
  • Updated references from dist.Work to list[dist.Work].

Testing:

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

Copy link

pytorch-bot bot commented May 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152938

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit a8a1884 with merge base f2ea636 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label May 6, 2025
@tom-pollak
Copy link
Contributor Author

tom-pollak commented May 6, 2025

@pytorchbot label "module: pipelining"

Copy link

pytorch-bot bot commented May 6, 2025

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: argument command: invalid choice: 'label:' (choose from 'merge', 'revert', 'rebase', 'label', 'drci', 'cherry-pick', 'close')

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci,cherry-pick,close} ...

Try @pytorchbot --help for more info.

@tom-pollak
Copy link
Contributor Author

@pytorchbot label "module: pipelining"

@pytorch-bot pytorch-bot bot added the module: pipelining Pipeline Parallelism label May 6, 2025
@tom-pollak
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label May 6, 2025
@tom-pollak
Copy link
Contributor Author

@pytorchbot rebase main

Copy link

pytorch-bot bot commented May 7, 2025

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: unrecognized arguments: main

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci,cherry-pick,close} ...

Try @pytorchbot --help for more info.

@tom-pollak
Copy link
Contributor Author

@pytorchbot rebase -b main

Copy link

pytorch-bot bot commented May 7, 2025

You don't have permissions to rebase this PR since you are a first time contributor. If you think this is a mistake, please contact PyTorch Dev Infra.

@janeyx99 janeyx99 requested a review from kwen2501 May 7, 2025 19:41
@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 7, 2025
@kwen2501 kwen2501 requested a review from H-Huang May 7, 2025 22:16
Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. LGTM.
Maybe @H-Huang want to have a second look?

@H-Huang
Copy link
Member

H-Huang commented May 8, 2025

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch push -f https://github.com/graphcore/pytorch-fork.git pull/152938/head:pp-deadlock returned non-zero exit code 128

remote: Permission to graphcore/pytorch-fork.git denied to pytorchmergebot.
fatal: unable to access 'https://github.com/graphcore/pytorch-fork.git/': The requested URL returned error: 403

This is likely because the author did not allow edits from maintainers on the PR or because the repo has additional permissions settings that mergebot does not qualify.
Raised by https://github.com/pytorch/pytorch/actions/runs/14907026043

@tom-pollak
Copy link
Contributor Author

@pytorchbot rebase

Copy link

pytorch-bot bot commented May 8, 2025

You don't have permissions to rebase this PR since you are a first time contributor. If you think this is a mistake, please contact PyTorch Dev Infra.

@tom-pollak
Copy link
Contributor Author

Seems to be a problem with cross-org "allow edits from maintainers"?
https://github.com/orgs/community/discussions/5634

Copy link
Member

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, I suspect there are still some issues with hangs on non-NCCL backends because of the dependencies between multiple ops across ranks vs just 1 op for nccl. Probably would see this issue arise in 1f1b or interleaved schedules.

Will you be testing other schedules on gloo? Also feel free to rebase on your fork and update the PR so we can get the testing signal

@tom-pollak
Copy link
Contributor Author

Probably would see this issue arise in 1f1b or interleaved schedules.

I'll have a look at this!

@tom-pollak
Copy link
Contributor Author

@pytorchbot drci

1 similar comment
@tom-pollak
Copy link
Contributor Author

@pytorchbot drci

tom-pollak added 2 commits May 9, 2025 16:29
Fixes pytorch#132644

`_batch_p2p` incorrectly assumes that `dist.batch_isend_irecv` returns a
single-element list of `dist.Work`, likely due to NCCL's coalescing
behaviour.

For none NCCL backends like Gloo, multiple `dist.Work` objects are
returned, causing the code to discard some operations via `.pop()`. This
leads to deadlocks during pipeline parallelism.

* Modified `_batch_p2p` to return `list[dist.Work]` instead of popping a
  single element.
* Added `_wait_batch_p2p` to call `wait()` on multiple `dist.Work`
  objects, consuming the result of `_batch_p2p`.
* Updated references from `dist.Work` to `list[dist.Work]`.

* `pippy_bert.py` from pytorch#132644 now works with gloo.
@H-Huang
Copy link
Member

H-Huang commented May 10, 2025

@pytorchbot merge -i

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 10, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: pull / linux-jammy-py3-clang12-executorch / test (executorch, 1, 1, ephemeral.linux.2xlarge)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: pipelining Pipeline Parallelism oncall: distributed Add this issue/PR to distributed oncall triage queue open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.distributed.pipelining hang and timeout in CPU gloo backend

6 participants