Skip to content

Conversation

codingwithsurya
Copy link
Contributor

@codingwithsurya codingwithsurya commented Jun 20, 2025

This PR introduces device-side NVSHMEM memory ordering via the fence API in Triton, enabling GPU kernels to enforce completion and ordering of remote memory operations before subsequent operations proceed.

Changes:

  • Added a new core.extern wrapper for nvshmem_fence in nvshmem_triton.py
  • Implemented test_triton_fence in test/distributed/test_nvshmem.py, including:
    • A Triton kernel that performs two ordered putmem_block operations separated by fence() calls
    • Final fence before flag update to ensure all data transfers complete before signaling
    • Consumer validation that both buffers contain expected values when flag arrives, proving ordering guarantees

Tests:
$ TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py -k test_triton_fence

Stack from ghstack (oldest at bottom):

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Jun 20, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/156474

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ce21940 with merge base ef6d2ce (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
@codingwithsurya codingwithsurya changed the title adding nvshmem fence support + kernels [SymmMem] Add NVSHMEM Fence support to Triton Jun 20, 2025
@codingwithsurya codingwithsurya self-assigned this Jun 20, 2025
@codingwithsurya
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 21, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Jun 21, 2025
This PR introduces device-side NVSHMEM completion guarantees via the quiet API in Triton, enabling GPU kernels to ensure all pending remote memory operations are fully complete before proceeding with subsequent operations.

Changes:
- Added a new `core.extern` wrapper for `nvshmem_quiet` in `nvshmem_triton.py`
- Implemented `test_triton_quiet` in `test/distributed/test_nvshmem.py`, including:
  - A Triton kernel that performs `putmem_block` followed by `quiet()` to ensure completion
  - Flag-based signaling only after `quiet()` completes, guaranteeing data delivery
  - Consumer validation that when the completion flag arrives, all data transfers are guaranteed complete

Tests:
`$ TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py -k test_triton_quiet`

Pull Request resolved: #156475
Approved by: https://github.com/kwen2501
ghstack dependencies: #156472, #156473, #156474
@github-actions github-actions bot deleted the gh/codingwithsurya/3/head branch July 23, 2025 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants