From 70bd52f3a0a3cfddb539a4825766bbc26da18161 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 10 May 2024 21:44:14 -0700 Subject: [PATCH 01/29] add FastBroadcastTensorDict --- vllm/distributed/communication_op.py | 122 ++++++++++++++++++++------- 1 file changed, 90 insertions(+), 32 deletions(-) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 32ab5694e53..6dde76c5e3b 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,8 +1,10 @@ +import pickle from collections import namedtuple from contextlib import contextmanager -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import torch +import torch.distributed as dist from torch.distributed import ProcessGroup from .parallel_state import (get_cpu_world_group, @@ -185,16 +187,57 @@ def _split_tensor_dict( return metadata_list, tensor_list +class FastBroadcastTensorDict: + + @staticmethod + def get_max_buffer_size_for_metadata(fields: List[str]): + metadata_list = [(f, + TensorMetadata("cuda", torch.float32, + torch.Size((1, 2, 3, 4, 5)))) + for f in fields] + metadata_list_bytes = pickle.dumps(metadata_list) + ALIGN_BYTES = 256 + return ((len(metadata_list_bytes) + ALIGN_BYTES - 1) // + ALIGN_BYTES) * ALIGN_BYTES + + # ===== subclass overrides starts ===== + # subclass should implement the `__init__` method, and set the `fields` + # attribute to a list of field names. Then repeat the following code + # snippet in the subclass to set the buffer size and buffer tensor. + def __init__(self): + pass + + fields: List[str] = [] + size_upper_bound = get_max_buffer_size_for_metadata(fields) + buffer = bytearray(size_upper_bound) + buffer_tensor = torch.frombuffer(memoryview(buffer), dtype=torch.uint8) + + # ===== subclass overrides ends ===== + + @staticmethod + def __new__(cls, tensor_dict: Dict[str, torch.Tensor]): + obj = object.__new__(cls) + obj.__dict__.update(tensor_dict) + return obj + + +T = TypeVar("T", bound=FastBroadcastTensorDict) + + def broadcast_tensor_dict( tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0, group: Optional[ProcessGroup] = None, - metadata_group: Optional[ProcessGroup] = None -) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: + metadata_group: Optional[ProcessGroup] = None, + cls: Optional[Type[T]] = None, +) -> Union[Dict[Any, Union[torch.Tensor, Any]], FastBroadcastTensorDict]: """Broadcast the input tensor dictionary. `group` is used to broadcast the tensors, while `metadata_group` is used to broadcast the metadata of the dict (e.g. dict structure, tensor sizes, - dtypes). + dtypes). If `cls` is provided, we can know the length of the metadata + roughly and allocate a buffer for it, then broadcasting metadata requires + only one broadcast call. Otherwise, we need to broadcast the metadata + length first, then broadcast the metadata. """ group = group or torch.distributed.group.WORLD metadata_group = metadata_group or get_cpu_world_group() @@ -204,6 +247,7 @@ def broadcast_tensor_dict( # Bypass the function if we are using only 1 GPU. world_size = torch.distributed.get_world_size(group=group) if world_size == 1: + assert tensor_dict is not None return tensor_dict rank = torch.distributed.get_rank() @@ -213,12 +257,19 @@ def broadcast_tensor_dict( tensor_dict, dict), (f"Expecting a dictionary, got {type(tensor_dict)}") metadata_list, tensor_list = _split_tensor_dict(tensor_dict) - # `metadata_list` lives in CPU memory. - # `broadcast_object_list` involves serialization and deserialization, - # all happening on CPU. Therefore, we can use the CPU group. - torch.distributed.broadcast_object_list([metadata_list], - src=src, - group=metadata_group) + if cls is not None: + s = pickle.dumps(metadata_list) + cls.buffer_tensor[:len(s)].copy_( + torch.frombuffer(s, dtype=torch.uint8)) + dist.broadcast(cls.buffer_tensor, src=src, group=metadata_group) + else: + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` involves serialization and + # deserialization, all happening on CPU. Therefore, + # we can use the CPU group. + dist.broadcast_object_list([metadata_list], + src=src, + group=metadata_group) async_handles = [] for tensor in tensor_list: if tensor.numel() == 0: @@ -226,29 +277,34 @@ def broadcast_tensor_dict( continue if tensor.is_cpu: # use metadata_group for CPU tensors - handle = torch.distributed.broadcast(tensor, - src=src, - group=metadata_group, - async_op=True) + handle = dist.broadcast(tensor, + src=src, + group=metadata_group, + async_op=True) else: # use group for GPU tensors - handle = torch.distributed.broadcast(tensor, - src=src, - group=group, - async_op=True) + handle = dist.broadcast(tensor, + src=src, + group=group, + async_op=True) async_handles.append(handle) for async_handle in async_handles: async_handle.wait() else: - recv_metadata_list = [None] - torch.distributed.broadcast_object_list(recv_metadata_list, - src=src, - group=metadata_group) - assert recv_metadata_list[0] is not None + if cls is None: + container = [None] + dist.broadcast_object_list(container, + src=src, + group=metadata_group) + recv_metadata_list = container[0] + assert recv_metadata_list is not None + else: + dist.broadcast(cls.buffer_tensor, src=src, group=metadata_group) + recv_metadata_list = pickle.loads(memoryview(cls.buffer)) tensor_dict = {} async_handles = [] - for key, value in recv_metadata_list[0]: + for key, value in recv_metadata_list: if isinstance(value, TensorMetadata): tensor = torch.empty(value.size, dtype=value.dtype, @@ -259,20 +315,22 @@ def broadcast_tensor_dict( continue if tensor.is_cpu: # use metadata_group for CPU tensors - handle = torch.distributed.broadcast(tensor, - src=src, - group=metadata_group, - async_op=True) + handle = dist.broadcast(tensor, + src=src, + group=metadata_group, + async_op=True) else: # use group for GPU tensors - handle = torch.distributed.broadcast(tensor, - src=src, - group=group, - async_op=True) + handle = dist.broadcast(tensor, + src=src, + group=group, + async_op=True) async_handles.append(handle) tensor_dict[key] = tensor else: tensor_dict[key] = value for async_handle in async_handles: async_handle.wait() + if cls is not None: + return cls.__new__(cls, tensor_dict) return tensor_dict From 48071710cda926499c2c9624550d1966755b4dcc Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 10 May 2024 22:54:05 -0700 Subject: [PATCH 02/29] add subclass init --- vllm/distributed/communication_op.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 6dde76c5e3b..e60754a07d3 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -202,17 +202,26 @@ def get_max_buffer_size_for_metadata(fields: List[str]): # ===== subclass overrides starts ===== # subclass should implement the `__init__` method, and set the `fields` - # attribute to a list of field names. Then repeat the following code - # snippet in the subclass to set the buffer size and buffer tensor. + # attribute to a list of field names. def __init__(self): pass - fields: List[str] = [] - size_upper_bound = get_max_buffer_size_for_metadata(fields) - buffer = bytearray(size_upper_bound) - buffer_tensor = torch.frombuffer(memoryview(buffer), dtype=torch.uint8) + fields: List[str] # ===== subclass overrides ends ===== + # for type annotation + size_upper_bound: int + buffer: bytearray + buffer_tensor: torch.Tensor + + def __init_subclass__(subclass): + assert hasattr(subclass, "fields"), ( + f"Expecting a `fields` attribute in the subclass {subclass}") + subclass.size_upper_bound = subclass.get_max_buffer_size_for_metadata( + subclass.fields) + subclass.buffer = bytearray(subclass.size_upper_bound) + subclass.buffer_tensor = torch.frombuffer(memoryview(subclass.buffer), + dtype=torch.uint8) @staticmethod def __new__(cls, tensor_dict: Dict[str, torch.Tensor]): From fa717a4c64e1ab2e477d99f50bab52dc1669437a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 10 May 2024 22:59:50 -0700 Subject: [PATCH 03/29] add tests --- tests/distributed/test_comm_ops.py | 47 ++++++++++++++++++++++++++-- vllm/distributed/communication_op.py | 2 +- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 9a7a1f07e1b..deec3429981 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -11,6 +11,7 @@ from vllm.distributed import (broadcast_tensor_dict, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) +from vllm.distributed.communication_op import FastBroadcastTensorDict from vllm.test_utils import (init_test_distributed_environment, multi_process_tensor_parallel) @@ -104,12 +105,54 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, assert torch.allclose(recv_dict["f"], test_dict["f"]) +class CustomData(FastBroadcastTensorDict): + + def __init__(self, a, b): + self.a = a + self.b = b + + fields = ["a", "b"] + + +@ray.remote(num_gpus=1, max_calls=1) +def fast_broadcast_tensor_dict_test_worker(tensor_parallel_size: int, + rank: int, + distributed_init_port: str): + # it is important to delete the CUDA_VISIBLE_DEVICES environment variable + # so that each worker can see all the GPUs + # they will be able to set the device to the correct GPU + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + init_test_distributed_environment(1, tensor_parallel_size, rank, + distributed_init_port) + + test_dict = { + # device tensor + "a": torch.arange(0, dtype=torch.float32, device="cuda"), + # CPU tensor + "b": torch.arange(0, dtype=torch.int8, device="cpu"), + } + + if rank == 0: + obj = CustomData(**test_dict) + broadcast_tensor_dict(obj.__dict__, src=0, cls=CustomData) + else: + obj = broadcast_tensor_dict(src=0, cls=CustomData) + recv_dict = obj.__dict__ + assert len(recv_dict) == len(test_dict) + assert torch.allclose(recv_dict["a"], test_dict["a"]) + assert torch.allclose(recv_dict["b"], test_dict["b"]) + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test.") @pytest.mark.parametrize("tensor_parallel_size", [2]) @pytest.mark.parametrize("test_target", [ - all_reduce_test_worker, all_gather_test_worker, - broadcast_tensor_dict_test_worker + all_reduce_test_worker, + all_gather_test_worker, + broadcast_tensor_dict_test_worker, + fast_broadcast_tensor_dict_test_worker, ]) def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): multi_process_tensor_parallel(tensor_parallel_size, test_target) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index e60754a07d3..fd80faa169a 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -239,7 +239,7 @@ def broadcast_tensor_dict( group: Optional[ProcessGroup] = None, metadata_group: Optional[ProcessGroup] = None, cls: Optional[Type[T]] = None, -) -> Union[Dict[Any, Union[torch.Tensor, Any]], FastBroadcastTensorDict]: +) -> Union[Dict[Any, Union[torch.Tensor, Any]], T]: """Broadcast the input tensor dictionary. `group` is used to broadcast the tensors, while `metadata_group` is used to broadcast the metadata of the dict (e.g. dict structure, tensor sizes, From dced974a2c9747bfe03f034d52c8c5b6bb0ceb10 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 10 May 2024 23:07:36 -0700 Subject: [PATCH 04/29] rm new --- vllm/distributed/communication_op.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index fd80faa169a..47f98cf2045 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -223,12 +223,6 @@ def __init_subclass__(subclass): subclass.buffer_tensor = torch.frombuffer(memoryview(subclass.buffer), dtype=torch.uint8) - @staticmethod - def __new__(cls, tensor_dict: Dict[str, torch.Tensor]): - obj = object.__new__(cls) - obj.__dict__.update(tensor_dict) - return obj - T = TypeVar("T", bound=FastBroadcastTensorDict) @@ -341,5 +335,5 @@ def broadcast_tensor_dict( for async_handle in async_handles: async_handle.wait() if cls is not None: - return cls.__new__(cls, tensor_dict) + return cls(**tensor_dict) return tensor_dict From ff95abf4dd2e46ab356e356b69fd8407b674a067 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 10 May 2024 23:12:52 -0700 Subject: [PATCH 05/29] update tests --- tests/distributed/test_comm_ops.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index deec3429981..25ef3b15374 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -133,16 +133,14 @@ def fast_broadcast_tensor_dict_test_worker(tensor_parallel_size: int, # CPU tensor "b": torch.arange(0, dtype=torch.int8, device="cpu"), } - + obj = CustomData(**test_dict) if rank == 0: - obj = CustomData(**test_dict) broadcast_tensor_dict(obj.__dict__, src=0, cls=CustomData) else: obj = broadcast_tensor_dict(src=0, cls=CustomData) - recv_dict = obj.__dict__ - assert len(recv_dict) == len(test_dict) - assert torch.allclose(recv_dict["a"], test_dict["a"]) - assert torch.allclose(recv_dict["b"], test_dict["b"]) + assert len(obj.__dict__) == len(test_dict) + assert torch.allclose(obj.a, test_dict["a"]) + assert torch.allclose(obj.b, test_dict["b"]) @pytest.mark.skipif(torch.cuda.device_count() < 2, From 88b9e2806264a61f68d5ebc27a2a986e0bce56ca Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 10 May 2024 23:28:27 -0700 Subject: [PATCH 06/29] use FastBroadcastTensorDict in worker --- vllm/worker/worker.py | 45 ++++++++++++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 0ca9c2b64cf..297667fc47e 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,7 +1,7 @@ """A GPU worker class.""" import gc import os -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import List, Optional, Set, Tuple import torch import torch.distributed @@ -12,6 +12,7 @@ from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment) +from vllm.distributed.communication_op import FastBroadcastTensorDict from vllm.distributed.device_communicators.custom_all_reduce import ( init_custom_ar) from vllm.lora.request import LoRARequest @@ -22,6 +23,25 @@ from vllm.worker.worker_base import WorkerBase +class BlockMetaData(FastBroadcastTensorDict): + """ + Use BlockMetaData to save one broadcasted in broadcast Python object. + """ + + def __init__(self, num_seq_groups: int, blocks_to_swap_in: torch.Tensor, + blocks_to_swap_out: torch.Tensor, + blocks_to_copy: torch.Tensor): + self.num_seq_groups = num_seq_groups + self.blocks_to_swap_in = blocks_to_swap_in + self.blocks_to_swap_out = blocks_to_swap_out + self.blocks_to_copy = blocks_to_copy + + fields = [ + "num_seq_groups", "blocks_to_swap_in", "blocks_to_swap_out", + "blocks_to_copy" + ] + + class Worker(WorkerBase): """A worker class that executes (a partition of) the model on a GPU. @@ -219,6 +239,7 @@ def execute_model( blocks_to_swap_in: torch.Tensor blocks_to_swap_out: torch.Tensor blocks_to_copy: torch.Tensor + data: BlockMetaData if self.is_driver_worker: assert seq_group_metadata_list is not None assert execute_model_req is not None @@ -239,19 +260,17 @@ def execute_model( blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, device=self.device, dtype=torch.int64).view(-1, 2) - data: Dict[str, Any] = { - "num_seq_groups": num_seq_groups, - "blocks_to_swap_in": blocks_to_swap_in, - "blocks_to_swap_out": blocks_to_swap_out, - "blocks_to_copy": blocks_to_copy, - } - broadcast_tensor_dict(data, src=0) + data = BlockMetaData(num_seq_groups=num_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy) + broadcast_tensor_dict(data.__dict__, src=0, cls=BlockMetaData) else: - data = broadcast_tensor_dict(src=0) - num_seq_groups = data["num_seq_groups"] - blocks_to_swap_in = data["blocks_to_swap_in"] - blocks_to_swap_out = data["blocks_to_swap_out"] - blocks_to_copy = data["blocks_to_copy"] + data = broadcast_tensor_dict(src=0, cls=BlockMetaData) + num_seq_groups = data.num_seq_groups + blocks_to_swap_in = data.blocks_to_swap_in + blocks_to_swap_out = data.blocks_to_swap_out + blocks_to_copy = data.blocks_to_copy self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) From ce741361223028b51ca8b701aa93821ba44a71f2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 10 May 2024 23:50:34 -0700 Subject: [PATCH 07/29] add get_example_data --- vllm/distributed/communication_op.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 47f98cf2045..03104beca86 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -189,12 +189,10 @@ def _split_tensor_dict( class FastBroadcastTensorDict: - @staticmethod - def get_max_buffer_size_for_metadata(fields: List[str]): - metadata_list = [(f, - TensorMetadata("cuda", torch.float32, - torch.Size((1, 2, 3, 4, 5)))) - for f in fields] + @classmethod + def get_max_buffer_size_for_metadata(cls): + data = cls.get_example_data() + metadata_list, _ = _split_tensor_dict(data) metadata_list_bytes = pickle.dumps(metadata_list) ALIGN_BYTES = 256 return ((len(metadata_list_bytes) + ALIGN_BYTES - 1) // @@ -202,11 +200,22 @@ def get_max_buffer_size_for_metadata(fields: List[str]): # ===== subclass overrides starts ===== # subclass should implement the `__init__` method, and set the `fields` - # attribute to a list of field names. + # attribute to a list of field names, and implement the + # `get_example_metadata` class method to provide an example metadata for + # the fields. This is used to calculate the buffer size. + fields: List[str] + def __init__(self): pass - fields: List[str] + @classmethod + def get_example_data(cls): + # Note: in general, if the example data contains cuda tensor, + # use cpu tensor here to avoid creating cuda context during + # the initialization of the class. The estimation of the buffer size + # might be inaccurate (by one byte per field), but it is fine because + # the buffer size will be aligned to 256 bytes. + return {} # ===== subclass overrides ends ===== # for type annotation @@ -217,8 +226,7 @@ def __init__(self): def __init_subclass__(subclass): assert hasattr(subclass, "fields"), ( f"Expecting a `fields` attribute in the subclass {subclass}") - subclass.size_upper_bound = subclass.get_max_buffer_size_for_metadata( - subclass.fields) + subclass.size_upper_bound = subclass.get_max_buffer_size_for_metadata() subclass.buffer = bytearray(subclass.size_upper_bound) subclass.buffer_tensor = torch.frombuffer(memoryview(subclass.buffer), dtype=torch.uint8) From a450b57690b34eac325e05cd03234489fad697ba Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 10 May 2024 23:51:17 -0700 Subject: [PATCH 08/29] update tests --- tests/distributed/test_comm_ops.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 25ef3b15374..03cf168844c 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -113,6 +113,13 @@ def __init__(self, a, b): fields = ["a", "b"] + @classmethod + def get_example_data(cls): + return { + "a": torch.tensor([], dtype=torch.float32, device="cpu"), + "b": torch.tensor([], dtype=torch.int8, device="cpu"), + } + @ray.remote(num_gpus=1, max_calls=1) def fast_broadcast_tensor_dict_test_worker(tensor_parallel_size: int, From 94e2d707e342710e78dcef85ff3553d1f1f1da0f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 10 May 2024 23:53:03 -0700 Subject: [PATCH 09/29] add get_example_data in worker --- vllm/worker/worker.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 297667fc47e..32b9a7a294b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -41,6 +41,19 @@ def __init__(self, num_seq_groups: int, blocks_to_swap_in: torch.Tensor, "blocks_to_copy" ] + @classmethod + def get_example_data(cls): + return { + "num_seq_groups": + 0, + "blocks_to_swap_in": + torch.randn((3, 2), dtype=torch.int64, device="cpu"), + "blocks_to_swap_out": + torch.randn((3, 2), dtype=torch.int64, device="cpu"), + "blocks_to_copy": + torch.randn((3, 2), dtype=torch.int64, device="cpu"), + } + class Worker(WorkerBase): """A worker class that executes (a partition of) the model on a GPU. From af22c7f8428d4b0c946df11858b3f05d5a54b476 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 10 May 2024 23:56:09 -0700 Subject: [PATCH 10/29] fix init --- vllm/worker/worker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 32b9a7a294b..eb7c65c3f64 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -47,11 +47,11 @@ def get_example_data(cls): "num_seq_groups": 0, "blocks_to_swap_in": - torch.randn((3, 2), dtype=torch.int64, device="cpu"), + torch.zeros((3, 2), dtype=torch.int64, device="cpu"), "blocks_to_swap_out": - torch.randn((3, 2), dtype=torch.int64, device="cpu"), + torch.zeros((3, 2), dtype=torch.int64, device="cpu"), "blocks_to_copy": - torch.randn((3, 2), dtype=torch.int64, device="cpu"), + torch.zeros((3, 2), dtype=torch.int64, device="cpu"), } From 0bea4aeac160723a0209cc3da49de571b4021e49 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 May 2024 00:05:25 -0700 Subject: [PATCH 11/29] add get_example_metadata_list --- tests/distributed/test_comm_ops.py | 13 +++++++------ vllm/distributed/communication_op.py | 5 ++--- vllm/worker/worker.py | 24 ++++++++++++------------ 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 03cf168844c..a1e33aab9d2 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -11,7 +11,8 @@ from vllm.distributed import (broadcast_tensor_dict, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.distributed.communication_op import FastBroadcastTensorDict +from vllm.distributed.communication_op import (FastBroadcastTensorDict, + TensorMetadata) from vllm.test_utils import (init_test_distributed_environment, multi_process_tensor_parallel) @@ -114,11 +115,11 @@ def __init__(self, a, b): fields = ["a", "b"] @classmethod - def get_example_data(cls): - return { - "a": torch.tensor([], dtype=torch.float32, device="cpu"), - "b": torch.tensor([], dtype=torch.int8, device="cpu"), - } + def get_example_metadata_list(cls): + return [ + ("a", TensorMetadata("cuda", torch.float32, torch.Size([]))), + ("b", TensorMetadata("cpu", torch.float32, torch.Size([]))), + ] @ray.remote(num_gpus=1, max_calls=1) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 03104beca86..1ec01100bf6 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -191,8 +191,7 @@ class FastBroadcastTensorDict: @classmethod def get_max_buffer_size_for_metadata(cls): - data = cls.get_example_data() - metadata_list, _ = _split_tensor_dict(data) + metadata_list = cls.get_example_metadata_list() metadata_list_bytes = pickle.dumps(metadata_list) ALIGN_BYTES = 256 return ((len(metadata_list_bytes) + ALIGN_BYTES - 1) // @@ -209,7 +208,7 @@ def __init__(self): pass @classmethod - def get_example_data(cls): + def get_example_metadata_list(cls): # Note: in general, if the example data contains cuda tensor, # use cpu tensor here to avoid creating cuda context during # the initialization of the class. The estimation of the buffer size diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index eb7c65c3f64..ddf98711065 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -12,7 +12,8 @@ from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment) -from vllm.distributed.communication_op import FastBroadcastTensorDict +from vllm.distributed.communication_op import (FastBroadcastTensorDict, + TensorMetadata) from vllm.distributed.device_communicators.custom_all_reduce import ( init_custom_ar) from vllm.lora.request import LoRARequest @@ -42,17 +43,16 @@ def __init__(self, num_seq_groups: int, blocks_to_swap_in: torch.Tensor, ] @classmethod - def get_example_data(cls): - return { - "num_seq_groups": - 0, - "blocks_to_swap_in": - torch.zeros((3, 2), dtype=torch.int64, device="cpu"), - "blocks_to_swap_out": - torch.zeros((3, 2), dtype=torch.int64, device="cpu"), - "blocks_to_copy": - torch.zeros((3, 2), dtype=torch.int64, device="cpu"), - } + def get_example_metadata_list(cls): + return [ + ("num_seq_groups", 1), + ("blocks_to_swap_in", + TensorMetadata("cpu", torch.int64, torch.Size([1, 2]))), + ("blocks_to_swap_out", + TensorMetadata("cpu", torch.int64, torch.Size([1, 2]))), + ("blocks_to_copy", + TensorMetadata("cuda", torch.int64, torch.Size([1, 2]))), + ] class Worker(WorkerBase): From 3e6ee16c97dc8566ada08e94c14451cd0227613b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 May 2024 15:46:49 -0700 Subject: [PATCH 12/29] rename to TensorDictWithBoundedMetadata --- tests/distributed/test_comm_ops.py | 4 ++-- vllm/distributed/communication_op.py | 21 +++++++++++++++++++-- vllm/worker/worker.py | 4 ++-- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index a1e33aab9d2..a8c2fe0f370 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -11,7 +11,7 @@ from vllm.distributed import (broadcast_tensor_dict, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.distributed.communication_op import (FastBroadcastTensorDict, +from vllm.distributed.communication_op import (TensorDictWithBoundedMetadata, TensorMetadata) from vllm.test_utils import (init_test_distributed_environment, multi_process_tensor_parallel) @@ -106,7 +106,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, assert torch.allclose(recv_dict["f"], test_dict["f"]) -class CustomData(FastBroadcastTensorDict): +class CustomData(TensorDictWithBoundedMetadata): def __init__(self, a, b): self.a = a diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 1ec01100bf6..bf10a08acc0 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -187,7 +187,24 @@ def _split_tensor_dict( return metadata_list, tensor_list -class FastBroadcastTensorDict: +class TensorDictWithBoundedMetadata: + """ + In the normal case, when we broadcast Python objects, we need two + collective operations: one to broadcast the length of the object after + serialization, and one to broadcast the serialized object. + + This class represents a dictionary of tensors with bounded metadata. + The upperbound of the buffer size is known a priori. Therefore, we can + pre-allocate a buffer for the metadata, and invoke only one collective + operation to broadcast the metadata. + + The main benefit is we can save one broadcast call. + + Note: it depends on the feature of Python pickle that the serialized + data contains a marker for the end of the data. Therefore, as long as + the buffer size is larger than the serialized data, we can guarantee + the deserialization is correct. + """ @classmethod def get_max_buffer_size_for_metadata(cls): @@ -231,7 +248,7 @@ def __init_subclass__(subclass): dtype=torch.uint8) -T = TypeVar("T", bound=FastBroadcastTensorDict) +T = TypeVar("T", bound=TensorDictWithBoundedMetadata) def broadcast_tensor_dict( diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 5c4160675d2..460a157dfc2 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -12,7 +12,7 @@ from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment) -from vllm.distributed.communication_op import (FastBroadcastTensorDict, +from vllm.distributed.communication_op import (TensorDictWithBoundedMetadata, TensorMetadata) from vllm.distributed.device_communicators.custom_all_reduce import ( init_custom_ar) @@ -25,7 +25,7 @@ from vllm.worker.worker_base import WorkerBase -class BlockMetaData(FastBroadcastTensorDict): +class BlockMetaData(TensorDictWithBoundedMetadata): """ Use BlockMetaData to save one broadcasted in broadcast Python object. """ From a8d1d3a699b8acfa61f2e87284ce4a6357cb5bfb Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 May 2024 16:13:19 -0700 Subject: [PATCH 13/29] use vllm.TensorMeta --- tests/distributed/test_comm_ops.py | 8 ++++---- vllm/__init__.py | 27 +++++++++++++++++++++++++++ vllm/distributed/communication_op.py | 10 ++++------ vllm/worker/worker.py | 10 +++++----- 4 files changed, 40 insertions(+), 15 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index a8c2fe0f370..533b70176f6 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -8,11 +8,11 @@ import ray import torch +from vllm import TensorMeta from vllm.distributed import (broadcast_tensor_dict, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.distributed.communication_op import (TensorDictWithBoundedMetadata, - TensorMetadata) +from vllm.distributed.communication_op import TensorDictWithBoundedMetadata from vllm.test_utils import (init_test_distributed_environment, multi_process_tensor_parallel) @@ -117,8 +117,8 @@ def __init__(self, a, b): @classmethod def get_example_metadata_list(cls): return [ - ("a", TensorMetadata("cuda", torch.float32, torch.Size([]))), - ("b", TensorMetadata("cpu", torch.float32, torch.Size([]))), + ("a", TensorMeta("cuda", torch.float32, torch.Size([]))), + ("b", TensorMeta("cpu", torch.float32, torch.Size([]))), ] diff --git a/vllm/__init__.py b/vllm/__init__.py index 74674ca0d12..8a7d69668ce 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -1,5 +1,9 @@ """vLLM: a high-throughput and memory-efficient inference engine for LLMs""" +import dataclasses + +import torch + from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.llm_engine import LLMEngine @@ -11,6 +15,28 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams +torch_dtypes = [ + getattr(torch, attr) for attr in dir(torch) + if isinstance(getattr(torch, attr), torch.dtype) +] +dtype_map = {dtype: i for i, dtype in enumerate(torch_dtypes)} + + +@dataclasses.dataclass +class TensorMeta: + device: str + dtype: torch.dtype + size: torch.Size + + def __getstate__(self): + return [self.device, dtype_map[self.dtype], tuple(self.size)] + + def __setstate__(self, state): + self.device = state[0] + self.dtype = torch_dtypes[state[1]] + self.size = torch.Size(state[2]) + + __version__ = "0.4.2" __all__ = [ @@ -27,4 +53,5 @@ "AsyncEngineArgs", "initialize_ray_cluster", "PoolingParams", + "TensorMeta", ] diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index bf10a08acc0..c6c66e9014c 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,5 +1,4 @@ import pickle -from collections import namedtuple from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union @@ -7,6 +6,8 @@ import torch.distributed as dist from torch.distributed import ProcessGroup +from vllm import TensorMeta + from .parallel_state import (get_cpu_world_group, get_tensor_model_parallel_group, get_tensor_model_parallel_rank, @@ -159,9 +160,6 @@ def broadcast_object_list(obj_list: List[Any], return obj_list -TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) - - def _split_tensor_dict( tensor_dict: Dict[Any, Union[torch.Tensor, Any]] ) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: @@ -180,7 +178,7 @@ def _split_tensor_dict( # receiving side will set the device index. device = "cpu" if value.is_cpu else "cuda" metadata_list.append( - (key, TensorMetadata(device, value.dtype, value.size()))) + (key, TensorMeta(device, value.dtype, value.size()))) tensor_list.append(value) else: metadata_list.append((key, value)) @@ -332,7 +330,7 @@ def broadcast_tensor_dict( tensor_dict = {} async_handles = [] for key, value in recv_metadata_list: - if isinstance(value, TensorMetadata): + if isinstance(value, TensorMeta): tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 460a157dfc2..776eb2b0538 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -6,14 +6,14 @@ import torch import torch.distributed +from vllm import TensorMeta from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment) -from vllm.distributed.communication_op import (TensorDictWithBoundedMetadata, - TensorMetadata) +from vllm.distributed.communication_op import TensorDictWithBoundedMetadata from vllm.distributed.device_communicators.custom_all_reduce import ( init_custom_ar) from vllm.lora.request import LoRARequest @@ -48,11 +48,11 @@ def get_example_metadata_list(cls): return [ ("num_seq_groups", 1), ("blocks_to_swap_in", - TensorMetadata("cpu", torch.int64, torch.Size([1, 2]))), + TensorMeta("cpu", torch.int64, torch.Size([1, 2]))), ("blocks_to_swap_out", - TensorMetadata("cpu", torch.int64, torch.Size([1, 2]))), + TensorMeta("cpu", torch.int64, torch.Size([1, 2]))), ("blocks_to_copy", - TensorMetadata("cuda", torch.int64, torch.Size([1, 2]))), + TensorMeta("cuda", torch.int64, torch.Size([1, 2]))), ] From ed2400903047def1c06ad291237062915ebe37d6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 May 2024 16:16:30 -0700 Subject: [PATCH 14/29] avoid circular import --- vllm/distributed/communication_op.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index c6c66e9014c..d23fa522bbe 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -6,8 +6,6 @@ import torch.distributed as dist from torch.distributed import ProcessGroup -from vllm import TensorMeta - from .parallel_state import (get_cpu_world_group, get_tensor_model_parallel_group, get_tensor_model_parallel_rank, @@ -168,6 +166,7 @@ def _split_tensor_dict( by its metadata. 2. A list of tensors. """ + from vllm import TensorMeta # import here to avoid circular import metadata_list = [] tensor_list = [] for key, value in tensor_dict.items(): @@ -275,6 +274,8 @@ def broadcast_tensor_dict( assert tensor_dict is not None return tensor_dict + from vllm import TensorMeta # import here to avoid circular import + rank = torch.distributed.get_rank() if rank == src: metadata_list: List[Tuple[Any, Any]] = [] From 1f9a910842ca6bc0020bee7994703a2aac6cee80 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 May 2024 16:17:52 -0700 Subject: [PATCH 15/29] add comments --- vllm/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/__init__.py b/vllm/__init__.py index 8a7d69668ce..b8111ed74b3 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -24,6 +24,10 @@ @dataclasses.dataclass class TensorMeta: + """ + This class is placed here to reduce the size of qualified name, + which will be used in pickle serialization. + """ device: str dtype: torch.dtype size: torch.Size From 0467afbee1ef5f6d760e8b10d28e839e8a2a2f16 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 May 2024 16:35:58 -0700 Subject: [PATCH 16/29] use class attributes --- vllm/__init__.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/__init__.py b/vllm/__init__.py index b8111ed74b3..fed96903265 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -15,12 +15,6 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -torch_dtypes = [ - getattr(torch, attr) for attr in dir(torch) - if isinstance(getattr(torch, attr), torch.dtype) -] -dtype_map = {dtype: i for i, dtype in enumerate(torch_dtypes)} - @dataclasses.dataclass class TensorMeta: @@ -32,12 +26,18 @@ class TensorMeta: dtype: torch.dtype size: torch.Size + torch_dtypes = [ + getattr(torch, attr) for attr in dir(torch) + if isinstance(getattr(torch, attr), torch.dtype) + ] + dtype_map = {dtype: i for i, dtype in enumerate(torch_dtypes)} + def __getstate__(self): - return [self.device, dtype_map[self.dtype], tuple(self.size)] + return [self.device, self.dtype_map[self.dtype], tuple(self.size)] def __setstate__(self, state): self.device = state[0] - self.dtype = torch_dtypes[state[1]] + self.dtype = self.torch_dtypes[state[1]] self.size = torch.Size(state[2]) From ee60d7815d17a99502f251d0e2ede1ab19c1583d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 May 2024 16:38:06 -0700 Subject: [PATCH 17/29] no need to broadcast keys --- vllm/distributed/communication_op.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index d23fa522bbe..4f80e47995f 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -159,25 +159,31 @@ def broadcast_object_list(obj_list: List[Any], def _split_tensor_dict( - tensor_dict: Dict[Any, Union[torch.Tensor, Any]] -) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + tensor_dict: Dict[str, Union[torch.Tensor, Any]], + keys: Optional[List[str]] = None, +) -> Tuple[List[Any], List[torch.Tensor]]: """Split the tensor dictionary into two parts: 1. A list of (key, value) pairs. If the value is a tensor, it is replaced - by its metadata. + by its metadata. If keys are provided, only return the value. 2. A list of tensors. + + `keys` is used to specify the keys to be included in the metadata list, + which can make sure the order of the metadata list is consistent across + different ranks. """ from vllm import TensorMeta # import here to avoid circular import metadata_list = [] tensor_list = [] - for key, value in tensor_dict.items(): + used_keys = keys or tensor_dict.keys() + for key in used_keys: + value = tensor_dict[key] if isinstance(value, torch.Tensor): # Note: we cannot use `value.device` here, # because it contains not only the device type but also the device # index (e.g. "cuda:0"). We only need the device type. # receiving side will set the device index. device = "cpu" if value.is_cpu else "cuda" - metadata_list.append( - (key, TensorMeta(device, value.dtype, value.size()))) + metadata_list.append(TensorMeta(device, value.dtype, value.size())) tensor_list.append(value) else: metadata_list.append((key, value)) @@ -282,13 +288,15 @@ def broadcast_tensor_dict( assert isinstance( tensor_dict, dict), (f"Expecting a dictionary, got {type(tensor_dict)}") - metadata_list, tensor_list = _split_tensor_dict(tensor_dict) if cls is not None: + metadata_list, tensor_list = _split_tensor_dict(tensor_dict, + keys=cls.fields) s = pickle.dumps(metadata_list) cls.buffer_tensor[:len(s)].copy_( torch.frombuffer(s, dtype=torch.uint8)) dist.broadcast(cls.buffer_tensor, src=src, group=metadata_group) else: + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) # `metadata_list` lives in CPU memory. # `broadcast_object_list` involves serialization and # deserialization, all happening on CPU. Therefore, @@ -327,7 +335,8 @@ def broadcast_tensor_dict( assert recv_metadata_list is not None else: dist.broadcast(cls.buffer_tensor, src=src, group=metadata_group) - recv_metadata_list = pickle.loads(memoryview(cls.buffer)) + recv_value_list = pickle.loads(memoryview(cls.buffer)) + recv_metadata_list = list(zip(cls.fields, recv_value_list)) tensor_dict = {} async_handles = [] for key, value in recv_metadata_list: From 6969587a5197f4ceff7edb42b6b44ae930b77e2d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 May 2024 16:45:06 -0700 Subject: [PATCH 18/29] fix key --- vllm/distributed/communication_op.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 4f80e47995f..d13a2cc69f7 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -183,10 +183,13 @@ def _split_tensor_dict( # index (e.g. "cuda:0"). We only need the device type. # receiving side will set the device index. device = "cpu" if value.is_cpu else "cuda" - metadata_list.append(TensorMeta(device, value.dtype, value.size())) + metadata_list.append( + (key, TensorMeta(device, value.dtype, value.size()))) tensor_list.append(value) else: metadata_list.append((key, value)) + if keys is not None: + metadata_list = [value for key, value in metadata_list] return metadata_list, tensor_list From 81586671c707963ba68f2c08ac1818565cdd490b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 May 2024 16:52:04 -0700 Subject: [PATCH 19/29] fix buffer size calculation --- vllm/distributed/communication_op.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index d13a2cc69f7..a305de50cd3 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -215,7 +215,9 @@ class TensorDictWithBoundedMetadata: @classmethod def get_max_buffer_size_for_metadata(cls): metadata_list = cls.get_example_metadata_list() - metadata_list_bytes = pickle.dumps(metadata_list) + # Note: we only need the values of the metadata list. + values = [value for key, value in metadata_list] + metadata_list_bytes = pickle.dumps(values) ALIGN_BYTES = 256 return ((len(metadata_list_bytes) + ALIGN_BYTES - 1) // ALIGN_BYTES) * ALIGN_BYTES From 52a59ecb28f79fec99996e727450b4814f7475e9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 May 2024 17:00:52 -0700 Subject: [PATCH 20/29] use smaller align bytes --- vllm/distributed/communication_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index a305de50cd3..34b7b5ca146 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -218,7 +218,7 @@ def get_max_buffer_size_for_metadata(cls): # Note: we only need the values of the metadata list. values = [value for key, value in metadata_list] metadata_list_bytes = pickle.dumps(values) - ALIGN_BYTES = 256 + ALIGN_BYTES = 128 return ((len(metadata_list_bytes) + ALIGN_BYTES - 1) // ALIGN_BYTES) * ALIGN_BYTES From 89bd1ec110d442ce59a8913767792cc926122fc5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 May 2024 17:36:13 -0700 Subject: [PATCH 21/29] assert torch dtype initialized --- vllm/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/__init__.py b/vllm/__init__.py index fed96903265..a12632f2a86 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -26,6 +26,9 @@ class TensorMeta: dtype: torch.dtype size: torch.Size + # this is a hack to make sure that torch.dtype is initialized + assert isinstance(torch.dtype, type) + torch_dtypes = [ getattr(torch, attr) for attr in dir(torch) if isinstance(getattr(torch, attr), torch.dtype) From d0a43ef5d2be091374d34f4d785971fb7b504a82 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 12 May 2024 12:39:21 -0700 Subject: [PATCH 22/29] fix torch dtype --- vllm/__init__.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/__init__.py b/vllm/__init__.py index a12632f2a86..79a70ffc592 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -26,12 +26,10 @@ class TensorMeta: dtype: torch.dtype size: torch.Size - # this is a hack to make sure that torch.dtype is initialized - assert isinstance(torch.dtype, type) - torch_dtypes = [ - getattr(torch, attr) for attr in dir(torch) - if isinstance(getattr(torch, attr), torch.dtype) + torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8, + torch.uint16, torch.uint32, torch.uint64, torch.float16, torch.float32, + torch.float64, torch.bfloat16 ] dtype_map = {dtype: i for i, dtype in enumerate(torch_dtypes)} From 62ac9621ff8dc1eaf5a28268a0049d4298b19f9a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 12 May 2024 13:10:27 -0700 Subject: [PATCH 23/29] use str --- vllm/__init__.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/__init__.py b/vllm/__init__.py index 79a70ffc592..8d1a9abb087 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -26,19 +26,21 @@ class TensorMeta: dtype: torch.dtype size: torch.Size + # use string to avoid torch lazy import issues + # sometimes `torch.int8` is not available at bootstrapping time torch_dtypes = [ - torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8, - torch.uint16, torch.uint32, torch.uint64, torch.float16, torch.float32, - torch.float64, torch.bfloat16 + "torch.int8", "torch.int16", "torch.int32", "torch.int64", + "torch.uint8", "torch.uint16", "torch.uint32", "torch.uint64", + "torch.float16", "torch.float32", "torch.float64", "torch.bfloat16" ] dtype_map = {dtype: i for i, dtype in enumerate(torch_dtypes)} def __getstate__(self): - return [self.device, self.dtype_map[self.dtype], tuple(self.size)] + return [self.device, self.dtype_map[str(self.dtype)], tuple(self.size)] def __setstate__(self, state): self.device = state[0] - self.dtype = self.torch_dtypes[state[1]] + self.dtype = eval(self.torch_dtypes[state[1]]) self.size = torch.Size(state[2]) From 59f094acd4e6c39a2d299e6fb8eeffda3aea544b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 12 May 2024 18:11:46 -0700 Subject: [PATCH 24/29] fix merge --- vllm/worker/worker.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 1ee33f7ef89..ee54172e8b5 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -15,8 +15,6 @@ init_distributed_environment, set_custom_all_reduce) from vllm.distributed.communication_op import TensorDictWithBoundedMetadata -from vllm.distributed.device_communicators.custom_all_reduce import ( - init_custom_ar) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput From 5c9f0e90b106718358da1c4beccd328bca28303a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 12 May 2024 18:42:00 -0700 Subject: [PATCH 25/29] update tests --- tests/distributed/test_comm_ops.py | 39 ++++++++++++++++-------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index bd90233ee3b..df0580fc5fd 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -106,33 +106,36 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, assert torch.allclose(recv_dict["f"], test_dict["f"]) -class CustomData(TensorDictWithBoundedMetadata): - - def __init__(self, a, b): - self.a = a - self.b = b - - fields = ["a", "b"] - - @classmethod - def get_example_metadata_list(cls): - return [ - ("a", TensorMeta("cuda", torch.float32, torch.Size([]))), - ("b", TensorMeta("cpu", torch.float32, torch.Size([]))), - ] - - @ray.remote(num_gpus=1, max_calls=1) -def fast_broadcast_tensor_dict_test_worker(tensor_parallel_size: int, +def fast_broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, distributed_init_port: str): # it is important to delete the CUDA_VISIBLE_DEVICES environment variable # so that each worker can see all the GPUs # they will be able to set the device to the correct GPU del os.environ["CUDA_VISIBLE_DEVICES"] + + # Note: it is important to define the custom data class in the worker + # the class definition might initialize torch/cuda, and might read + # environment variables CUDA_VISIBLE_DEVICES + class CustomData(TensorDictWithBoundedMetadata): + + def __init__(self, a, b): + self.a = a + self.b = b + + fields = ["a", "b"] + + @classmethod + def get_example_metadata_list(cls): + return [ + ("a", TensorMeta("cuda", torch.float32, torch.Size([]))), + ("b", TensorMeta("cpu", torch.float32, torch.Size([]))), + ] + device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, tensor_parallel_size, rank, + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) test_dict = { From 7f8ce07649bb91d5353d9f54eec4d4b15f58becb Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 13 May 2024 11:09:42 -0700 Subject: [PATCH 26/29] type annotation for get_example_metadata_list --- tests/distributed/test_comm_ops.py | 2 +- vllm/distributed/communication_op.py | 16 +++++++++------- vllm/worker/worker.py | 4 ++-- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index df0580fc5fd..e8b39ef157f 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -127,7 +127,7 @@ def __init__(self, a, b): fields = ["a", "b"] @classmethod - def get_example_metadata_list(cls): + def get_example_metadata_list(cls) -> List[Tuple[str, Any]]: return [ ("a", TensorMeta("cuda", torch.float32, torch.Size([]))), ("b", TensorMeta("cpu", torch.float32, torch.Size([]))), diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 67dc6a1a193..e33da177977 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -256,13 +256,15 @@ def __init__(self): pass @classmethod - def get_example_metadata_list(cls): - # Note: in general, if the example data contains cuda tensor, - # use cpu tensor here to avoid creating cuda context during - # the initialization of the class. The estimation of the buffer size - # might be inaccurate (by one byte per field), but it is fine because - # the buffer size will be aligned to 256 bytes. - return {} + def get_example_metadata_list(cls) -> List[Tuple[str, Any]]: + """ + Return an example metadata list for the fields. The format is + a list of (key, value) pairs. We use list rather than dict to + make sure the order is consistent across different ranks. + If the value is a normal Python object, leave it as is. If the value + is a tensor, replace it with its metadata. + """ + return [] # ===== subclass overrides ends ===== # for type annotation diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index ee54172e8b5..9e571a38a90 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,7 +1,7 @@ """A GPU worker class.""" import gc import os -from typing import List, Optional, Set, Tuple, Union +from typing import List, Optional, Set, Tuple, Union, Any import torch import torch.distributed @@ -43,7 +43,7 @@ def __init__(self, num_seq_groups: int, blocks_to_swap_in: torch.Tensor, ] @classmethod - def get_example_metadata_list(cls): + def get_example_metadata_list(cls) -> List[Tuple[str, Any]]: return [ ("num_seq_groups", 1), ("blocks_to_swap_in", From 4ae2b3a59902a8e4a8e1e09972fbc5e91a79ccfc Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 13 May 2024 11:10:08 -0700 Subject: [PATCH 27/29] type annotation for get_example_metadata_list --- tests/distributed/test_comm_ops.py | 1 + vllm/worker/worker.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index e8b39ef157f..e0981445223 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -3,6 +3,7 @@ Run `pytest tests/distributed/test_comm_ops.py`. """ import os +from typing import Any, List, Tuple import pytest import ray diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 9e571a38a90..5ae75943311 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,7 +1,7 @@ """A GPU worker class.""" import gc import os -from typing import List, Optional, Set, Tuple, Union, Any +from typing import Any, List, Optional, Set, Tuple, Union import torch import torch.distributed From 8de5ba2f49462c451f88d53f44539167c1cfc6eb Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 13 May 2024 11:14:42 -0700 Subject: [PATCH 28/29] add assert --- vllm/distributed/communication_op.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index e33da177977..873da701a9b 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -322,6 +322,9 @@ def broadcast_tensor_dict( metadata_list, tensor_list = _split_tensor_dict(tensor_dict, keys=cls.fields) s = pickle.dumps(metadata_list) + assert len(s) <= cls.size_upper_bound, ( + f"Object size after serialization {len(s)} exceeds the upper" + f" bound {cls.size_upper_bound}") cls.buffer_tensor[:len(s)].copy_( torch.frombuffer(s, dtype=torch.uint8)) dist.broadcast(cls.buffer_tensor, src=src, group=metadata_group) From 70c36646e24b1da557677ae8064c5e07889b3490 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 13 May 2024 12:00:42 -0700 Subject: [PATCH 29/29] fix name error --- tests/distributed/test_comm_ops.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index e0981445223..df0580fc5fd 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -3,7 +3,6 @@ Run `pytest tests/distributed/test_comm_ops.py`. """ import os -from typing import Any, List, Tuple import pytest import ray @@ -128,7 +127,7 @@ def __init__(self, a, b): fields = ["a", "b"] @classmethod - def get_example_metadata_list(cls) -> List[Tuple[str, Any]]: + def get_example_metadata_list(cls): return [ ("a", TensorMeta("cuda", torch.float32, torch.Size([]))), ("b", TensorMeta("cpu", torch.float32, torch.Size([]))),