-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[SymmMem] Remove redundant dist.barrier in Triton NVSHMEM tests & add device‐side signal_op support #156684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SymmMem] Remove redundant dist.barrier in Triton NVSHMEM tests & add device‐side signal_op support #156684
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -611,6 +611,16 @@ def put_kernel( | |
): | ||
nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer) | ||
|
||
# A Triton kernel that calls nvshmem device side API for SIGNAL_OP | ||
@triton.jit | ||
def signal_op_kernel( | ||
sig_addr, | ||
signal: tl.constexpr, | ||
sig_op: tl.constexpr, | ||
peer: tl.constexpr, | ||
): | ||
nvshmem.signal_op(sig_addr, signal, sig_op, peer) | ||
|
||
# A Triton kernel that calls nvshmem device side API for WAIT_UNTIL | ||
@triton.jit | ||
def wait_until_kernel( | ||
|
@@ -637,10 +647,10 @@ def wait_until_kernel( | |
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) | ||
inp_hdl = symm_mem.rendezvous(inp, group=group_name) | ||
out_hdl = symm_mem.rendezvous(out, group=group_name) | ||
dist.barrier() | ||
|
||
peer = 1 - rank | ||
NVSHMEM_CMP_EQ = 0 # from nvshmem.h | ||
NVSHMEM_SIGNAL_SET = 0 # atomic set operation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see we can also pass in Do we see any scenario where it could be useful? @kwen2501 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, when the sender wants to monotonically update a counter as versioning info to the receiver. |
||
|
||
if rank == 0: | ||
# Rank 0 waits for the flag to be set by Rank 1, then checks the data | ||
|
@@ -666,17 +676,13 @@ def wait_until_kernel( | |
peer=peer, | ||
extern_libs=nvshmem_lib, | ||
) | ||
# Rank 1 sets the flag on Rank 0 | ||
# We use a temporary tensor for the value to put. | ||
flag_update_val = torch.tensor( | ||
[flag_val], dtype=torch.int64, device=self.device | ||
) | ||
dst_ptr = out_hdl.signal_pad_ptrs[rank] | ||
src_ptr = flag_update_val.data_ptr() | ||
put_kernel[(1, 1, 1)]( | ||
dst_ptr, | ||
src_ptr, | ||
numel=1, | ||
|
||
# Rank 1 sets the flag on Rank 0 using nvshmemx_signal_op | ||
sig_addr = out_hdl.signal_pad_ptrs[rank] | ||
signal_op_kernel[(1, 1, 1)]( | ||
sig_addr, | ||
signal=flag_val, | ||
sig_op=NVSHMEM_SIGNAL_SET, | ||
peer=peer, | ||
extern_libs=nvshmem_lib, | ||
) | ||
|
@@ -736,8 +742,6 @@ def put_and_signal_kernel( | |
# Use the signal pad for synchronization, as in previous tests | ||
flag_dtype = torch.int64 | ||
flag = out_hdl.get_signal_pad(rank, (1,), dtype=flag_dtype).fill_(0) | ||
# Ensure setup is complete on all ranks before proceeding | ||
dist.barrier() | ||
|
||
if rank == 0: | ||
# Producer (rank 0): Puts data into rank 1's `out` buffer and then sets the flag | ||
|
@@ -773,8 +777,6 @@ def put_and_signal_kernel( | |
[COMPLETION_FLAG_VAL], dtype=flag_dtype, device=self.device | ||
), | ||
) | ||
# Final barrier to ensure the test does not exit before assertions complete | ||
dist.barrier() | ||
|
||
@skipIfRocm | ||
@requires_triton() | ||
|
@@ -851,7 +853,6 @@ def wait_until_kernel( | |
[flag_val], dtype=torch.int64, device=self.device | ||
) | ||
NVSHMEM_CMP_EQ = 0 # compare equal | ||
dist.barrier() | ||
|
||
if rank == 0: | ||
dst_ptr1 = out1_hdl.buffer_ptrs[rank] | ||
|
@@ -892,7 +893,6 @@ def wait_until_kernel( | |
torch.testing.assert_close( | ||
flag, torch.tensor([flag_val], dtype=torch.int64, device=self.device) | ||
) | ||
dist.barrier() | ||
|
||
@skipIfRocm | ||
@requires_triton() | ||
|
@@ -944,7 +944,6 @@ def wait_until_kernel( | |
): | ||
nvshmem.wait_until(ivar_ptr, cmp_op, cmp_val) | ||
|
||
dist.barrier() | ||
if rank == 0: | ||
# Rank 0 waits for flag from Rank 1 | ||
ivar_ptr = out_hdl.signal_pad_ptrs[rank] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these three operands be constexpr? Can they be regular variables? Just to avoid re-compilation.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
talked offline; these make sense to be regular variables
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm going to update this on top PR of this stack.