diff --git a/test/distributed/test_nvshmem.py b/test/distributed/test_nvshmem.py index e903dec89cf9..2a241917b893 100644 --- a/test/distributed/test_nvshmem.py +++ b/test/distributed/test_nvshmem.py @@ -611,6 +611,16 @@ def put_kernel( ): nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer) + # A Triton kernel that calls nvshmem device side API for SIGNAL_OP + @triton.jit + def signal_op_kernel( + sig_addr, + signal: tl.constexpr, + sig_op: tl.constexpr, + peer: tl.constexpr, + ): + nvshmem.signal_op(sig_addr, signal, sig_op, peer) + # A Triton kernel that calls nvshmem device side API for WAIT_UNTIL @triton.jit def wait_until_kernel( @@ -637,10 +647,10 @@ def wait_until_kernel( out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) inp_hdl = symm_mem.rendezvous(inp, group=group_name) out_hdl = symm_mem.rendezvous(out, group=group_name) - dist.barrier() peer = 1 - rank NVSHMEM_CMP_EQ = 0 # from nvshmem.h + NVSHMEM_SIGNAL_SET = 0 # atomic set operation if rank == 0: # Rank 0 waits for the flag to be set by Rank 1, then checks the data @@ -666,17 +676,13 @@ def wait_until_kernel( peer=peer, extern_libs=nvshmem_lib, ) - # Rank 1 sets the flag on Rank 0 - # We use a temporary tensor for the value to put. - flag_update_val = torch.tensor( - [flag_val], dtype=torch.int64, device=self.device - ) - dst_ptr = out_hdl.signal_pad_ptrs[rank] - src_ptr = flag_update_val.data_ptr() - put_kernel[(1, 1, 1)]( - dst_ptr, - src_ptr, - numel=1, + + # Rank 1 sets the flag on Rank 0 using nvshmemx_signal_op + sig_addr = out_hdl.signal_pad_ptrs[rank] + signal_op_kernel[(1, 1, 1)]( + sig_addr, + signal=flag_val, + sig_op=NVSHMEM_SIGNAL_SET, peer=peer, extern_libs=nvshmem_lib, ) @@ -736,8 +742,6 @@ def put_and_signal_kernel( # Use the signal pad for synchronization, as in previous tests flag_dtype = torch.int64 flag = out_hdl.get_signal_pad(rank, (1,), dtype=flag_dtype).fill_(0) - # Ensure setup is complete on all ranks before proceeding - dist.barrier() if rank == 0: # Producer (rank 0): Puts data into rank 1's `out` buffer and then sets the flag @@ -773,8 +777,6 @@ def put_and_signal_kernel( [COMPLETION_FLAG_VAL], dtype=flag_dtype, device=self.device ), ) - # Final barrier to ensure the test does not exit before assertions complete - dist.barrier() @skipIfRocm @requires_triton() @@ -851,7 +853,6 @@ def wait_until_kernel( [flag_val], dtype=torch.int64, device=self.device ) NVSHMEM_CMP_EQ = 0 # compare equal - dist.barrier() if rank == 0: dst_ptr1 = out1_hdl.buffer_ptrs[rank] @@ -892,7 +893,6 @@ def wait_until_kernel( torch.testing.assert_close( flag, torch.tensor([flag_val], dtype=torch.int64, device=self.device) ) - dist.barrier() @skipIfRocm @requires_triton() @@ -944,7 +944,6 @@ def wait_until_kernel( ): nvshmem.wait_until(ivar_ptr, cmp_op, cmp_val) - dist.barrier() if rank == 0: # Rank 0 waits for flag from Rank 1 ivar_ptr = out_hdl.signal_pad_ptrs[rank] diff --git a/torch/distributed/_symmetric_memory/_nvshmem_triton.py b/torch/distributed/_symmetric_memory/_nvshmem_triton.py index aeded6d76df9..75abae38c755 100644 --- a/torch/distributed/_symmetric_memory/_nvshmem_triton.py +++ b/torch/distributed/_symmetric_memory/_nvshmem_triton.py @@ -154,6 +154,24 @@ def signal_wait_until(sig_addr, cmp, cmp_val, _builder=None): # type: ignore[no _builder=_builder, ) + @core.extern + def signal_op(sig_addr, signal, sig_op, pe, _builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [sig_addr, signal, sig_op, pe], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + ): ("nvshmemx_signal_op", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) + @core.extern def fence(_builder=None): # type: ignore[no-untyped-def] return core.extern_elementwise(