Skip to content

Conversation

codingwithsurya
Copy link
Contributor

@codingwithsurya codingwithsurya commented Jul 26, 2025

When playing around with it, I noticed some flakiness in this test across sessions.

After debugging, turns out the heavy sync primitives that I was calling (like nvshmem_quiet() or nvshmem_fence()) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: put(data) -> fence/quiet -> put(flag). But the GPU thread got stuck in quiet() waiting for network confirmation while holding the SM, creating a deadlock.

The fix was realizing wait_until already provides all the sync you need. Just do:

  • PE A: nvshmem_wait_until(&ivar, ...)
  • PE B: nvshmem_put(&ivar_on_PE_A, ...)

Stack from ghstack (oldest at bottom):

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta

Copy link

pytorch-bot bot commented Jul 26, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159215

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Cancelled Job, 1 Unrelated Failure

As of commit 7e1d91f with merge base 3daef4d (image):

CANCELLED JOB - The following job was cancelled. Please retry:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

  • pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, linux.12xlarge, unstable) (gh) (#158876)
    /var/lib/jenkins/workspace/xla/torch_xla/csrc/runtime/BUILD:476:14: Compiling torch_xla/csrc/runtime/xla_util_test.cpp failed: (Exit 1): gcc failed: error executing CppCompile command (from target //torch_xla/csrc/runtime:xla_util_test) /usr/bin/gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer -g0 -O2 '-D_FORTIFY_SOURCE=1' -DNDEBUG -ffunction-sections ... (remaining 229 arguments skipped)

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/h100-symm-mem oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category labels Jul 26, 2025
codingwithsurya added a commit that referenced this pull request Jul 26, 2025
@codingwithsurya codingwithsurya self-assigned this Jul 26, 2025
@codingwithsurya codingwithsurya changed the title [SymmMem] Fix flaky ring broadcast and wait_until test [wip] [SymmMem] Fix flaky ring broadcast and wait_until test Jul 26, 2025
@codingwithsurya codingwithsurya changed the title [wip] [SymmMem] Fix flaky ring broadcast and wait_until test [SymmMem] Fix flaky ring broadcast and wait_until test Jul 31, 2025
@codingwithsurya codingwithsurya changed the title [SymmMem] Fix flaky ring broadcast and wait_until test [wip] [SymmMem] Fix flaky ring broadcast and wait_until test Aug 2, 2025
When playing around with it, I noticed some flakiness in these tests across sessions.

 In the ring broadcast test, I kept getting illegal memory access errors. After some debugging and checking the NVSHMEM docs, turns out `signal_op` wasn't the right call here. The issue is `putmem_signal_block` gives you atomic guarantees - signal only comes after data transfer is done. But `signal_op` is just standalone signaling, no connection to data transfers. Without that atomic guarantee, you get race conditions where signals show up before data, so PEs try accessing data that isn't there yet. I still need to figure out how to get signal_op working with signal_wait_until, I've been facing some issues with it. 
 
 For the wait_until test, I got NCCL hangs. After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock.

The fix was realizing `wait_until` already provides all the sync you need. Just do:
- PE A: `nvshmem_wait_until(&ivar, ...)`  
- PE B: `nvshmem_put(&ivar_on_PE_A, ...)`

Fixes #158423



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
**update: have seen some tests hangs when running tests in CI on 8 ranks, will fix and update this PR**

When playing around with it, I noticed some flakiness in these tests across sessions.

 In the ring broadcast test, I kept getting illegal memory access errors. After some debugging and checking the NVSHMEM docs, turns out `signal_op` wasn't the right call here. The issue is `putmem_signal_block` gives you atomic guarantees - signal only comes after data transfer is done. But `signal_op` is just standalone signaling, no connection to data transfers. Without that atomic guarantee, you get race conditions where signals show up before data, so PEs try accessing data that isn't there yet. I still need to figure out how to get signal_op working with signal_wait_until, I've been facing some issues with it. 
 
 For the wait_until test, I got NCCL hangs. After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock.

The fix was realizing `wait_until` already provides all the sync you need. Just do:
- PE A: `nvshmem_wait_until(&ivar, ...)`  
- PE B: `nvshmem_put(&ivar_on_PE_A, ...)`




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
**update: have seen some tests hangs when running tests in CI on 8 ranks, will fix and update this PR**

When playing around with it, I noticed some flakiness in these tests across sessions.

 In the ring broadcast test, I kept getting illegal memory access errors. After some debugging and checking the NVSHMEM docs, turns out `signal_op` wasn't the right call here. The issue is `putmem_signal_block` gives you atomic guarantees - signal only comes after data transfer is done. But `signal_op` is just standalone signaling, no connection to data transfers. Without that atomic guarantee, you get race conditions where signals show up before data, so PEs try accessing data that isn't there yet. I still need to figure out how to get signal_op working with signal_wait_until, I've been facing some issues with it. 
 
 For the wait_until test, I got NCCL hangs. After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock.

The fix was realizing `wait_until` already provides all the sync you need. Just do:
- PE A: `nvshmem_wait_until(&ivar, ...)`  
- PE B: `nvshmem_put(&ivar_on_PE_A, ...)`




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
@codingwithsurya codingwithsurya changed the title [wip] [SymmMem] Fix flaky ring broadcast and wait_until test [wip] [SymmMem] Fix flakywait_until test Aug 3, 2025
@codingwithsurya codingwithsurya changed the title [wip] [SymmMem] Fix flakywait_until test [SymmMem] Fix flaky wait_until test Aug 3, 2025
@codingwithsurya codingwithsurya marked this pull request as ready for review August 3, 2025 03:30
When playing around with it, I noticed some flakiness in this test across sessions.
 
After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock.

The fix was realizing `wait_until` already provides all the sync you need. Just do:
- PE A: `nvshmem_wait_until(&ivar, ...)`  
- PE B: `nvshmem_put(&ivar_on_PE_A, ...)`




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
When playing around with it, I noticed some flakiness in this test across sessions.
 
After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock.

The fix was realizing `wait_until` already provides all the sync you need. Just do:
- PE A: `nvshmem_wait_until(&ivar, ...)`  
- PE B: `nvshmem_put(&ivar_on_PE_A, ...)`




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
When playing around with it, I noticed some flakiness in this test across sessions.
 
After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock.

The fix was realizing `wait_until` already provides all the sync you need. Just do:
- PE A: `nvshmem_wait_until(&ivar, ...)`  
- PE B: `nvshmem_put(&ivar_on_PE_A, ...)`




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
When playing around with it, I noticed some flakiness in this test across sessions.
 
After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock.

The fix was realizing `wait_until` already provides all the sync you need. Just do:
- PE A: `nvshmem_wait_until(&ivar, ...)`  
- PE B: `nvshmem_put(&ivar_on_PE_A, ...)`




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #159788


inp_hdl = symm_mem.rendezvous(inp, group=group_name)
out_hdl = symm_mem.rendezvous(out, group=group_name)
barrier_all_kernel[(1,)](extern_libs=nvshmem_lib)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #159788

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #159788

pytorchmergebot pushed a commit that referenced this pull request Aug 8, 2025
…(make device library discoverable + fix peer calculation bug) (#159701)

This PR introduces support for Triton 3.4 and resolves several CI and test-related issues.

**Triton 3.4 Compatibility**
- The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook.
- The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes.

**Fix CI Errors**
- The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped.
- Added a decorator to run NVSHMEM tests only on H100s (compatible hardware)

**Peer Rank Calculation Fix**
- The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank.
Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup.

Pull Request resolved: #159701
Approved by: https://github.com/ngimel
ghstack dependencies: #158515, #158718, #159136, #159215
pytorchmergebot pushed a commit that referenced this pull request Aug 8, 2025
…m in their name (#159734)

Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes  `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this.

Pull Request resolved: #159734
Approved by: https://github.com/ngimel
ghstack dependencies: #158515, #158718, #159136, #159215, #159701
pytorchmergebot pushed a commit that referenced this pull request Aug 8, 2025
…tomatic dtype‐based dispatch (#159755)

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce

Pull Request resolved: #159755
Approved by: https://github.com/ngimel
ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734
pytorchmergebot pushed a commit that referenced this pull request Aug 8, 2025
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness.

Pull Request resolved: #159756
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734, #159755
pytorchmergebot pushed a commit that referenced this pull request Aug 8, 2025
…on kernels (#159788)

This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).

The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data.

-----

**TODO:**
This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until`

From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer

```
Pointer-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Remote buffer:  0x2430300c00 (dst) ← Rank 1's memory
    Remote signal:  0x2430301600 (sig) ← Rank 1's signal

  Rank 1 (waiting):
    Local signal:   0x430301600 (waits here)

Tensor-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Local buffer:   0x430300c00  (dst) ← this is wrong
    Local signal:   0x430300e00  (sig) ← this is wrong

  Rank 1 (waiting):
    Local signal:   0x430300e00 (waits here)

```

Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup.

Pull Request resolved: #159788
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734, #159755, #159756
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
When playing around with it, I noticed some flakiness in this test across sessions.

After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock.

The fix was realizing `wait_until` already provides all the sync you need. Just do:
- PE A: `nvshmem_wait_until(&ivar, ...)`
- PE B: `nvshmem_put(&ivar_on_PE_A, ...)`

Pull Request resolved: pytorch#159215
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
…(make device library discoverable + fix peer calculation bug) (pytorch#159701)

This PR introduces support for Triton 3.4 and resolves several CI and test-related issues.

**Triton 3.4 Compatibility**
- The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook.
- The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes.

**Fix CI Errors**
- The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped.
- Added a decorator to run NVSHMEM tests only on H100s (compatible hardware)

**Peer Rank Calculation Fix**
- The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank.
Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup.

Pull Request resolved: pytorch#159701
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
…m in their name (pytorch#159734)

Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes  `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this.

Pull Request resolved: pytorch#159734
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
…tomatic dtype‐based dispatch (pytorch#159755)

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce

Pull Request resolved: pytorch#159755
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness.

Pull Request resolved: pytorch#159756
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755
hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
…on kernels (pytorch#159788)

This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).

The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data.

-----

**TODO:**
This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until`

From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer

```
Pointer-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Remote buffer:  0x2430300c00 (dst) ← Rank 1's memory
    Remote signal:  0x2430301600 (sig) ← Rank 1's signal

  Rank 1 (waiting):
    Local signal:   0x430301600 (waits here)

Tensor-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Local buffer:   0x430300c00  (dst) ← this is wrong
    Local signal:   0x430300e00  (sig) ← this is wrong

  Rank 1 (waiting):
    Local signal:   0x430300e00 (waits here)

```

Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup.

Pull Request resolved: pytorch#159788
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755, pytorch#159756
@github-actions github-actions bot deleted the gh/codingwithsurya/16/head branch September 8, 2025 02:14
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
When playing around with it, I noticed some flakiness in this test across sessions.

After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock.

The fix was realizing `wait_until` already provides all the sync you need. Just do:
- PE A: `nvshmem_wait_until(&ivar, ...)`
- PE B: `nvshmem_put(&ivar_on_PE_A, ...)`

Pull Request resolved: pytorch#159215
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…(make device library discoverable + fix peer calculation bug) (pytorch#159701)

This PR introduces support for Triton 3.4 and resolves several CI and test-related issues.

**Triton 3.4 Compatibility**
- The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook.
- The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes.

**Fix CI Errors**
- The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped.
- Added a decorator to run NVSHMEM tests only on H100s (compatible hardware)

**Peer Rank Calculation Fix**
- The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank.
Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup.

Pull Request resolved: pytorch#159701
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…m in their name (pytorch#159734)

Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes  `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this.

Pull Request resolved: pytorch#159734
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…tomatic dtype‐based dispatch (pytorch#159755)

This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce

Pull Request resolved: pytorch#159755
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Fed Claude Code NVSHMEM Documentation and asked it to generate helpful docstrings. Verified for correctness.

Pull Request resolved: pytorch#159756
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…on kernels (pytorch#159788)

This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).

The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data.

-----

**TODO:**
This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until`

From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer

```
Pointer-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Remote buffer:  0x2430300c00 (dst) ← Rank 1's memory
    Remote signal:  0x2430301600 (sig) ← Rank 1's signal

  Rank 1 (waiting):
    Local signal:   0x430301600 (waits here)

Tensor-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Local buffer:   0x430300c00  (dst) ← this is wrong
    Local signal:   0x430300e00  (sig) ← this is wrong

  Rank 1 (waiting):
    Local signal:   0x430300e00 (waits here)

```

Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup.

Pull Request resolved: pytorch#159788
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: pytorch#158515, pytorch#158718, pytorch#159136, pytorch#159215, pytorch#159701, pytorch#159734, pytorch#159755, pytorch#159756
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants