-
Notifications
You must be signed in to change notification settings - Fork 25.6k
force_stride_order on fused_all_gather_matmul/fused_matmul_reduce_scatter's operands to avoid a copy due to layout transformation #127454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…tter's operands to avoid a copy due to layout transformation When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation. This copy can be avoided if the lhs operand already has the following stride order: lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride() In `micro_pipeline_tp` passes, we enforce the lhs operand to have such stride order via `inductor_prims.force_stride_order`. This way if the lhs operand has a flexible layout, the copy is avoided. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/127454
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 9069c7d with merge base d3b8230 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…_reduce_scatter's operands to avoid a copy due to layout transformation" When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation. This copy can be avoided if the lhs operand already has the following stride order: lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride() In `micro_pipeline_tp` passes, we enforce the lhs operand to have such stride order via `inductor_prims.force_stride_order`. This way if the lhs operand has a flexible layout, the copy is avoided. [ghstack-poisoned]
…_reduce_scatter's operands to avoid a copy due to layout transformation" When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation. This copy can be avoided if the lhs operand already has the following stride order: lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride() In `micro_pipeline_tp` passes, we enforce the lhs operand to have such stride order via `inductor_prims.force_stride_order`. This way if the lhs operand has a flexible layout, the copy is avoided. [ghstack-poisoned]
…_reduce_scatter's operands to avoid a copy due to layout transformation" When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation. This copy can be avoided if the lhs operand already has the following stride order: lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride() In `micro_pipeline_tp` passes, we enforce the lhs operand to have such stride order via `inductor_prims.force_stride_order`. This way if the lhs operand has a flexible layout, the copy is avoided. [ghstack-poisoned]
…_reduce_scatter's operands to avoid a copy due to layout transformation" When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation. This copy can be avoided if the lhs operand already has the following stride order: lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride() In `micro_pipeline_tp` passes, we enforce the lhs operand to have such stride order via `inductor_prims.force_stride_order`. This way if the lhs operand has a flexible layout, the copy is avoided. [ghstack-poisoned]
…_reduce_scatter's operands to avoid a copy due to layout transformation" When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation. This copy can be avoided if the lhs operand already has the following stride order: lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride() In `micro_pipeline_tp` passes, we enforce the lhs operand to have such stride order via `inductor_prims.force_stride_order`. This way if the lhs operand has a flexible layout, the copy is avoided. [ghstack-poisoned]
…_reduce_scatter's operands to avoid a copy due to layout transformation" When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation. This copy can be avoided if the lhs operand already has the following stride order: lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride() In `micro_pipeline_tp` passes, we enforce the lhs operand to have such stride order via `inductor_prims.force_stride_order`. This way if the lhs operand has a flexible layout, the copy is avoided. [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: linux-binary-manywheel / manywheel-py3_8-cuda11_8-test / test Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: linux-binary-manywheel / manywheel-py3_8-cuda11_8-test / test Details for Dev Infra teamRaised by workflow job |
…_reduce_scatter's operands to avoid a copy due to layout transformation" When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation. This copy can be avoided if the lhs operand already has the following stride order: lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride() In `micro_pipeline_tp` passes, we enforce the lhs operand to have such stride order via `inductor_prims.force_stride_order`. This way if the lhs operand has a flexible layout, the copy is avoided. cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k voznesenskym EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
In fused_all_gather_matmul, each rank copies their shard into their local p2p buffer, performs a barrier, then performs (copy -> matmul) for each remote shard. The (copy -> matmul)s for remote shards run on two streams without synchronization. This not only allows for computation/communication overlapping, but also computation/computation overlapping which alleviates the wave quantization effect caused by computation decomposition. However, the synchronization-free approach doesn't work well with fused_matmul_reduce_scatter, in which there's a barrier in every step. Without synchronization between the two streams, a matmul in one stream can delay a barrier in the other stream, further delaying the copy waiting for the barrier. This PR addresss the issue by adding synchronization between the two streams such that the matmul of step i can only start after the barrier of step i-1 completes. With this approach, we lose the computation/computation overlapping, but avoid slowdown due to delayed barrier. Pull Request resolved: #127455 Approved by: https://github.com/Chillee ghstack dependencies: #127454
…reduce_scatter (#127556) Pull Request resolved: #127556 Approved by: https://github.com/awgu ghstack dependencies: #127454, #127455
…tter's operands to avoid a copy due to layout transformation (pytorch#127454) When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation. This copy can be avoided if the lhs operand already has the following stride order: lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride() In `micro_pipeline_tp` passes, we enforce the lhs operand to have such stride order via `inductor_prims.force_stride_order`. This way if the lhs operand has a flexible layout, the copy is avoided. Pull Request resolved: pytorch#127454 Approved by: https://github.com/Chillee
In fused_all_gather_matmul, each rank copies their shard into their local p2p buffer, performs a barrier, then performs (copy -> matmul) for each remote shard. The (copy -> matmul)s for remote shards run on two streams without synchronization. This not only allows for computation/communication overlapping, but also computation/computation overlapping which alleviates the wave quantization effect caused by computation decomposition. However, the synchronization-free approach doesn't work well with fused_matmul_reduce_scatter, in which there's a barrier in every step. Without synchronization between the two streams, a matmul in one stream can delay a barrier in the other stream, further delaying the copy waiting for the barrier. This PR addresss the issue by adding synchronization between the two streams such that the matmul of step i can only start after the barrier of step i-1 completes. With this approach, we lose the computation/computation overlapping, but avoid slowdown due to delayed barrier. Pull Request resolved: pytorch#127455 Approved by: https://github.com/Chillee ghstack dependencies: pytorch#127454
…reduce_scatter (pytorch#127556) Pull Request resolved: pytorch#127556 Approved by: https://github.com/awgu ghstack dependencies: pytorch#127454, pytorch#127455
…tter's operands to avoid a copy due to layout transformation (pytorch#127454) When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation. This copy can be avoided if the lhs operand already has the following stride order: lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride() In `micro_pipeline_tp` passes, we enforce the lhs operand to have such stride order via `inductor_prims.force_stride_order`. This way if the lhs operand has a flexible layout, the copy is avoided. Pull Request resolved: pytorch#127454 Approved by: https://github.com/Chillee
In fused_all_gather_matmul, each rank copies their shard into their local p2p buffer, performs a barrier, then performs (copy -> matmul) for each remote shard. The (copy -> matmul)s for remote shards run on two streams without synchronization. This not only allows for computation/communication overlapping, but also computation/computation overlapping which alleviates the wave quantization effect caused by computation decomposition. However, the synchronization-free approach doesn't work well with fused_matmul_reduce_scatter, in which there's a barrier in every step. Without synchronization between the two streams, a matmul in one stream can delay a barrier in the other stream, further delaying the copy waiting for the barrier. This PR addresss the issue by adding synchronization between the two streams such that the matmul of step i can only start after the barrier of step i-1 completes. With this approach, we lose the computation/computation overlapping, but avoid slowdown due to delayed barrier. Pull Request resolved: pytorch#127455 Approved by: https://github.com/Chillee ghstack dependencies: pytorch#127454
…reduce_scatter (pytorch#127556) Pull Request resolved: pytorch#127556 Approved by: https://github.com/awgu ghstack dependencies: pytorch#127454, pytorch#127455
…tter's operands to avoid a copy due to layout transformation When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation. This copy can be avoided if the lhs operand already has the following stride order: lhs.movedim(gather_dim, 0).contiguous().movedim(0, gather_dim).stride() In `micro_pipeline_tp` passes, we enforce the lhs operand to have such stride order via `inductor_prims.force_stride_order`. This way if the lhs operand has a flexible layout, the copy is avoided. ghstack-source-id: b0123c1 Pull Request resolved: pytorch/pytorch#127454
Stack from ghstack (oldest at bottom):
When performing fused_all_gather_matmul/fused_matmul_reduce_scatter and gather_dim/scatter_dim != 0, a copy of the lhs operand (A_shard/A) is needed for layout transformation.
This copy can be avoided if the lhs operand already has the following stride order:
In
micro_pipeline_tp
passes, we enforce the lhs operand to have such stride order viainductor_prims.force_stride_order
. This way if the lhs operand has a flexible layout, the copy is avoided.cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire