Skip to content

Conversation

yifuwang
Copy link
Collaborator

@yifuwang yifuwang commented May 29, 2024

Stack from ghstack (oldest at bottom):

When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation.
This copy can be avoided if the lhs operand already has the following stride order:

lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride()

In micro_pipeline_tp passes, we enforce the lhs operand to have such stride order via inductor_prims.force_stride_order. This way if the lhs operand has a flexible layout, the copy is avoided.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire

…tter's operands to avoid a copy due to layout transformation

When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation.
This copy can be avoided if the lhs operand already has the following stride order:

    lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride()

In `micro_pipeline_tp` passes, we enforce the lhs operand to have such stride order via `inductor_prims.force_stride_order`. This way if the lhs operand has a flexible layout, the copy is avoided.

[ghstack-poisoned]
Copy link

pytorch-bot bot commented May 29, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/127454

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 9069c7d with merge base d3b8230 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Yifu Wang added 2 commits May 29, 2024 14:43
…_reduce_scatter's operands to avoid a copy due to layout transformation"

When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation.
This copy can be avoided if the lhs operand already has the following stride order:

    lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride()

In `micro_pipeline_tp` passes, we enforce the lhs operand to have such stride order via `inductor_prims.force_stride_order`. This way if the lhs operand has a flexible layout, the copy is avoided.

[ghstack-poisoned]
…_reduce_scatter's operands to avoid a copy due to layout transformation"

When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation.
This copy can be avoided if the lhs operand already has the following stride order:

    lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride()

In `micro_pipeline_tp` passes, we enforce the lhs operand to have such stride order via `inductor_prims.force_stride_order`. This way if the lhs operand has a flexible layout, the copy is avoided.

[ghstack-poisoned]
Yifu Wang added 2 commits May 30, 2024 11:52
…_reduce_scatter's operands to avoid a copy due to layout transformation"

When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation.
This copy can be avoided if the lhs operand already has the following stride order:

    lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride()

In `micro_pipeline_tp` passes, we enforce the lhs operand to have such stride order via `inductor_prims.force_stride_order`. This way if the lhs operand has a flexible layout, the copy is avoided.

[ghstack-poisoned]
…_reduce_scatter's operands to avoid a copy due to layout transformation"

When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation.
This copy can be avoided if the lhs operand already has the following stride order:

    lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride()

In `micro_pipeline_tp` passes, we enforce the lhs operand to have such stride order via `inductor_prims.force_stride_order`. This way if the lhs operand has a flexible layout, the copy is avoided.

[ghstack-poisoned]
@yifuwang yifuwang mentioned this pull request Jun 1, 2024
Yifu Wang added 2 commits June 3, 2024 11:08
…_reduce_scatter's operands to avoid a copy due to layout transformation"

When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation.
This copy can be avoided if the lhs operand already has the following stride order:

    lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride()

In `micro_pipeline_tp` passes, we enforce the lhs operand to have such stride order via `inductor_prims.force_stride_order`. This way if the lhs operand has a flexible layout, the copy is avoided.

[ghstack-poisoned]
…_reduce_scatter's operands to avoid a copy due to layout transformation"

When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation.
This copy can be avoided if the lhs operand already has the following stride order:

    lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride()

In `micro_pipeline_tp` passes, we enforce the lhs operand to have such stride order via `inductor_prims.force_stride_order`. This way if the lhs operand has a flexible layout, the copy is avoided.

[ghstack-poisoned]
@yifuwang
Copy link
Collaborator Author

yifuwang commented Jun 7, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 7, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: linux-binary-manywheel / manywheel-py3_8-cuda11_8-test / test

Details for Dev Infra team Raised by workflow job

@yifuwang
Copy link
Collaborator Author

yifuwang commented Jun 7, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: linux-binary-manywheel / manywheel-py3_8-cuda11_8-test / test

Details for Dev Infra team Raised by workflow job

…_reduce_scatter's operands to avoid a copy due to layout transformation"

When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation.
This copy can be avoided if the lhs operand already has the following stride order:

    lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride()

In `micro_pipeline_tp` passes, we enforce the lhs operand to have such stride order via `inductor_prims.force_stride_order`. This way if the lhs operand has a flexible layout, the copy is avoided.

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k voznesenskym EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire

[ghstack-poisoned]
@yifuwang
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@yifuwang
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Jun 13, 2024
In fused_all_gather_matmul, each rank copies their shard into their
local p2p buffer, performs a barrier, then performs (copy -> matmul) for
each remote shard. The (copy -> matmul)s for remote shards run on two
streams without synchronization. This not only allows for
computation/communication overlapping, but also computation/computation
overlapping which alleviates the wave quantization effect caused by
computation decomposition.

