-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[SymmMem] Add NVSHMEM signal_wait_until support to Triton #156473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SymmMem] Add NVSHMEM signal_wait_until support to Triton #156473
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/156473
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 3ba0a76 with merge base a67eb1a ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
flag_dtype = torch.int64 | ||
flag = out_hdl.get_signal_pad(rank, (1,), dtype=flag_dtype).fill_(0) | ||
# Ensure setup is complete on all ranks before proceeding | ||
dist.barrier() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A note for us to replace the barrier with other APIs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
talked offline; will take care of this w/ a PR on top
@pytorchbot rebase |
@pytorchbot merge |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR introduces device-side NVSHMEM memory ordering via the fence API in Triton, enabling GPU kernels to enforce completion and ordering of remote memory operations before subsequent operations proceed. Changes: - Added a new `core.extern` wrapper for `nvshmem_fence` in `nvshmem_triton.py` - Implemented `test_triton_fence` in `test/distributed/test_nvshmem.py`, including: - A Triton kernel that performs two ordered `putmem_block` operations separated by `fence()` calls - Final fence before flag update to ensure all data transfers complete before signaling - Consumer validation that both buffers contain expected values when flag arrives, proving ordering guarantees Tests: `$ TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py -k test_triton_fence` Pull Request resolved: #156474 Approved by: https://github.com/mandroid6, https://github.com/kwen2501 ghstack dependencies: #156472, #156473
This PR introduces device-side NVSHMEM completion guarantees via the quiet API in Triton, enabling GPU kernels to ensure all pending remote memory operations are fully complete before proceeding with subsequent operations. Changes: - Added a new `core.extern` wrapper for `nvshmem_quiet` in `nvshmem_triton.py` - Implemented `test_triton_quiet` in `test/distributed/test_nvshmem.py`, including: - A Triton kernel that performs `putmem_block` followed by `quiet()` to ensure completion - Flag-based signaling only after `quiet()` completes, guaranteeing data delivery - Consumer validation that when the completion flag arrives, all data transfers are guaranteed complete Tests: `$ TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py -k test_triton_quiet` Pull Request resolved: #156475 Approved by: https://github.com/kwen2501 ghstack dependencies: #156472, #156473, #156474
This PR introduces device-side NVSHMEM signal synchronization via the signal_wait_until API in Triton, enabling GPU kernels to block until a signal variable meets a specified condition. This replaces previous barrier-based synchronization patterns with more efficient signal-based coordination between PEs.
Changes:
core.extern
wrapper fornvshmem_signal_wait_until
innvshmem_triton.py
test_triton_put_signal
andtest_triton_put_signal_add
tests to usesignal_wait_until
instead ofdist.barrier()
for proper device-side synchronization (per feedback)test_triton_signal_wait_until
with:putmem_signal_block
signal_wait_until
to block until the signal variable reaches the expected valueTests:
$ TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py -k test_triton_signal_wait_until
Stack from ghstack (oldest at bottom):
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k