Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions csrc/quickreduce/quick_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ template <typename AllReduceKernel, typename T>
__global__ __quickreduce_launch_bounds_two_shot__ static void
allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks,
int rank, uint8_t** dbuffer_list,
uint32_t data_offset, uint32_t flag_color) {
uint32_t data_offset, uint32_t flag_color,
int64_t data_size_per_phase) {
int block = blockIdx.x;
int grid = gridDim.x;

while (block < num_blocks) {
AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset,
flag_color);
flag_color, data_size_per_phase);
block += grid;
flag_color++;
}
Expand All @@ -41,21 +42,21 @@ allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks,
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \
flag_color, this->kMaxProblemSize); \
} else if (world_size == 4) { \
using LineCodec = __codec<T, 4>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \
flag_color, this->kMaxProblemSize); \
} else if (world_size == 8) { \
using LineCodec = __codec<T, 8>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \
flag_color, this->kMaxProblemSize); \
}

enum QuickReduceQuantLevel {
Expand Down
9 changes: 3 additions & 6 deletions csrc/quickreduce/quick_reduce_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -553,13 +553,12 @@ struct AllReduceTwoshot {
int const rank, // rank index
uint8_t** __restrict__ buffer_list, // communication buffers
uint32_t const data_offset, // offset to start of the data buffer
uint32_t flag_color) {
uint32_t flag_color, int64_t data_size_per_phase) {
// Topology
int thread = threadIdx.x + threadIdx.y * kWavefront;
uint8_t* rank_buffer = buffer_list[rank];
Codec codec(thread, rank);
int block_id = blockIdx.x;
int grid_size = gridDim.x;
// --------------------------------------------------------
// Read input into registers
int32x4_t tA[kAtoms];
Expand Down Expand Up @@ -588,12 +587,10 @@ struct AllReduceTwoshot {
// rank responsible for this segment.
uint32_t comm_data0_offset =
data_offset + block_id * Codec::kTransmittedTileSize;
uint32_t comm_data1_offset =
grid_size * Codec::kTransmittedTileSize + comm_data0_offset;
uint32_t comm_data1_offset = data_size_per_phase + comm_data0_offset;

uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t));
uint32_t comm_flags1_offset =
grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset;
uint32_t comm_flags1_offset = (data_offset / 2) + comm_flags0_offset;

for (int r = 0; r < kWorldSize; r++) {
int32x4_t* send_buffer =
Expand Down
87 changes: 87 additions & 0 deletions tests/distributed/test_quick_all_reduce.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import multiprocessing
import random

import pytest
import ray
import torch
import torch.distributed as dist

from vllm import _custom_ops as ops
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
from vllm.distributed.parallel_state import get_tp_group, graph_capture
from vllm.platforms import current_platform
Expand Down Expand Up @@ -134,3 +136,88 @@ def test_custom_quick_allreduce(
monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode)

multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target)


def qr_variable_input(rank, world_size):
"""
When the tensor parallelism is set to 4 or 8, frequent changes
in the input shape can cause QuickReduce to hang (this issue
has been observed with the gpt_oss model).
"""
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
qr_max_size = None # MB
_ptr = ops.init_custom_qr(rank, world_size, qr_max_size)
ranks = []
for i in range(world_size):
ranks.append(i)
dist.init_process_group(
backend="nccl",
init_method="tcp://127.0.0.1:29500",
rank=rank,
world_size=world_size,
)
cpu_group = torch.distributed.new_group(ranks, backend="nccl")

handle = ops.qr_get_handle(_ptr)
world_size = dist.get_world_size(group=cpu_group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=cpu_group)
ops.qr_open_handles(_ptr, handles)

num = 1
s1 = 1024
while num < 50000: # 50000 is sufficient to identify issues.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long does the test take?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin Do we have CI for MI300? Should we add tests/distributed/test_quick_all_reduce.py to test-pipeline?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long does the test take?

about 23s for tp4 and tp8

dtype = torch.float16
if num % 2 == 0:
s2 = 1024
inp1 = torch.zeros(
(s1, s2), dtype=dtype, device=torch.cuda.current_device()
)
else:
s2 = 2048
inp1 = torch.ones((s1, s2), dtype=dtype, device=torch.cuda.current_device())
result = torch.empty_like(inp1)
# FP = 0 INT8 = 1 INT6 = 2 INT4 = 3 NONE = 4
ops.qr_all_reduce(_ptr, inp1, result, 3, cast_bf2half=True)
try:
if inp1[0, 0] == 0:
assert torch.all(result == 0)
else:
assert torch.all(result == world_size)
except AssertionError:
print("Assertion failed! Allreduce results are incorrect.")
raise
num += 1


@pytest.mark.skipif(
not current_platform.is_rocm(), reason="only test quick allreduce for rocm"
)
@pytest.mark.parametrize("tp_size", [4, 8])
@pytest.mark.parametrize("pipeline_parallel_size", [1])
def test_custom_quick_allreduce_variable_input(tp_size, pipeline_parallel_size):
world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.")

multiprocessing.set_start_method("spawn", force=True)
# 60s is enough
timeout = 60
processes = []
for rank in range(tp_size):
p = multiprocessing.Process(target=qr_variable_input, args=(rank, tp_size))
p.start()
processes.append((rank, p))
for rank, p in processes:
p.join(timeout=timeout)
if p.is_alive():
for r, proc in processes:
if proc.is_alive():
proc.terminate()
proc.join()
raise RuntimeError(f"QuickReduce hang detected after {timeout} seconds!")


if __name__ == "__main__":
test_custom_quick_allreduce_variable_input(tp_size=4, pipeline_parallel_size=1)
Loading