Skip to content

Conversation

@haoyangli-amd
Copy link
Contributor

@haoyangli-amd haoyangli-amd commented Sep 29, 2025

When the tensor parallelism (TP) degree is set to 4 or 8, frequent changes in the input shape can cause QuickReduce to hang (this issue has been observed with the gpt_oss model).
We have identified that the root cause is overlapping flag memory addresses between consecutive AllReduce operations.

For most models, the hidden size remains relatively stable, so this issue does not occur.

Our current solution is to allocate separate memory regions for the flags and data of the two AllReduce phases in each operation.
(Note: The data region must also be separated, as overlapping would lead to correctness issues.)

To reproduce error

1.git clone https://github.com/vllm-project/vllm.git
2.python3 setup.py develop
3.python3 this_script.py

import torch
import multiprocessing
import argparse
import torch.distributed as dist
from vllm import _custom_ops as ops
def worker(rank, world_size, comm_handles, comm_handle_dict):
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
    qr_max_size = None # MB
    _ptr = ops.init_custom_qr(rank, world_size, qr_max_size)
    ranks = []
    for i in range(world_size):
        ranks.append(i)
    dist.init_process_group(
        backend="nccl",
        init_method="tcp://127.0.0.1:29500",
        rank=rank,
        world_size=world_size
    )
    cpu_group = torch.distributed.new_group(ranks, backend="nccl")

    handle = ops.qr_get_handle(_ptr)
    world_size = dist.get_world_size(group=cpu_group)
    handles = [None] * world_size
    dist.all_gather_object(handles, handle, group=cpu_group)
    ops.qr_open_handles(_ptr, handles)
    num = 1
    s1 = 1024
    while s1 > 0 :
        dtype = torch.float16
        if num % 60 == 0:
            s1 = s1 // 2
        if num % 2 == 0:
            s2=1024
            inp1 = torch.zeros((s1, s2),
                dtype=dtype,
                device=torch.cuda.current_device())
        else:
            s2=2048
            inp1 = torch.ones((s1, s2),
                dtype=dtype,
                device=torch.cuda.current_device())
        # 1=FP16, 2=FP8, 3=Q8, 4=Q6, 5=Q4
        print(f"num:{num}, rank:{rank}, shape:{inp1.shape}")
        result = torch.empty_like(inp1)
        # FP = 0 INT8 = 1 INT6 = 2 INT4 = 3 NONE = 4
        ops.qr_all_reduce(_ptr, inp1, result, 3, cast_bf2half=True)
        try:
            if inp1[0,0] == 0:
                assert torch.all(result == 0)
            else:
                assert torch.all(result == world_size)
        except AssertionError:
            torch.save(result, "result_failed.pth")
            print("Assertion failed! Saved result to result_failed.pth")
            raise 
        # dist.barrier(group=cpu_group)
        num+=1
        if s1 < 100:
            s1 = 8*1024
    print("done")
def run_multiprocessing(world_size):
    with multiprocessing.Manager() as manager:
        comm_handle_dict = manager.dict()
        comm_handles = manager.Barrier(world_size)

        processes = []
        for rank in range(world_size):
            p = multiprocessing.Process(
                target=worker,
                args=(rank, world_size, comm_handles, comm_handle_dict)
            )
            p.start()
            processes.append(p)

        for p in processes:
            p.join()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--world_size",
        type=int,
        default=4,      
        help="number of processes / GPUs to use"
    )
    args = parser.parse_args()

    multiprocessing.set_start_method("spawn")
    run_multiprocessing(world_size=args.world_size)

we can also use this command to check if the hang issue is resolved and checkout if the result is reasonable.

A more detailed explanation
Why does the frequently changing INP shape cause problems?
1.I have obtained some logs.
It appears the program isn't stuck at [256, 2048], but rather at its previous execution point, [256, 1024].
To summarize this phenomenon:
It seems that the program didn’t actually hang at the [256, 2048] stage, but rather at the previous one — [256, 1024].
I suspect that the n-th allreduce and the (n+1)-th allreduce overlap in time.
When the (n+1)-th allreduce executes its phase 1, it modifies the flag used by the n-th allreduce’s phase 2.
For [256,2048], its phase 1 address completely overlaps with the phase 1+phase 2 address of [256,1024].

num:6451, rank:3, shape:torch.Size([256, 1024])
flag_color1:6451
flag_color1:6451
num_blocks:16
flag_color1:6451
flag_color2:6452
flag_color2:6452
flag_color2:6452
num:6452, rank:0, shape:torch.Size([256, 2048])
num_blocks:32
flag_color1:6452
num:6452, rank:1, shape:torch.Size([256, 2048])
num:6452, rank:3, shape:torch.Size([256, 2048])
num_blocks:32
num_blocks:32
flag_color1:6452
flag_color1:6452
2, block:12, thread:0, flag_color:6451, flag_ptr:6452
2, block:14, thread:0, flag_color:6451, flag_ptr:6452

2.Previously, we used hipStreamSynchronize(stream);
hipDeviceSynchronize();
These did not guarantee all ranks would block at the same point. Now we use dist.barrier(group=cpu_group),
and even after running for an hour, the program will not hang.

3.Referring to vLLM’s communication reduction (CR) implementation, using isolated addresses to distinguish different phases of different allreduce batches is necessary to prevent interference between the n-th and (n+1)-th allreduce operations.

Why don't other models have this issue?
For typical models, the hidden size is fixed, so the input does not change frequently, and the phase 2 of the n-th allreduce and phase 1 of the (n+1)-th allreduce do not share addresses. However, for models like GPT-OSS with variable-length inputs, conflicts may occur. What we need to do is to completely isolate the addresses to avoid any conflicts.
image

@haoyangli-amd haoyangli-amd marked this pull request as ready for review October 10, 2025 09:14
@haoyangli-amd haoyangli-amd changed the title [fix] fix qr error when different inp shape [Bugfix][Rocm] fix qr error when different inp shape Oct 10, 2025
@mergify mergify bot added the rocm Related to AMD ROCm label Oct 10, 2025
@ilmarkov
Copy link
Contributor

@haoyangli-amd Please, add a test case covering the issue to test_quick_all_reduce.py. With the approach from your repro script and a timeout to detect a hang.

@haoyangli-amd
Copy link
Contributor Author

hi, @ilmarkov
could you please help to review this pr, thank you so much.


num = 1
s1 = 1024
while num < 50000: # 50000 is sufficient to identify issues.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long does the test take?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin Do we have CI for MI300? Should we add tests/distributed/test_quick_all_reduce.py to test-pipeline?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long does the test take?

about 23s for tp4 and tp8

haoyangli-amd and others added 4 commits October 13, 2025 13:53
Co-authored-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
Co-authored-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 13, 2025
@vllm-bot vllm-bot merged commit 134f70b into vllm-project:main Oct 13, 2025
82 of 85 checks passed
1994 pushed a commit to 1994/vllm that referenced this pull request Oct 14, 2025
)

Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: 1994 <1994@users.noreply.github.com>
Dhruvilbhatt pushed a commit to Dhruvilbhatt/vllm that referenced this pull request Oct 14, 2025
)

Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Dhruvil Bhatt <bhattdbh@amazon.com>
bbartels pushed a commit to bbartels/vllm that referenced this pull request Oct 16, 2025
)

Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: bbartels <benjamin@bartels.dev>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
)

Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
)

Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
)

Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
)

Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
)

Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
)

Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
)

Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Zhathw pushed a commit to Zhathw/vllm that referenced this pull request Nov 12, 2025
)

Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants