-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[a2av] Improve tuning for 4 GPUs #154580
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[a2av] Improve tuning for 4 GPUs #154580
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/154580
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit d181c33 with merge base 241f8dc ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
What about our triton-based all2allv? does it need this tuning too? |
It probably needs too. On the other hand, I am not 100% sure about its maintenance or move-in-core plan -- |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
### Problem Running `nvshmem_all_to_all_vdev` on 4 x H100s (fully connected with NVSwitch). Before: ``` Bytes: MiB, Time: us, BusBw: GB/s 0 32.29 16.23 1 33.01 31.76 2 33.01 63.54 4 33.83 123.97 8 49.83 168.34 16 80.82 207.59 32 178.66 187.82 64 335.79 199.86 128 646.72 207.54 256 1268.77 211.57 512 2511.14 213.80 1024 4998.31 214.82 2048 9964.49 215.51 4096 19892.34 215.91 ``` 215 GB/s does not reach the SOL of NV18 (350-400 GB/s). ### Change If the number of peers decreases (say 8 to 4), we do not reduce the number of CTAs; instead, we shift more CTAs towards the data parallel dimension. After: ``` Bytes: MiB, Time: us, BusBw: GB/s 0 25.01 20.96 1 25.70 40.80 2 25.76 81.42 4 28.87 145.26 8 40.79 205.64 16 61.46 272.97 32 111.82 300.06 64 202.40 331.57 128 382.56 350.84 256 739.11 363.19 512 1450.79 370.05 1024 2873.13 373.72 2048 5719.50 375.47 4096 11395.65 376.90 ``` If we look at MoE related region, say 32 MB, we can see a 187 -> 300 GB/s improvement. Pull Request resolved: pytorch#154580 Approved by: https://github.com/ngimel
Stack from ghstack (oldest at bottom):
Problem
Running
nvshmem_all_to_all_vdev
on 4 x H100s (fully connected with NVSwitch).Before:
215 GB/s does not reach the SOL of NV18 (350-400 GB/s).
Change
If the number of peers decreases (say 8 to 4), we do not reduce the number of CTAs; instead, we shift more CTAs towards the data parallel dimension.
After:
If we look at MoE related region, say 32 MB, we can see a 187 -> 300 GB/s improvement.
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k