However, the synchronization-free approach doesn't work well with
fused_matmul_reduce_scatter, in which there's a barrier in every step.
Without synchronization between the two streams, a matmul in one stream
can delay a barrier in the other stream, further delaying the copy
waiting for the barrier.

This PR addresss the issue by adding synchronization between the two
streams such that the matmul of step i can only start after the barrier
of step i-1 completes. With this approach, we lose the
computation/computation overlapping, but avoid slowdown due to delayed
barrier.

Pull Request resolved: #127455
Approved by: https://github.com/Chillee
ghstack dependencies: #127454
pytorchmergebot pushed a commit that referenced this pull request Jun 13, 2024
…reduce_scatter (#127556)

Pull Request resolved: #127556
Approved by: https://github.com/awgu
ghstack dependencies: #127454, #127455
TharinduRusira pushed a commit to TharinduRusira/pytorch that referenced this pull request Jun 14, 2024
…tter's operands to avoid a copy due to layout transformation (pytorch#127454)

When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation.
This copy can be avoided if the lhs operand already has the following stride order:

    lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride()

In `micro_pipeline_tp` passes, we enforce the lhs operand to have such stride order via `inductor_prims.force_stride_order`. This way if the lhs operand has a flexible layout, the copy is avoided.

Pull Request resolved: pytorch#127454
Approved by: https://github.com/Chillee
TharinduRusira pushed a commit to TharinduRusira/pytorch that referenced this pull request Jun 14, 2024
In fused_all_gather_matmul, each rank copies their shard into their
local p2p buffer, performs a barrier, then performs (copy -> matmul) for
each remote shard. The (copy -> matmul)s for remote shards run on two
streams without synchronization. This not only allows for
computation/communication overlapping, but also computation/computation
overlapping which alleviates the wave quantization effect caused by
computation decomposition.

However, the synchronization-free approach doesn't work well with
fused_matmul_reduce_scatter, in which there's a barrier in every step.
Without synchronization between the two streams, a matmul in one stream
can delay a barrier in the other stream, further delaying the copy
waiting for the barrier.

This PR addresss the issue by adding synchronization between the two
streams such that the matmul of step i can only start after the barrier
of step i-1 completes. With this approach, we lose the
computation/computation overlapping, but avoid slowdown due to delayed
barrier.

Pull Request resolved: pytorch#127455
Approved by: https://github.com/Chillee
ghstack dependencies: pytorch#127454
TharinduRusira pushed a commit to TharinduRusira/pytorch that referenced this pull request Jun 14, 2024
ignaciobartol pushed a commit to ignaciobartol/pytorch that referenced this pull request Jun 14, 2024
…tter's operands to avoid a copy due to layout transformation (pytorch#127454)

When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation.
This copy can be avoided if the lhs operand already has the following stride order:

    lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride()

In `micro_pipeline_tp` passes, we enforce the lhs operand to have such stride order via `inductor_prims.force_stride_order`. This way if the lhs operand has a flexible layout, the copy is avoided.

Pull Request resolved: pytorch#127454
Approved by: https://github.com/Chillee
ignaciobartol pushed a commit to ignaciobartol/pytorch that referenced this pull request Jun 14, 2024
In fused_all_gather_matmul, each rank copies their shard into their
local p2p buffer, performs a barrier, then performs (copy -> matmul) for
each remote shard. The (copy -> matmul)s for remote shards run on two
streams without synchronization. This not only allows for
computation/communication overlapping, but also computation/computation
overlapping which alleviates the wave quantization effect caused by
computation decomposition.

However, the synchronization-free approach doesn't work well with
fused_matmul_reduce_scatter, in which there's a barrier in every step.
Without synchronization between the two streams, a matmul in one stream
can delay a barrier in the other stream, further delaying the copy
waiting for the barrier.

This PR addresss the issue by adding synchronization between the two
streams such that the matmul of step i can only start after the barrier
of step i-1 completes. With this approach, we lose the
computation/computation overlapping, but avoid slowdown due to delayed
barrier.

Pull Request resolved: pytorch#127455
Approved by: https://github.com/Chillee
ghstack dependencies: pytorch#127454
ignaciobartol pushed a commit to ignaciobartol/pytorch that referenced this pull request Jun 14, 2024
@github-actions github-actions bot deleted the gh/yifuwang/87/head branch July 14, 2024 02:03
francograndegmailcom pushed a commit to francograndegmailcom/pytorch-pytorch that referenced this pull request Jul 23, 2024
…tter's operands to avoid a copy due to layout transformation

When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation.
This copy can be avoided if the lhs operand already has the following stride order:

    lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride()

In `micro_pipeline_tp` passes, we enforce the lhs operand to have such stride order via `inductor_prims.force_stride_order`. This way if the lhs operand has a flexible layout, the copy is avoided.

ghstack-source-id: b0123c1
Pull Request resolved: pytorch/pytorch#127454
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants