Skip to content

Conversation

kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented May 16, 2025

Stack from ghstack (oldest at bottom):

Use MultiProcContinousTest to avoid re-create ProcessGroup in each test instance.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

[ghstack-poisoned]
Copy link

pytorch-bot bot commented May 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153677

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c5d9c3e with merge base fa85434 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category labels May 16, 2025
kwen2501 added a commit that referenced this pull request May 16, 2025
ghstack-source-id: 9b6f66e
Pull-Request-resolved: #153677
@kwen2501 kwen2501 requested review from ngimel and fegin May 16, 2025 01:49
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, the failing tests looks like from the previous PR.

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool! How long to symm mem tests run now?

[4, 8192, 8196],
[4, 8, 16],
[
8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interetsint, do you know why memory usage changed? All these tests use very little memory

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It it not a problem of the alignment, but the number of tests we run continuously. It seems we either failed to release tensors or there is some flaw in the allocation logic (e.g. allocated more than needed).

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request May 17, 2025
ghstack-source-id: 56d9fdf
Pull-Request-resolved: #153677
[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request May 19, 2025
ghstack-source-id: 8c73378
Pull-Request-resolved: #153677
@kwen2501
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 19, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request May 24, 2025
ghstack-source-id: f0a32f6
Pull-Request-resolved: #153677
@kwen2501
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@malfet
Copy link
Contributor

malfet commented May 26, 2025

@pytorchbot revert -m "I don't know how, but you PRs keep escaping TD and breaking trunk oops I wrong" -c nosignal

@malfet
Copy link
Contributor

malfet commented May 26, 2025

Sorry, looks like infra is just unhappy

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

Don't want to revert based on edited command

pytorchmergebot pushed a commit that referenced this pull request Jun 6, 2025
A 2D AllToAllv shuffle is illustrated below:
(`world_size` = 2, `ne` = 2, where `ne` is number of experts per rank)
```
        Source: |       Rank 0      |       Rank 1      |
                | c0 | c1 | c2 | c3 | d0 | d1 | d2 | d3 |

        Dest  : |       Rank 0      |       Rank 1      |
                | c0 | d0 | c1 | d1 | c2 | d2 | c3 | d3 |
```
where each `c_i` / `d_i` are slices of the `input` tensor, targeting expert `i`, with length indicated by input splits (in `in_out_splits[0]`).

That is, the 2D AllToAllv shuffle achieves a transpose from rank-major order at input to expert-major order at output.

Pull Request resolved: #155058
Approved by: https://github.com/ngimel
ghstack dependencies: #153653, #153677
pytorchmergebot pushed a commit that referenced this pull request Jun 6, 2025
Downstream consumer of the 2D all-to-all-v is often a group GEMM.
Today the GEMM often have an alignment requirement on the chunk sizes within grouped sequence, where each chunk carries the tokens headed for an expert. For example, `torch._group_mm` requires an alignment of 8.

This PR adds that alignment capability, when user passes in a `major_align` argument, so that no extra padding step is needed.

The key in supporting that is making the output offsets aligned to such value. (Output offsets are returned to the users in the 3rd row of `in_out_splits`, on device. The 2nd row, output splits, are unaffected by this alignment value -- i.e. reflecting true number of tokens for an expert.)

The algorithm is as follows.

![502413288_678786854922438_530852083153996358_n](https://github.com/user-attachments/assets/557624a3-150e-4ab6-ba8b-1dbaa5ac01ac)

In detailed implementation, we use warp scan to calculate prefix sum on the "block" illustrated above. As a result, the "block" size, i.e. `npes` is currently limited to warp size 32.

Pull Request resolved: #155172
Approved by: https://github.com/ngimel
ghstack dependencies: #153653, #153677, #155058
vijayabhaskar-ev pushed a commit to vijayabhaskar-ev/pytorch that referenced this pull request Jun 22, 2025
A 2D AllToAllv shuffle is illustrated below:
(`world_size` = 2, `ne` = 2, where `ne` is number of experts per rank)
```
        Source: |       Rank 0      |       Rank 1      |
                | c0 | c1 | c2 | c3 | d0 | d1 | d2 | d3 |

        Dest  : |       Rank 0      |       Rank 1      |
                | c0 | d0 | c1 | d1 | c2 | d2 | c3 | d3 |
```
where each `c_i` / `d_i` are slices of the `input` tensor, targeting expert `i`, with length indicated by input splits (in `in_out_splits[0]`).

That is, the 2D AllToAllv shuffle achieves a transpose from rank-major order at input to expert-major order at output.

Pull Request resolved: pytorch#155058
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#153653, pytorch#153677
vijayabhaskar-ev pushed a commit to vijayabhaskar-ev/pytorch that referenced this pull request Jun 22, 2025
…155172)

Downstream consumer of the 2D all-to-all-v is often a group GEMM.
Today the GEMM often have an alignment requirement on the chunk sizes within grouped sequence, where each chunk carries the tokens headed for an expert. For example, `torch._group_mm` requires an alignment of 8.

This PR adds that alignment capability, when user passes in a `major_align` argument, so that no extra padding step is needed.

The key in supporting that is making the output offsets aligned to such value. (Output offsets are returned to the users in the 3rd row of `in_out_splits`, on device. The 2nd row, output splits, are unaffected by this alignment value -- i.e. reflecting true number of tokens for an expert.)

The algorithm is as follows.

![502413288_678786854922438_530852083153996358_n](https://github.com/user-attachments/assets/557624a3-150e-4ab6-ba8b-1dbaa5ac01ac)

In detailed implementation, we use warp scan to calculate prefix sum on the "block" illustrated above. As a result, the "block" size, i.e. `npes` is currently limited to warp size 32.

Pull Request resolved: pytorch#155172
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#153653, pytorch#153677, pytorch#155058
@github-actions github-actions bot deleted the gh/kwen2501/154/head branch June 27, 2025 02:20
vijayabhaskar-ev pushed a commit to vijayabhaskar-ev/pytorch that referenced this pull request Jul 14, 2025
A 2D AllToAllv shuffle is illustrated below:
(`world_size` = 2, `ne` = 2, where `ne` is number of experts per rank)
```
        Source: |       Rank 0      |       Rank 1      |
                | c0 | c1 | c2 | c3 | d0 | d1 | d2 | d3 |

        Dest  : |       Rank 0      |       Rank 1      |
                | c0 | d0 | c1 | d1 | c2 | d2 | c3 | d3 |
```
where each `c_i` / `d_i` are slices of the `input` tensor, targeting expert `i`, with length indicated by input splits (in `in_out_splits[0]`).

That is, the 2D AllToAllv shuffle achieves a transpose from rank-major order at input to expert-major order at output.

Pull Request resolved: pytorch#155058
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#153653, pytorch#153677
vijayabhaskar-ev pushed a commit to vijayabhaskar-ev/pytorch that referenced this pull request Jul 14, 2025
…155172)

Downstream consumer of the 2D all-to-all-v is often a group GEMM.
Today the GEMM often have an alignment requirement on the chunk sizes within grouped sequence, where each chunk carries the tokens headed for an expert. For example, `torch._group_mm` requires an alignment of 8.

This PR adds that alignment capability, when user passes in a `major_align` argument, so that no extra padding step is needed.

The key in supporting that is making the output offsets aligned to such value. (Output offsets are returned to the users in the 3rd row of `in_out_splits`, on device. The 2nd row, output splits, are unaffected by this alignment value -- i.e. reflecting true number of tokens for an expert.)

The algorithm is as follows.

![502413288_678786854922438_530852083153996358_n](https://github.com/user-attachments/assets/557624a3-150e-4ab6-ba8b-1dbaa5ac01ac)

In detailed implementation, we use warp scan to calculate prefix sum on the "block" illustrated above. As a result, the "block" size, i.e. `npes` is currently limited to warp size 32.

Pull Request resolved: pytorch#155172
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#153653, pytorch#153677, pytorch#155058
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants