Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Distributed] add fast broadcast for tensor dict #4757

Closed
wants to merge 31 commits into from
Closed
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
70bd52f
add FastBroadcastTensorDict
youkaichao May 11, 2024
4807171
add subclass init
youkaichao May 11, 2024
fa717a4
add tests
youkaichao May 11, 2024
dced974
rm new
youkaichao May 11, 2024
ff95abf
update tests
youkaichao May 11, 2024
88b9e28
use FastBroadcastTensorDict in worker
youkaichao May 11, 2024
ce74136
add get_example_data
youkaichao May 11, 2024
a450b57
update tests
youkaichao May 11, 2024
94e2d70
add get_example_data in worker
youkaichao May 11, 2024
af22c7f
fix init
youkaichao May 11, 2024
0bea4ae
add get_example_metadata_list
youkaichao May 11, 2024
b060e0d
Merge branch 'main' into fbtd
youkaichao May 11, 2024
3e6ee16
rename to TensorDictWithBoundedMetadata
youkaichao May 11, 2024
a8d1d3a
use vllm.TensorMeta
youkaichao May 11, 2024
ed24009
avoid circular import
youkaichao May 11, 2024
1f9a910
add comments
youkaichao May 11, 2024
0467afb
use class attributes
youkaichao May 11, 2024
ee60d78
no need to broadcast keys
youkaichao May 11, 2024
6969587
fix key
youkaichao May 11, 2024
8158667
fix buffer size calculation
youkaichao May 11, 2024
52a59ec
use smaller align bytes
youkaichao May 12, 2024
89bd1ec
assert torch dtype initialized
youkaichao May 12, 2024
d0a43ef
fix torch dtype
youkaichao May 12, 2024
62ac962
use str
youkaichao May 12, 2024
5469842
Merge branch 'main' into fbtd
youkaichao May 13, 2024
59f094a
fix merge
youkaichao May 13, 2024
5c9f0e9
update tests
youkaichao May 13, 2024
7f8ce07
type annotation for get_example_metadata_list
youkaichao May 13, 2024
4ae2b3a
type annotation for get_example_metadata_list
youkaichao May 13, 2024
8de5ba2
add assert
youkaichao May 13, 2024
70c3664
fix name error
youkaichao May 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions tests/distributed/test_comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import ray
import torch

from vllm import TensorMeta
from vllm.distributed import (broadcast_tensor_dict,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.distributed.communication_op import TensorDictWithBoundedMetadata
from vllm.test_utils import (init_test_distributed_environment,
multi_process_tensor_parallel)

Expand Down Expand Up @@ -104,12 +106,62 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
assert torch.allclose(recv_dict["f"], test_dict["f"])


@ray.remote(num_gpus=1, max_calls=1)
def fast_broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int,
rank: int,
distributed_init_port: str):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# they will be able to set the device to the correct GPU
del os.environ["CUDA_VISIBLE_DEVICES"]

# Note: it is important to define the custom data class in the worker
# the class definition might initialize torch/cuda, and might read
# environment variables CUDA_VISIBLE_DEVICES
class CustomData(TensorDictWithBoundedMetadata):

def __init__(self, a, b):
self.a = a
self.b = b

fields = ["a", "b"]

@classmethod
def get_example_metadata_list(cls):
return [
("a", TensorMeta("cuda", torch.float32, torch.Size([]))),
("b", TensorMeta("cpu", torch.float32, torch.Size([]))),
]

device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)

test_dict = {
# device tensor
"a": torch.arange(0, dtype=torch.float32, device="cuda"),
# CPU tensor
"b": torch.arange(0, dtype=torch.int8, device="cpu"),
}
obj = CustomData(**test_dict)
if rank == 0:
broadcast_tensor_dict(obj.__dict__, src=0, cls=CustomData)
else:
obj = broadcast_tensor_dict(src=0, cls=CustomData)
assert len(obj.__dict__) == len(test_dict)
assert torch.allclose(obj.a, test_dict["a"])
assert torch.allclose(obj.b, test_dict["b"])


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("test_target", [
all_reduce_test_worker, all_gather_test_worker,
broadcast_tensor_dict_test_worker
all_reduce_test_worker,
all_gather_test_worker,
broadcast_tensor_dict_test_worker,
fast_broadcast_tensor_dict_test_worker,
])
def test_multi_process_tensor_parallel(tp_size, test_target):
multi_process_tensor_parallel(tp_size, 1, test_target)
34 changes: 34 additions & 0 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""

import dataclasses

import torch

from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
Expand All @@ -11,6 +15,35 @@
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams


@dataclasses.dataclass
class TensorMeta:
"""
This class is placed here to reduce the size of qualified name,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about we just create vllm/tensor_meta.py? Is this still long?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, vllm/tensor_meta.py will lead to vllm.tensor_meta.TensorMeta , longer than vllm.TensorMeta

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does tensor_meta make a big difference? Feel like if it is just a little bit difference (like 2 digits microsecond), I prefer to avoid it...

which will be used in pickle serialization.
"""
device: str
dtype: torch.dtype
size: torch.Size

# use string to avoid torch lazy import issues
# sometimes `torch.int8` is not available at bootstrapping time
torch_dtypes = [
"torch.int8", "torch.int16", "torch.int32", "torch.int64",
"torch.uint8", "torch.uint16", "torch.uint32", "torch.uint64",
"torch.float16", "torch.float32", "torch.float64", "torch.bfloat16"
]
dtype_map = {dtype: i for i, dtype in enumerate(torch_dtypes)}

def __getstate__(self):
return [self.device, self.dtype_map[str(self.dtype)], tuple(self.size)]

def __setstate__(self, state):
self.device = state[0]
self.dtype = eval(self.torch_dtypes[state[1]])
self.size = torch.Size(state[2])


__version__ = "0.4.2"

__all__ = [
Expand All @@ -27,4 +60,5 @@
"AsyncEngineArgs",
"initialize_ray_cluster",
"PoolingParams",
"TensorMeta",
]
184 changes: 141 additions & 43 deletions vllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from collections import namedtuple
import pickle
from contextlib import contextmanager, nullcontext
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from .parallel_state import (get_cpu_world_group,
Expand Down Expand Up @@ -180,44 +181,121 @@ def broadcast_object_list(obj_list: List[Any],
return obj_list


TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])


def _split_tensor_dict(
tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
keys: Optional[List[str]] = None,
) -> Tuple[List[Any], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
by its metadata. If keys are provided, only return the value.
2. A list of tensors.

`keys` is used to specify the keys to be included in the metadata list,
which can make sure the order of the metadata list is consistent across
different ranks.
"""
from vllm import TensorMeta # import here to avoid circular import
metadata_list = []
tensor_list = []
for key, value in tensor_dict.items():
used_keys = keys or tensor_dict.keys()
for key in used_keys:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we assert keys == len(tensor_dict)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary though. This current code is more flexible without assert.

value = tensor_dict[key]
if isinstance(value, torch.Tensor):
# Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device
# index (e.g. "cuda:0"). We only need the device type.
# receiving side will set the device index.
device = "cpu" if value.is_cpu else "cuda"
metadata_list.append(
(key, TensorMetadata(device, value.dtype, value.size())))
(key, TensorMeta(device, value.dtype, value.size())))
tensor_list.append(value)
else:
metadata_list.append((key, value))
if keys is not None:
metadata_list = [value for key, value in metadata_list]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit; why don't we just check it in line 213?

if keys is not None:
    metadata_list.append((key, value))
else:
    metadata_list.append(value)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think control flow in the loop is more expensive (N control flow) than control flow outside of the loop (1 control flow).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In [12]: def control():
    ...:     b = True
    ...:     result = []
    ...:     a = [i for i in range(3000)]
    ...:     for i in a:
    ...:         if b:
    ...:             result.append((i, i))
    ...:         else:
    ...:             result.append(i)

In [16]: def copy():
    ...:     b = True
    ...:     result = []
    ...:     a = [i for i in range(3000)]
    ...:     for i in a:
    ...:         result.append((i, i))
    ...:     result = [value for key, value in result]
In [22]: timeit copy()
192 µs ± 686 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [23]: timeit control()
159 µs ± 487 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Hmm actually I tried and it looks like control is faster. But I think the perf diff here is not very meaningful (it is premature optimization). I was asking because I thought it is easier to understand, but not strong opinion. I will leave it up to you.

return metadata_list, tensor_list


class TensorDictWithBoundedMetadata:
"""
In the normal case, when we broadcast Python objects, we need two
collective operations: one to broadcast the length of the object after
serialization, and one to broadcast the serialized object.

This class represents a dictionary of tensors with bounded metadata.
The upperbound of the buffer size is known a priori. Therefore, we can
pre-allocate a buffer for the metadata, and invoke only one collective
operation to broadcast the metadata.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct because we are now broadcasting using cpu "tensor", we don't need to broadcast the object size (which is the implementation detail of broadcast_object_list)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The key idea is not cpu "tensor", but we know the maximum size of the serialization, so we don't need to broadcast the length. This is indeed an implementation detail of broadcast_object_list.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add this to the comment that it relies on that implementation detail?


The main benefit is we can save one broadcast call.

Note: it depends on the feature of Python pickle that the serialized
data contains a marker for the end of the data. Therefore, as long as
the buffer size is larger than the serialized data, we can guarantee
the deserialization is correct.
"""

@classmethod
def get_max_buffer_size_for_metadata(cls):
metadata_list = cls.get_example_metadata_list()
# Note: we only need the values of the metadata list.
values = [value for key, value in metadata_list]
metadata_list_bytes = pickle.dumps(values)
ALIGN_BYTES = 128
return ((len(metadata_list_bytes) + ALIGN_BYTES - 1) //
ALIGN_BYTES) * ALIGN_BYTES

# ===== subclass overrides starts =====
# subclass should implement the `__init__` method, and set the `fields`
# attribute to a list of field names, and implement the
# `get_example_metadata` class method to provide an example metadata for
# the fields. This is used to calculate the buffer size.
fields: List[str]

def __init__(self):
pass

@classmethod
def get_example_metadata_list(cls):
# Note: in general, if the example data contains cuda tensor,
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
# use cpu tensor here to avoid creating cuda context during
# the initialization of the class. The estimation of the buffer size
# might be inaccurate (by one byte per field), but it is fine because
# the buffer size will be aligned to 256 bytes.
return {}

# ===== subclass overrides ends =====
# for type annotation
size_upper_bound: int
buffer: bytearray
buffer_tensor: torch.Tensor

def __init_subclass__(subclass):
assert hasattr(subclass, "fields"), (
f"Expecting a `fields` attribute in the subclass {subclass}")
subclass.size_upper_bound = subclass.get_max_buffer_size_for_metadata()
subclass.buffer = bytearray(subclass.size_upper_bound)
subclass.buffer_tensor = torch.frombuffer(memoryview(subclass.buffer),
dtype=torch.uint8)


T = TypeVar("T", bound=TensorDictWithBoundedMetadata)


def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
metadata_group: Optional[ProcessGroup] = None,
cls: Optional[Type[T]] = None,
) -> Union[Dict[Any, Union[torch.Tensor, Any]], T]:
"""Broadcast the input tensor dictionary.
`group` is used to broadcast the tensors, while `metadata_group` is used
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
dtypes).
dtypes). If `cls` is provided, we can know the length of the metadata
roughly and allocate a buffer for it, then broadcasting metadata requires
only one broadcast call. Otherwise, we need to broadcast the metadata
length first, then broadcast the metadata.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a simple example of how to use TensorDictWithBoundedMetadata in the docstring?

