-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[SymmMem] Fix flaky wait_until test #159215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SymmMem] Fix flaky wait_until test #159215
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159215
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Cancelled Job, 1 Unrelated FailureAs of commit 7e1d91f with merge base 3daef4d ( CANCELLED JOB - The following job was cancelled. Please retry:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
When playing around with it, I noticed some flakiness in these tests across sessions. In the ring broadcast test, I kept getting illegal memory access errors. After some debugging and checking the NVSHMEM docs, turns out `signal_op` wasn't the right call here. The issue is `putmem_signal_block` gives you atomic guarantees - signal only comes after data transfer is done. But `signal_op` is just standalone signaling, no connection to data transfers. Without that atomic guarantee, you get race conditions where signals show up before data, so PEs try accessing data that isn't there yet. I still need to figure out how to get signal_op working with signal_wait_until, I've been facing some issues with it. For the wait_until test, I got NCCL hangs. After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock. The fix was realizing `wait_until` already provides all the sync you need. Just do: - PE A: `nvshmem_wait_until(&ivar, ...)` - PE B: `nvshmem_put(&ivar_on_PE_A, ...)` Fixes #158423 cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
**update: have seen some tests hangs when running tests in CI on 8 ranks, will fix and update this PR** When playing around with it, I noticed some flakiness in these tests across sessions. In the ring broadcast test, I kept getting illegal memory access errors. After some debugging and checking the NVSHMEM docs, turns out `signal_op` wasn't the right call here. The issue is `putmem_signal_block` gives you atomic guarantees - signal only comes after data transfer is done. But `signal_op` is just standalone signaling, no connection to data transfers. Without that atomic guarantee, you get race conditions where signals show up before data, so PEs try accessing data that isn't there yet. I still need to figure out how to get signal_op working with signal_wait_until, I've been facing some issues with it. For the wait_until test, I got NCCL hangs. After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock. The fix was realizing `wait_until` already provides all the sync you need. Just do: - PE A: `nvshmem_wait_until(&ivar, ...)` - PE B: `nvshmem_put(&ivar_on_PE_A, ...)` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
**update: have seen some tests hangs when running tests in CI on 8 ranks, will fix and update this PR** When playing around with it, I noticed some flakiness in these tests across sessions. In the ring broadcast test, I kept getting illegal memory access errors. After some debugging and checking the NVSHMEM docs, turns out `signal_op` wasn't the right call here. The issue is `putmem_signal_block` gives you atomic guarantees - signal only comes after data transfer is done. But `signal_op` is just standalone signaling, no connection to data transfers. Without that atomic guarantee, you get race conditions where signals show up before data, so PEs try accessing data that isn't there yet. I still need to figure out how to get signal_op working with signal_wait_until, I've been facing some issues with it. For the wait_until test, I got NCCL hangs. After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock. The fix was realizing `wait_until` already provides all the sync you need. Just do: - PE A: `nvshmem_wait_until(&ivar, ...)` - PE B: `nvshmem_put(&ivar_on_PE_A, ...)` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
When playing around with it, I noticed some flakiness in this test across sessions. After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock. The fix was realizing `wait_until` already provides all the sync you need. Just do: - PE A: `nvshmem_wait_until(&ivar, ...)` - PE B: `nvshmem_put(&ivar_on_PE_A, ...)` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
When playing around with it, I noticed some flakiness in this test across sessions. After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock. The fix was realizing `wait_until` already provides all the sync you need. Just do: - PE A: `nvshmem_wait_until(&ivar, ...)` - PE B: `nvshmem_put(&ivar_on_PE_A, ...)` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
When playing around with it, I noticed some flakiness in this test across sessions. After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock. The fix was realizing `wait_until` already provides all the sync you need. Just do: - PE A: `nvshmem_wait_until(&ivar, ...)` - PE B: `nvshmem_put(&ivar_on_PE_A, ...)` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
When playing around with it, I noticed some flakiness in this test across sessions. After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock. The fix was realizing `wait_until` already provides all the sync you need. Just do: - PE A: `nvshmem_wait_until(&ivar, ...)` - PE B: `nvshmem_put(&ivar_on_PE_A, ...)` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta [ghstack-poisoned]
Starting merge as part of PR stack under #159788 |
|
||
inp_hdl = symm_mem.rendezvous(inp, group=group_name) | ||
out_hdl = symm_mem.rendezvous(out, group=group_name) | ||
barrier_all_kernel[(1,)](extern_libs=nvshmem_lib) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this needed?
Starting merge as part of PR stack under #159788 |
Starting merge as part of PR stack under #159788 |
…(make device library discoverable + fix peer calculation bug) (#159701) This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. **Triton 3.4 Compatibility** - The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook. - The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes. **Fix CI Errors** - The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped. - Added a decorator to run NVSHMEM tests only on H100s (compatible hardware) **Peer Rank Calculation Fix** - The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank. Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup. Pull Request resolved: #159701 Approved by: https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215
…m in their name (#159734) Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. Pull Request resolved: #159734 Approved by: https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215, #159701
…tomatic dtype‐based dispatch (#159755) This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce Pull Request resolved: #159755 Approved by: https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness. Pull Request resolved: #159756 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734, #159755
…on kernels (#159788) This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers). The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. ----- **TODO:** This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until` From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer ``` Pointer-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Remote buffer: 0x2430300c00 (dst) ← Rank 1's memory Remote signal: 0x2430301600 (sig) ← Rank 1's signal Rank 1 (waiting): Local signal: 0x430301600 (waits here) Tensor-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Local buffer: 0x430300c00 (dst) ← this is wrong Local signal: 0x430300e00 (sig) ← this is wrong Rank 1 (waiting): Local signal: 0x430300e00 (waits here) ``` Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup. Pull Request resolved: #159788 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734, #159755, #159756
When playing around with it, I noticed some flakiness in this test across sessions. After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock. The fix was realizing `wait_until` already provides all the sync you need. Just do: - PE A: `nvshmem_wait_until(&ivar, ...)` - PE B: `nvshmem_put(&ivar_on_PE_A, ...)` Pull Request resolved: pytorch#159215 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136
…(make device library discoverable + fix peer calculation bug) (pytorch#159701) This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. **Triton 3.4 Compatibility** - The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook. - The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes. **Fix CI Errors** - The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped. - Added a decorator to run NVSHMEM tests only on H100s (compatible hardware) **Peer Rank Calculation Fix** - The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank. Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup. Pull Request resolved: pytorch#159701 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215
…m in their name (pytorch#159734) Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. Pull Request resolved: pytorch#159734 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701
…tomatic dtype‐based dispatch (pytorch#159755) This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce Pull Request resolved: pytorch#159755 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness. Pull Request resolved: pytorch#159756 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755
…on kernels (pytorch#159788) This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers). The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. ----- **TODO:** This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until` From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer ``` Pointer-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Remote buffer: 0x2430300c00 (dst) ← Rank 1's memory Remote signal: 0x2430301600 (sig) ← Rank 1's signal Rank 1 (waiting): Local signal: 0x430301600 (waits here) Tensor-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Local buffer: 0x430300c00 (dst) ← this is wrong Local signal: 0x430300e00 (sig) ← this is wrong Rank 1 (waiting): Local signal: 0x430300e00 (waits here) ``` Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup. Pull Request resolved: pytorch#159788 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755, pytorch#159756
When playing around with it, I noticed some flakiness in this test across sessions. After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock. The fix was realizing `wait_until` already provides all the sync you need. Just do: - PE A: `nvshmem_wait_until(&ivar, ...)` - PE B: `nvshmem_put(&ivar_on_PE_A, ...)` Pull Request resolved: pytorch#159215 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136
…(make device library discoverable + fix peer calculation bug) (pytorch#159701) This PR introduces support for Triton 3.4 and resolves several CI and test-related issues. **Triton 3.4 Compatibility** - The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook. - The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes. **Fix CI Errors** - The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped. - Added a decorator to run NVSHMEM tests only on H100s (compatible hardware) **Peer Rank Calculation Fix** - The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank. Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup. Pull Request resolved: pytorch#159701 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215
…m in their name (pytorch#159734) Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this. Pull Request resolved: pytorch#159734 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701
…tomatic dtype‐based dispatch (pytorch#159755) This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64). It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce Pull Request resolved: pytorch#159755 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness. Pull Request resolved: pytorch#159756 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755
…on kernels (pytorch#159788) This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers). The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data. ----- **TODO:** This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until` From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer ``` Pointer-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Remote buffer: 0x2430300c00 (dst) ← Rank 1's memory Remote signal: 0x2430301600 (sig) ← Rank 1's signal Rank 1 (waiting): Local signal: 0x430301600 (waits here) Tensor-Based Version: Rank 0 → Rank 1: Local buffer: 0x430300a00 (src) Local buffer: 0x430300c00 (dst) ← this is wrong Local signal: 0x430300e00 (sig) ← this is wrong Rank 1 (waiting): Local signal: 0x430300e00 (waits here) ``` Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup. Pull Request resolved: pytorch#159788 Approved by: https://github.com/mandroid6, https://github.com/ngimel ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755, pytorch#159756
When playing around with it, I noticed some flakiness in this test across sessions.
After debugging, turns out the heavy sync primitives that I was calling (like
nvshmem_quiet()
ornvshmem_fence()
) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering:put(data) -> fence/quiet -> put(flag)
. But the GPU thread got stuck inquiet()
waiting for network confirmation while holding the SM, creating a deadlock.The fix was realizing
wait_until
already provides all the sync you need. Just do:nvshmem_wait_until(&ivar, ...)
nvshmem_put(&ivar_on_PE_A, ...)
Stack from ghstack (oldest at bottom):
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta