-
Notifications
You must be signed in to change notification settings - Fork 3.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core][Distributed] add fast broadcast for tensor dict #4757
Changes from 27 commits
70bd52f
4807171
fa717a4
dced974
ff95abf
88b9e28
ce74136
a450b57
94e2d70
af22c7f
0bea4ae
b060e0d
3e6ee16
a8d1d3a
ed24009
1f9a910
0467afb
ee60d78
6969587
8158667
52a59ec
89bd1ec
d0a43ef
62ac962
5469842
59f094a
5c9f0e9
7f8ce07
4ae2b3a
8de5ba2
70c3664
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
from collections import namedtuple | ||
import pickle | ||
from contextlib import contextmanager, nullcontext | ||
from typing import Any, Dict, List, Optional, Tuple, Union | ||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union | ||
|
||
import torch | ||
import torch.distributed as dist | ||
from torch.distributed import ProcessGroup | ||
|
||
from .parallel_state import (get_cpu_world_group, | ||
|
@@ -180,44 +181,121 @@ def broadcast_object_list(obj_list: List[Any], | |
return obj_list | ||
|
||
|
||
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) | ||
|
||
|
||
def _split_tensor_dict( | ||
tensor_dict: Dict[Any, Union[torch.Tensor, Any]] | ||
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: | ||
tensor_dict: Dict[str, Union[torch.Tensor, Any]], | ||
keys: Optional[List[str]] = None, | ||
) -> Tuple[List[Any], List[torch.Tensor]]: | ||
"""Split the tensor dictionary into two parts: | ||
1. A list of (key, value) pairs. If the value is a tensor, it is replaced | ||
by its metadata. | ||
by its metadata. If keys are provided, only return the value. | ||
2. A list of tensors. | ||
|
||
`keys` is used to specify the keys to be included in the metadata list, | ||
which can make sure the order of the metadata list is consistent across | ||
different ranks. | ||
""" | ||
from vllm import TensorMeta # import here to avoid circular import | ||
metadata_list = [] | ||
tensor_list = [] | ||
for key, value in tensor_dict.items(): | ||
used_keys = keys or tensor_dict.keys() | ||
for key in used_keys: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we assert keys == len(tensor_dict)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not necessary though. This current code is more flexible without assert. |
||
value = tensor_dict[key] | ||
if isinstance(value, torch.Tensor): | ||
# Note: we cannot use `value.device` here, | ||
# because it contains not only the device type but also the device | ||
# index (e.g. "cuda:0"). We only need the device type. | ||
# receiving side will set the device index. | ||
device = "cpu" if value.is_cpu else "cuda" | ||
metadata_list.append( | ||
(key, TensorMetadata(device, value.dtype, value.size()))) | ||
(key, TensorMeta(device, value.dtype, value.size()))) | ||
tensor_list.append(value) | ||
else: | ||
metadata_list.append((key, value)) | ||
if keys is not None: | ||
metadata_list = [value for key, value in metadata_list] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit; why don't we just check it in line 213?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think control flow in the loop is more expensive (N control flow) than control flow outside of the loop (1 control flow). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Hmm actually I tried and it looks like control is faster. But I think the perf diff here is not very meaningful (it is premature optimization). I was asking because I thought it is easier to understand, but not strong opinion. I will leave it up to you. |
||
return metadata_list, tensor_list | ||
|
||
|
||
class TensorDictWithBoundedMetadata: | ||
""" | ||
In the normal case, when we broadcast Python objects, we need two | ||
collective operations: one to broadcast the length of the object after | ||
serialization, and one to broadcast the serialized object. | ||
|
||
This class represents a dictionary of tensors with bounded metadata. | ||
The upperbound of the buffer size is known a priori. Therefore, we can | ||
pre-allocate a buffer for the metadata, and invoke only one collective | ||
operation to broadcast the metadata. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this correct because we are now broadcasting using cpu "tensor", we don't need to broadcast the object size (which is the implementation detail of broadcast_object_list)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The key idea is not cpu "tensor", but we know the maximum size of the serialization, so we don't need to broadcast the length. This is indeed an implementation detail of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add this to the comment that it relies on that implementation detail? |
||
|
||
The main benefit is we can save one broadcast call. | ||
|
||
Note: it depends on the feature of Python pickle that the serialized | ||
data contains a marker for the end of the data. Therefore, as long as | ||
the buffer size is larger than the serialized data, we can guarantee | ||
the deserialization is correct. | ||
""" | ||
|
||
@classmethod | ||
def get_max_buffer_size_for_metadata(cls): | ||
metadata_list = cls.get_example_metadata_list() | ||
# Note: we only need the values of the metadata list. | ||
values = [value for key, value in metadata_list] | ||
metadata_list_bytes = pickle.dumps(values) | ||
ALIGN_BYTES = 128 | ||
return ((len(metadata_list_bytes) + ALIGN_BYTES - 1) // | ||
ALIGN_BYTES) * ALIGN_BYTES | ||
|
||
# ===== subclass overrides starts ===== | ||
# subclass should implement the `__init__` method, and set the `fields` | ||
# attribute to a list of field names, and implement the | ||
# `get_example_metadata` class method to provide an example metadata for | ||
# the fields. This is used to calculate the buffer size. | ||
fields: List[str] | ||
|
||
def __init__(self): | ||
pass | ||
|
||
@classmethod | ||
def get_example_metadata_list(cls): | ||
# Note: in general, if the example data contains cuda tensor, | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# use cpu tensor here to avoid creating cuda context during | ||
# the initialization of the class. The estimation of the buffer size | ||
# might be inaccurate (by one byte per field), but it is fine because | ||
# the buffer size will be aligned to 256 bytes. | ||
return {} | ||
|
||
# ===== subclass overrides ends ===== | ||
# for type annotation | ||
size_upper_bound: int | ||
buffer: bytearray | ||
buffer_tensor: torch.Tensor | ||
|
||
def __init_subclass__(subclass): | ||
assert hasattr(subclass, "fields"), ( | ||
f"Expecting a `fields` attribute in the subclass {subclass}") | ||
subclass.size_upper_bound = subclass.get_max_buffer_size_for_metadata() | ||
subclass.buffer = bytearray(subclass.size_upper_bound) | ||
subclass.buffer_tensor = torch.frombuffer(memoryview(subclass.buffer), | ||
dtype=torch.uint8) | ||
|
||
|
||
T = TypeVar("T", bound=TensorDictWithBoundedMetadata) | ||
|
||
|
||
def broadcast_tensor_dict( | ||
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, | ||
src: int = 0, | ||
group: Optional[ProcessGroup] = None, | ||
metadata_group: Optional[ProcessGroup] = None | ||
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: | ||
metadata_group: Optional[ProcessGroup] = None, | ||
cls: Optional[Type[T]] = None, | ||
) -> Union[Dict[Any, Union[torch.Tensor, Any]], T]: | ||
"""Broadcast the input tensor dictionary. | ||
`group` is used to broadcast the tensors, while `metadata_group` is used | ||
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes, | ||
dtypes). | ||
dtypes). If `cls` is provided, we can know the length of the metadata | ||
roughly and allocate a buffer for it, then broadcasting metadata requires | ||
only one broadcast call. Otherwise, we need to broadcast the metadata | ||
length first, then broadcast the metadata. | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a simple example of how to use TensorDictWithBoundedMetadata in the docstring? |
||
group = group or torch.distributed.group.WORLD | ||
metadata_group = metadata_group or get_cpu_world_group() | ||
|
@@ -227,52 +305,70 @@ def broadcast_tensor_dict( | |
# Bypass the function if we are using only 1 GPU. | ||
world_size = torch.distributed.get_world_size(group=group) | ||
if world_size == 1: | ||
assert tensor_dict is not None | ||
return tensor_dict | ||
|
||
from vllm import TensorMeta # import here to avoid circular import | ||
|
||
rank = torch.distributed.get_rank() | ||
if rank == src: | ||
metadata_list: List[Tuple[Any, Any]] = [] | ||
assert isinstance( | ||
tensor_dict, | ||
dict), (f"Expecting a dictionary, got {type(tensor_dict)}") | ||
metadata_list, tensor_list = _split_tensor_dict(tensor_dict) | ||
# `metadata_list` lives in CPU memory. | ||
# `broadcast_object_list` involves serialization and deserialization, | ||
# all happening on CPU. Therefore, we can use the CPU group. | ||
torch.distributed.broadcast_object_list([metadata_list], | ||
src=src, | ||
group=metadata_group) | ||
if cls is not None: | ||
metadata_list, tensor_list = _split_tensor_dict(tensor_dict, | ||
keys=cls.fields) | ||
s = pickle.dumps(metadata_list) | ||
cls.buffer_tensor[:len(s)].copy_( | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
torch.frombuffer(s, dtype=torch.uint8)) | ||
dist.broadcast(cls.buffer_tensor, src=src, group=metadata_group) | ||
else: | ||
metadata_list, tensor_list = _split_tensor_dict(tensor_dict) | ||
# `metadata_list` lives in CPU memory. | ||
# `broadcast_object_list` involves serialization and | ||
# deserialization, all happening on CPU. Therefore, | ||
# we can use the CPU group. | ||
dist.broadcast_object_list([metadata_list], | ||
src=src, | ||
group=metadata_group) | ||
async_handles = [] | ||
for tensor in tensor_list: | ||
if tensor.numel() == 0: | ||
# Skip broadcasting empty tensors. | ||
continue | ||
if tensor.is_cpu: | ||
# use metadata_group for CPU tensors | ||
handle = torch.distributed.broadcast(tensor, | ||
src=src, | ||
group=metadata_group, | ||
async_op=True) | ||
handle = dist.broadcast(tensor, | ||
src=src, | ||
group=metadata_group, | ||
async_op=True) | ||
else: | ||
# use group for GPU tensors | ||
handle = torch.distributed.broadcast(tensor, | ||
src=src, | ||
group=group, | ||
async_op=True) | ||
handle = dist.broadcast(tensor, | ||
src=src, | ||
group=group, | ||
async_op=True) | ||
async_handles.append(handle) | ||
for async_handle in async_handles: | ||
async_handle.wait() | ||
|
||
else: | ||
recv_metadata_list = [None] | ||
torch.distributed.broadcast_object_list(recv_metadata_list, | ||
src=src, | ||
group=metadata_group) | ||
assert recv_metadata_list[0] is not None | ||
if cls is None: | ||
container = [None] | ||
dist.broadcast_object_list(container, | ||
src=src, | ||
group=metadata_group) | ||
recv_metadata_list = container[0] | ||
assert recv_metadata_list is not None | ||
else: | ||
dist.broadcast(cls.buffer_tensor, src=src, group=metadata_group) | ||
recv_value_list = pickle.loads(memoryview(cls.buffer)) | ||
recv_metadata_list = list(zip(cls.fields, recv_value_list)) | ||
tensor_dict = {} | ||
async_handles = [] | ||
for key, value in recv_metadata_list[0]: | ||
if isinstance(value, TensorMetadata): | ||
for key, value in recv_metadata_list: | ||
if isinstance(value, TensorMeta): | ||
tensor = torch.empty(value.size, | ||
dtype=value.dtype, | ||
device=value.device) | ||
|
@@ -282,20 +378,22 @@ def broadcast_tensor_dict( | |
continue | ||
if tensor.is_cpu: | ||
# use metadata_group for CPU tensors | ||
handle = torch.distributed.broadcast(tensor, | ||
src=src, | ||
group=metadata_group, | ||
async_op=True) | ||
handle = dist.broadcast(tensor, | ||
src=src, | ||
group=metadata_group, | ||
async_op=True) | ||
else: | ||
# use group for GPU tensors | ||
handle = torch.distributed.broadcast(tensor, | ||
src=src, | ||
group=group, | ||
async_op=True) | ||
handle = dist.broadcast(tensor, | ||
src=src, | ||
group=group, | ||
async_op=True) | ||
async_handles.append(handle) | ||
tensor_dict[key] = tensor | ||
else: | ||
tensor_dict[key] = value | ||
for async_handle in async_handles: | ||
async_handle.wait() | ||
if cls is not None: | ||
return cls(**tensor_dict) | ||
return tensor_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about we just create vllm/tensor_meta.py? Is this still long?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes,
vllm/tensor_meta.py
will lead tovllm.tensor_meta.TensorMeta
, longer thanvllm.TensorMeta
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does
tensor_meta
make a big difference? Feel like if it is just a little bit difference (like 2 digits microsecond), I prefer to avoid it...