From 716b62d178d8e70ffd50e8a3b546b394346f9494 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 29 Sep 2025 14:38:32 +0000 Subject: [PATCH 1/5] fix qr error when different inp shape Co-authored-by: ilmarkov Signed-off-by: Haoyang Li --- csrc/quickreduce/quick_reduce_impl.cuh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index 17816c552d25..8fe44d9df50c 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -592,8 +592,7 @@ struct AllReduceTwoshot { grid_size * Codec::kTransmittedTileSize + comm_data0_offset; uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t)); - uint32_t comm_flags1_offset = - grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset; + uint32_t comm_flags1_offset = (data_offset / 2) + comm_flags0_offset; for (int r = 0; r < kWorldSize; r++) { int32x4_t* send_buffer = From 832ffa95142cd87c8e8439e5aad6065e11008077 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 10 Oct 2025 09:03:36 +0000 Subject: [PATCH 2/5] split data ptr Co-authored-by: ilmarkov Signed-off-by: Haoyang Li --- csrc/quickreduce/quick_reduce.h | 11 ++++++----- csrc/quickreduce/quick_reduce_impl.cuh | 6 ++---- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h index 4fe4c44be7eb..4cc35300bf87 100644 --- a/csrc/quickreduce/quick_reduce.h +++ b/csrc/quickreduce/quick_reduce.h @@ -22,13 +22,14 @@ template __global__ __quickreduce_launch_bounds_two_shot__ static void allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks, int rank, uint8_t** dbuffer_list, - uint32_t data_offset, uint32_t flag_color) { + uint32_t data_offset, uint32_t flag_color, + int64_t data_size_per_phase) { int block = blockIdx.x; int grid = gridDim.x; while (block < num_blocks) { AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset, - flag_color); + flag_color, data_size_per_phase); block += grid; flag_color++; } @@ -41,21 +42,21 @@ allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks, hipLaunchKernelGGL((allreduce_prototype_twoshot), \ dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ num_blocks, rank, dbuffer_list, data_offset, \ - flag_color); \ + flag_color, this->kMaxProblemSize); \ } else if (world_size == 4) { \ using LineCodec = __codec; \ using AllReduceKernel = AllReduceTwoshot; \ hipLaunchKernelGGL((allreduce_prototype_twoshot), \ dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ num_blocks, rank, dbuffer_list, data_offset, \ - flag_color); \ + flag_color, this->kMaxProblemSize); \ } else if (world_size == 8) { \ using LineCodec = __codec; \ using AllReduceKernel = AllReduceTwoshot; \ hipLaunchKernelGGL((allreduce_prototype_twoshot), \ dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ num_blocks, rank, dbuffer_list, data_offset, \ - flag_color); \ + flag_color, this->kMaxProblemSize); \ } enum QuickReduceQuantLevel { diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index 8fe44d9df50c..38dc9938fc8a 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -553,13 +553,12 @@ struct AllReduceTwoshot { int const rank, // rank index uint8_t** __restrict__ buffer_list, // communication buffers uint32_t const data_offset, // offset to start of the data buffer - uint32_t flag_color) { + uint32_t flag_color, int64_t data_size_per_phase) { // Topology int thread = threadIdx.x + threadIdx.y * kWavefront; uint8_t* rank_buffer = buffer_list[rank]; Codec codec(thread, rank); int block_id = blockIdx.x; - int grid_size = gridDim.x; // -------------------------------------------------------- // Read input into registers int32x4_t tA[kAtoms]; @@ -588,8 +587,7 @@ struct AllReduceTwoshot { // rank responsible for this segment. uint32_t comm_data0_offset = data_offset + block_id * Codec::kTransmittedTileSize; - uint32_t comm_data1_offset = - grid_size * Codec::kTransmittedTileSize + comm_data0_offset; + uint32_t comm_data1_offset = data_size_per_phase + comm_data0_offset; uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t)); uint32_t comm_flags1_offset = (data_offset / 2) + comm_flags0_offset; From 6687f96c24fa9b8f0cac4b8c9ef5d3b30b2772bf Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Sat, 11 Oct 2025 07:43:50 +0000 Subject: [PATCH 3/5] add test Signed-off-by: Haoyang Li --- tests/distributed/test_quick_all_reduce.py | 88 ++++++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/tests/distributed/test_quick_all_reduce.py b/tests/distributed/test_quick_all_reduce.py index 2df88377345d..4e227e4eb611 100644 --- a/tests/distributed/test_quick_all_reduce.py +++ b/tests/distributed/test_quick_all_reduce.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import multiprocessing import random import pytest @@ -8,6 +9,7 @@ import torch import torch.distributed as dist +from vllm import _custom_ops as ops from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa from vllm.distributed.parallel_state import get_tp_group, graph_capture from vllm.platforms import current_platform @@ -134,3 +136,89 @@ def test_custom_quick_allreduce( monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode) multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target) + + +def reproduce_hang(rank, world_size): + """ + When the tensor parallelism is set to 4 or 8, frequent changes + in the input shape can cause QuickReduce to hang (this issue + has been observed with the gpt_oss model). + """ + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + qr_max_size = None # MB + _ptr = ops.init_custom_qr(rank, world_size, qr_max_size) + ranks = [] + for i in range(world_size): + ranks.append(i) + dist.init_process_group( + backend="nccl", + init_method="tcp://127.0.0.1:29500", + rank=rank, + world_size=world_size, + ) + cpu_group = torch.distributed.new_group(ranks, backend="nccl") + + handle = ops.qr_get_handle(_ptr) + world_size = dist.get_world_size(group=cpu_group) + handles = [None] * world_size + dist.all_gather_object(handles, handle, group=cpu_group) + ops.qr_open_handles(_ptr, handles) + + num = 1 + s1 = 1024 + while num < 50000: # 50000 is sufficient to identify issues. + dtype = torch.float16 + if num % 2 == 0: + s2 = 1024 + inp1 = torch.zeros( + (s1, s2), dtype=dtype, device=torch.cuda.current_device() + ) + else: + s2 = 2048 + inp1 = torch.ones((s1, s2), dtype=dtype, device=torch.cuda.current_device()) + result = torch.empty_like(inp1) + # FP = 0 INT8 = 1 INT6 = 2 INT4 = 3 NONE = 4 + ops.qr_all_reduce(_ptr, inp1, result, 3, cast_bf2half=True) + try: + if inp1[0, 0] == 0: + assert torch.all(result == 0) + else: + assert torch.all(result == world_size) + except AssertionError: + print("Assertion failed! Allreduce results are incorrect.") + raise + # dist.barrier(group=cpu_group) + num += 1 + + +@pytest.mark.skipif( + not current_platform.is_rocm(), reason="only test quick allreduce for rocm" +) +@pytest.mark.parametrize("tp_size", [4, 8]) +@pytest.mark.parametrize("pipeline_parallel_size", [1]) +def test_custom_quick_allreduce_hang_error(tp_size, pipeline_parallel_size): + world_size = tp_size * pipeline_parallel_size + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs to run the test.") + + multiprocessing.set_start_method("spawn", force=True) + # 60s is enough + timeout = 60 + processes = [] + for rank in range(tp_size): + p = multiprocessing.Process(target=reproduce_hang, args=(rank, tp_size)) + p.start() + processes.append((rank, p)) + for rank, p in processes: + p.join(timeout=timeout) + if p.is_alive(): + for r, proc in processes: + if proc.is_alive(): + proc.terminate() + proc.join() + raise RuntimeError(f"QuickReduce hang detected after {timeout} seconds!") + + +if __name__ == "__main__": + test_custom_quick_allreduce_hang_error(tp_size=4, pipeline_parallel_size=1) From 096e9ec383be94f7a2404301b811f44280cf0d58 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 13 Oct 2025 13:57:39 +0000 Subject: [PATCH 4/5] rename func name Signed-off-by: Haoyang Li --- tests/distributed/test_quick_all_reduce.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/distributed/test_quick_all_reduce.py b/tests/distributed/test_quick_all_reduce.py index 4e227e4eb611..3f8d05248de4 100644 --- a/tests/distributed/test_quick_all_reduce.py +++ b/tests/distributed/test_quick_all_reduce.py @@ -138,7 +138,7 @@ def test_custom_quick_allreduce( multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target) -def reproduce_hang(rank, world_size): +def qr_variable_input(rank, world_size): """ When the tensor parallelism is set to 4 or 8, frequent changes in the input shape can cause QuickReduce to hang (this issue @@ -197,7 +197,7 @@ def reproduce_hang(rank, world_size): ) @pytest.mark.parametrize("tp_size", [4, 8]) @pytest.mark.parametrize("pipeline_parallel_size", [1]) -def test_custom_quick_allreduce_hang_error(tp_size, pipeline_parallel_size): +def test_custom_quick_allreduce_variable_input(tp_size, pipeline_parallel_size): world_size = tp_size * pipeline_parallel_size if world_size > torch.cuda.device_count(): pytest.skip("Not enough GPUs to run the test.") @@ -207,7 +207,7 @@ def test_custom_quick_allreduce_hang_error(tp_size, pipeline_parallel_size): timeout = 60 processes = [] for rank in range(tp_size): - p = multiprocessing.Process(target=reproduce_hang, args=(rank, tp_size)) + p = multiprocessing.Process(target=qr_variable_input, args=(rank, tp_size)) p.start() processes.append((rank, p)) for rank, p in processes: @@ -221,4 +221,4 @@ def test_custom_quick_allreduce_hang_error(tp_size, pipeline_parallel_size): if __name__ == "__main__": - test_custom_quick_allreduce_hang_error(tp_size=4, pipeline_parallel_size=1) + test_custom_quick_allreduce_variable_input(tp_size=4, pipeline_parallel_size=1) From a19cb5f251bab00de1d2e91800b99f63a1402588 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 13 Oct 2025 10:36:17 -0400 Subject: [PATCH 5/5] debug cruft Signed-off-by: Tyler Michael Smith --- tests/distributed/test_quick_all_reduce.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/distributed/test_quick_all_reduce.py b/tests/distributed/test_quick_all_reduce.py index 3f8d05248de4..53d906bbc7bd 100644 --- a/tests/distributed/test_quick_all_reduce.py +++ b/tests/distributed/test_quick_all_reduce.py @@ -188,7 +188,6 @@ def qr_variable_input(rank, world_size): except AssertionError: print("Assertion failed! Allreduce results are incorrect.") raise - # dist.barrier(group=cpu_group) num += 1