Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Distributed] use cpu group to broadcast metadata in cpu #4444

Merged
merged 11 commits into from
Apr 29, 2024
12 changes: 9 additions & 3 deletions vllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import torch
from torch.distributed import ProcessGroup

from .parallel_state import (get_tensor_model_parallel_group,
from .parallel_state import (get_cpu_world_group,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
is_pynccl_enabled_for_all_reduce)
Expand Down Expand Up @@ -146,6 +147,11 @@ def broadcast_tensor_dict(
group: Optional[ProcessGroup] = None,
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary."""
if group is None:
group = torch.distributed.group.WORLD
cpu_group = get_cpu_world_group()
else:
cpu_group = group
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
Expand All @@ -172,7 +178,7 @@ def broadcast_tensor_dict(
metadata_list.append((key, value))
torch.distributed.broadcast_object_list([metadata_list],
src=src,
group=group)
group=cpu_group)
async_handles = []
for key, value in metadata_list:
if isinstance(value, TensorMetadata):
Expand All @@ -189,7 +195,7 @@ def broadcast_tensor_dict(
recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list,
src=src,
group=group)
group=cpu_group)
assert recv_metadata_list[0] is not None
tensor_dict = {}
async_handles = []
Expand Down
Loading