From 9ec6b6e5fad3657a0d9417a1b7cc68eb3621a87d Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 12 Nov 2025 18:58:58 -0800 Subject: [PATCH 1/3] Get remote tensors inside Helion kernel --- examples/all_reduce.py | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/examples/all_reduce.py b/examples/all_reduce.py index 7396e9b9c..8518d3e5a 100644 --- a/examples/all_reduce.py +++ b/examples/all_reduce.py @@ -91,8 +91,9 @@ def dev_array_to_tensor_short( def one_shot_all_reduce_kernel( signal_pad_addrs: torch.Tensor, local_signal_pad: torch.Tensor, - a_shared_tuple: tuple[torch.Tensor, ...], + a_shared: torch.Tensor, my_rank: hl.constexpr, + group_name: hl.constexpr, ) -> torch.Tensor: """ Helion JIT-compiled kernel for one-shot all-reduce operation. @@ -112,8 +113,9 @@ def one_shot_all_reduce_kernel( """ _, world_size = local_signal_pad.size() world_size = hl.specialize(world_size) - out = torch.empty_like(a_shared_tuple[0]) + out = torch.empty_like(a_shared) N = out.size(0) + a_shared_tuple = torch.ops.symm_mem.get_remote_tensors(a_shared, group_name) for tile_n in hl.tile(N): # Sync all devices through signal_pad to make sure @@ -138,9 +140,7 @@ def one_shot_all_reduce_kernel( scope="sys", ) - acc = hl.zeros( - [tile_n], dtype=a_shared_tuple[0].dtype, device=local_signal_pad.device - ) + acc = hl.zeros([tile_n], dtype=a_shared.dtype, device=local_signal_pad.device) for a in a_shared_tuple: acc += a[tile_n] @@ -183,15 +183,8 @@ def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor: Tensor containing the all-reduced result (sum across all devices) """ assert dist.group.WORLD is not None - - symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD) - - a_shared_tuple = tuple( - [ - symm_mem_hdl.get_buffer(i, tuple(a_shared.shape), a_shared.dtype) - for i in range(symm_mem_hdl.world_size) - ] - ) + group_name = dist.group.WORLD.group_name + symm_mem_hdl = symm_mem.rendezvous(a_shared, group_name) local_signal_pad = symm_mem_hdl.get_signal_pad( symm_mem_hdl.rank, dtype=torch.int32 @@ -207,8 +200,9 @@ def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor: return one_shot_all_reduce_kernel( signal_pad_addrs, local_signal_pad, - a_shared_tuple, + a_shared, my_rank=symm_mem_hdl.rank, + group_name=group_name, ) @@ -254,15 +248,16 @@ def test(N: int, device: torch.device, dtype: torch.dtype) -> None: rank = dist.get_rank() # Create symmetric memory tensor for Helion implementation - a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_() - - print(f"[Rank {rank}] Running Helion all-reduce...") - result_helion = helion_one_shot_all_reduce(a_shared) + symm_mem.enable_symm_mem_for_group(dist.group.WORLD.group_name) + a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).fill_(1) # Create symmetric memory tensor for reference implementation a_shared_ref = symm_mem.empty(N // world_size, dtype=dtype, device=device) a_shared_ref.copy_(a_shared) + print(f"[Rank {rank}] Running Helion all-reduce...") + result_helion = helion_one_shot_all_reduce(a_shared) + print(f"[Rank {rank}] Running reference all-reduce...") result_ref = reference_one_shot_all_reduce(a_shared_ref) @@ -279,6 +274,8 @@ def main() -> None: Sets up the distributed environment, initializes CUDA devices, and runs the all-reduce test, and then clean up. """ + # Only NVSHMEM backend implements `get_remote_tensor` for now. + symm_mem.set_backend("NVSHMEM") rank = int(os.environ["LOCAL_RANK"]) torch.manual_seed(42 + rank) device = torch.device(f"cuda:{rank}") From 3e47ded7dfad0fb9e2a0e42e205d8381ed4c4984 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Wed, 12 Nov 2025 22:20:54 -0800 Subject: [PATCH 2/3] fix distributed CI error reporting --- .github/workflows/test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d3652a265..d1bbdc223 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -161,8 +161,8 @@ jobs: # --timeout: max allowed time for each test TEST_PATH=$([[ "${{ contains(matrix.alias, 'distributed') }}" == "true" ]] && echo "test/test_examples_dist.py" || echo ".") EXTRA_FLAGS=$([[ "${{ contains(matrix.alias, 'distributed') }}" == "true" ]] && echo "-rs" || echo "--ignore=test/test_examples_dist.py") - # For distributed tests, fail if any test is skipped - SKIP_CHECK=$([[ "${{ contains(matrix.alias, 'distributed') }}" == "true" ]] && echo "! grep -q SKIPPED" || echo "cat") + # For distributed tests, fail if any test is skipped, failed, or has an error + SKIP_CHECK=$([[ "${{ contains(matrix.alias, 'distributed') }}" == "true" ]] && echo "! grep -qE '(SKIPPED|FAILED|ERROR)'" || echo "cat") pytest -rf --timeout=60 $EXTRA_FLAGS $TEST_PATH | tee >(eval $SKIP_CHECK) test-notebooks: From f7801220c2915487237aba4e48e334fde31e05ec Mon Sep 17 00:00:00 2001 From: Will Feng Date: Wed, 12 Nov 2025 22:44:28 -0800 Subject: [PATCH 3/3] Revert "Get remote tensors inside Helion kernel" This reverts commit 9ec6b6e5fad3657a0d9417a1b7cc68eb3621a87d. --- examples/all_reduce.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/examples/all_reduce.py b/examples/all_reduce.py index 8518d3e5a..7396e9b9c 100644 --- a/examples/all_reduce.py +++ b/examples/all_reduce.py @@ -91,9 +91,8 @@ def dev_array_to_tensor_short( def one_shot_all_reduce_kernel( signal_pad_addrs: torch.Tensor, local_signal_pad: torch.Tensor, - a_shared: torch.Tensor, + a_shared_tuple: tuple[torch.Tensor, ...], my_rank: hl.constexpr, - group_name: hl.constexpr, ) -> torch.Tensor: """ Helion JIT-compiled kernel for one-shot all-reduce operation. @@ -113,9 +112,8 @@ def one_shot_all_reduce_kernel( """ _, world_size = local_signal_pad.size() world_size = hl.specialize(world_size) - out = torch.empty_like(a_shared) + out = torch.empty_like(a_shared_tuple[0]) N = out.size(0) - a_shared_tuple = torch.ops.symm_mem.get_remote_tensors(a_shared, group_name) for tile_n in hl.tile(N): # Sync all devices through signal_pad to make sure @@ -140,7 +138,9 @@ def one_shot_all_reduce_kernel( scope="sys", ) - acc = hl.zeros([tile_n], dtype=a_shared.dtype, device=local_signal_pad.device) + acc = hl.zeros( + [tile_n], dtype=a_shared_tuple[0].dtype, device=local_signal_pad.device + ) for a in a_shared_tuple: acc += a[tile_n] @@ -183,8 +183,15 @@ def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor: Tensor containing the all-reduced result (sum across all devices) """ assert dist.group.WORLD is not None - group_name = dist.group.WORLD.group_name - symm_mem_hdl = symm_mem.rendezvous(a_shared, group_name) + + symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD) + + a_shared_tuple = tuple( + [ + symm_mem_hdl.get_buffer(i, tuple(a_shared.shape), a_shared.dtype) + for i in range(symm_mem_hdl.world_size) + ] + ) local_signal_pad = symm_mem_hdl.get_signal_pad( symm_mem_hdl.rank, dtype=torch.int32 @@ -200,9 +207,8 @@ def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor: return one_shot_all_reduce_kernel( signal_pad_addrs, local_signal_pad, - a_shared, + a_shared_tuple, my_rank=symm_mem_hdl.rank, - group_name=group_name, ) @@ -248,16 +254,15 @@ def test(N: int, device: torch.device, dtype: torch.dtype) -> None: rank = dist.get_rank() # Create symmetric memory tensor for Helion implementation - symm_mem.enable_symm_mem_for_group(dist.group.WORLD.group_name) - a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).fill_(1) + a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_() + + print(f"[Rank {rank}] Running Helion all-reduce...") + result_helion = helion_one_shot_all_reduce(a_shared) # Create symmetric memory tensor for reference implementation a_shared_ref = symm_mem.empty(N // world_size, dtype=dtype, device=device) a_shared_ref.copy_(a_shared) - print(f"[Rank {rank}] Running Helion all-reduce...") - result_helion = helion_one_shot_all_reduce(a_shared) - print(f"[Rank {rank}] Running reference all-reduce...") result_ref = reference_one_shot_all_reduce(a_shared_ref) @@ -274,8 +279,6 @@ def main() -> None: Sets up the distributed environment, initializes CUDA devices, and runs the all-reduce test, and then clean up. """ - # Only NVSHMEM backend implements `get_remote_tensor` for now. - symm_mem.set_backend("NVSHMEM") rank = int(os.environ["LOCAL_RANK"]) torch.manual_seed(42 + rank) device = torch.device(f"cuda:{rank}")