group = group or torch.distributed.group.WORLD
metadata_group = metadata_group or get_cpu_world_group()
Expand All @@ -227,52 +305,70 @@ def broadcast_tensor_dict(
# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
assert tensor_dict is not None
return tensor_dict

from vllm import TensorMeta # import here to avoid circular import

rank = torch.distributed.get_rank()
if rank == src:
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `broadcast_object_list` involves serialization and deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
torch.distributed.broadcast_object_list([metadata_list],
src=src,
group=metadata_group)
if cls is not None:
metadata_list, tensor_list = _split_tensor_dict(tensor_dict,
keys=cls.fields)
s = pickle.dumps(metadata_list)
cls.buffer_tensor[:len(s)].copy_(
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
torch.frombuffer(s, dtype=torch.uint8))
dist.broadcast(cls.buffer_tensor, src=src, group=metadata_group)
else:
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `broadcast_object_list` involves serialization and
# deserialization, all happening on CPU. Therefore,
# we can use the CPU group.
dist.broadcast_object_list([metadata_list],
src=src,
group=metadata_group)
async_handles = []
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
group=metadata_group,
async_op=True)
handle = dist.broadcast(tensor,
src=src,
group=metadata_group,
async_op=True)
else:
# use group for GPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
group=group,
async_op=True)
handle = dist.broadcast(tensor,
src=src,
group=group,
async_op=True)
async_handles.append(handle)
for async_handle in async_handles:
async_handle.wait()

else:
recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list,
src=src,
group=metadata_group)
assert recv_metadata_list[0] is not None
if cls is None:
container = [None]
dist.broadcast_object_list(container,
src=src,
group=metadata_group)
recv_metadata_list = container[0]
assert recv_metadata_list is not None
else:
dist.broadcast(cls.buffer_tensor, src=src, group=metadata_group)
recv_value_list = pickle.loads(memoryview(cls.buffer))
recv_metadata_list = list(zip(cls.fields, recv_value_list))
tensor_dict = {}
async_handles = []
for key, value in recv_metadata_list[0]:
if isinstance(value, TensorMetadata):
for key, value in recv_metadata_list:
if isinstance(value, TensorMeta):
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
Expand All @@ -282,20 +378,22 @@ def broadcast_tensor_dict(
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
group=metadata_group,
async_op=True)
handle = dist.broadcast(tensor,
src=src,
group=metadata_group,
async_op=True)
else:
# use group for GPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
group=group,
async_op=True)
handle = dist.broadcast(tensor,
src=src,
group=group,
async_op=True)
async_handles.append(handle)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
for async_handle in async_handles:
async_handle.wait()
if cls is not None:
return cls(**tensor_dict)
return tensor_dict
Loading
Loading