Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 18 additions & 19 deletions test/distributed/test_nvshmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,16 @@ def put_kernel(
):
nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer)

# A Triton kernel that calls nvshmem device side API for SIGNAL_OP
@triton.jit
def signal_op_kernel(
sig_addr,
signal: tl.constexpr,
sig_op: tl.constexpr,
peer: tl.constexpr,
Comment on lines +618 to +620
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these three operands be constexpr? Can they be regular variables? Just to avoid re-compilation.

Copy link
Contributor Author

@codingwithsurya codingwithsurya Jun 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

talked offline; these make sense to be regular variables

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm going to update this on top PR of this stack.

):
nvshmem.signal_op(sig_addr, signal, sig_op, peer)

# A Triton kernel that calls nvshmem device side API for WAIT_UNTIL
@triton.jit
def wait_until_kernel(
Expand All @@ -637,10 +647,10 @@ def wait_until_kernel(
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
out_hdl = symm_mem.rendezvous(out, group=group_name)
dist.barrier()

peer = 1 - rank
NVSHMEM_CMP_EQ = 0 # from nvshmem.h
NVSHMEM_SIGNAL_SET = 0 # atomic set operation

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see we can also pass in NVSHMEM_SIGNAL_ADD here? (https://docs.nvidia.com/nvshmem/api/gen/api/signal.html#available-signal-operators )

Do we see any scenario where it could be useful? @kwen2501

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, when the sender wants to monotonically update a counter as versioning info to the receiver.


if rank == 0:
# Rank 0 waits for the flag to be set by Rank 1, then checks the data
Expand All @@ -666,17 +676,13 @@ def wait_until_kernel(
peer=peer,
extern_libs=nvshmem_lib,
)
# Rank 1 sets the flag on Rank 0
# We use a temporary tensor for the value to put.
flag_update_val = torch.tensor(
[flag_val], dtype=torch.int64, device=self.device
)
dst_ptr = out_hdl.signal_pad_ptrs[rank]
src_ptr = flag_update_val.data_ptr()
put_kernel[(1, 1, 1)](
dst_ptr,
src_ptr,
numel=1,

# Rank 1 sets the flag on Rank 0 using nvshmemx_signal_op
sig_addr = out_hdl.signal_pad_ptrs[rank]
signal_op_kernel[(1, 1, 1)](
sig_addr,
signal=flag_val,
sig_op=NVSHMEM_SIGNAL_SET,
peer=peer,
extern_libs=nvshmem_lib,
)
Expand Down Expand Up @@ -736,8 +742,6 @@ def put_and_signal_kernel(
# Use the signal pad for synchronization, as in previous tests
flag_dtype = torch.int64
flag = out_hdl.get_signal_pad(rank, (1,), dtype=flag_dtype).fill_(0)
# Ensure setup is complete on all ranks before proceeding
dist.barrier()

if rank == 0:
# Producer (rank 0): Puts data into rank 1's `out` buffer and then sets the flag
Expand Down Expand Up @@ -773,8 +777,6 @@ def put_and_signal_kernel(
[COMPLETION_FLAG_VAL], dtype=flag_dtype, device=self.device
),
)
# Final barrier to ensure the test does not exit before assertions complete
dist.barrier()

@skipIfRocm
@requires_triton()
Expand Down Expand Up @@ -851,7 +853,6 @@ def wait_until_kernel(
[flag_val], dtype=torch.int64, device=self.device
)
NVSHMEM_CMP_EQ = 0 # compare equal
dist.barrier()

if rank == 0:
dst_ptr1 = out1_hdl.buffer_ptrs[rank]
Expand Down Expand Up @@ -892,7 +893,6 @@ def wait_until_kernel(
torch.testing.assert_close(
flag, torch.tensor([flag_val], dtype=torch.int64, device=self.device)
)
dist.barrier()

@skipIfRocm
@requires_triton()
Expand Down Expand Up @@ -944,7 +944,6 @@ def wait_until_kernel(
):
nvshmem.wait_until(ivar_ptr, cmp_op, cmp_val)

dist.barrier()
if rank == 0:
# Rank 0 waits for flag from Rank 1
ivar_ptr = out_hdl.signal_pad_ptrs[rank]
Expand Down
18 changes: 18 additions & 0 deletions torch/distributed/_symmetric_memory/_nvshmem_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,24 @@ def signal_wait_until(sig_addr, cmp, cmp_val, _builder=None): # type: ignore[no
_builder=_builder,
)

@core.extern
def signal_op(sig_addr, signal, sig_op, pe, _builder=None): # type: ignore[no-untyped-def]
return core.extern_elementwise(
"",
"",
[sig_addr, signal, sig_op, pe],
{
(
core.dtype("int64"),
core.dtype("int64"),
core.dtype("int64"),
core.dtype("int64"),
): ("nvshmemx_signal_op", core.dtype("int32"))
},
is_pure=False,
_builder=_builder,
)

@core.extern
def fence(_builder=None): # type: ignore[no-untyped-def]
return core.extern_elementwise(
Expand Down
Loading