diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 9a7a1f07e1b..a4423bbfddf 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -16,7 +16,7 @@ @ray.remote(num_gpus=1, max_calls=1) -def all_reduce_test_worker(tensor_parallel_size: int, rank: int, +def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int, distributed_init_port: str): # it is important to delete the CUDA_VISIBLE_DEVICES environment variable # so that each worker can see all the GPUs @@ -24,12 +24,12 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int, del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, tensor_parallel_size, rank, + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) num_elements = 8 all_tensors = [ torch.arange(num_elements, dtype=torch.float32, device="cuda") * - (r + 1) for r in range(tensor_parallel_size) + (r + 1) for r in range(tp_size) ] expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0) t = all_tensors[rank] @@ -38,7 +38,7 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int, @ray.remote(num_gpus=1, max_calls=1) -def all_gather_test_worker(tensor_parallel_size: int, rank: int, +def all_gather_test_worker(tp_size: int, pp_size: int, rank: int, distributed_init_port: str): # it is important to delete the CUDA_VISIBLE_DEVICES environment variable # so that each worker can see all the GPUs @@ -46,7 +46,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, tensor_parallel_size, rank, + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) num_dimensions = 3 tensor_size = list(range(2, num_dimensions + 2)) @@ -57,7 +57,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, all_tensors = [ torch.arange(total_size, dtype=torch.float32, device="cuda").reshape(tensor_size) * (r + 1) - for r in range(tensor_parallel_size) + for r in range(tp_size) ] expected = torch.cat(all_tensors, dim=all_gather_dimension) t = all_tensors[rank] @@ -66,7 +66,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, @ray.remote(num_gpus=1, max_calls=1) -def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, +def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, distributed_init_port: str): # it is important to delete the CUDA_VISIBLE_DEVICES environment variable # so that each worker can see all the GPUs @@ -74,7 +74,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, tensor_parallel_size, rank, + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) test_dict = { # device tensor @@ -106,10 +106,10 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize("tensor_parallel_size", [2]) +@pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("test_target", [ all_reduce_test_worker, all_gather_test_worker, broadcast_tensor_dict_test_worker ]) -def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): - multi_process_tensor_parallel(tensor_parallel_size, test_target) +def test_multi_process_tensor_parallel(tp_size, test_target): + multi_process_tensor_parallel(tp_size, 1, test_target) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 308b874280f..bdca031e39b 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -6,8 +6,10 @@ import torch import torch.distributed as dist -from vllm.distributed import tensor_model_parallel_all_reduce -from vllm.distributed.device_communicators import custom_all_reduce +from vllm.distributed.communication_op import ( # noqa + graph_capture, tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, + get_tp_ca_communicator) from vllm.test_utils import (init_test_distributed_environment, multi_process_tensor_parallel) @@ -18,17 +20,36 @@ @ray.remote(num_gpus=1, max_calls=1) -def graph_allreduce(world_size, rank, distributed_init_port): +def graph_allreduce(tp_size, pp_size, rank, distributed_init_port): del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, world_size, rank, + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) - custom_all_reduce.init_custom_ar() + group = get_tensor_model_parallel_group() + + # A small all_reduce for warmup. + # this is needed because device communicators might be created lazily + # (e.g. NCCL). This will ensure that the communicator is initialized + # before any communication happens, so that this group can be used for + # graph capture immediately. + data = torch.zeros(1) + data = data.to(device=device) + torch.distributed.all_reduce(data, group=group) + torch.cuda.synchronize() + del data + + # we use the first group to communicate once + # and the second group to communicate twice + # and so on + # this is used to demonstrate that each group can + # communicate independently + num_communication = rank // tp_size + 1 + for sz in test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: - with custom_all_reduce.capture(): + with graph_capture(): # use integers so result matches NCCL exactly inp1 = torch.randint(1, 16, (sz, ), @@ -41,44 +62,52 @@ def graph_allreduce(world_size, rank, distributed_init_port): torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): - out1 = tensor_model_parallel_all_reduce(inp1) - # the input buffer is immediately modified to test - # synchronization - dist.all_reduce(inp1) - out2 = tensor_model_parallel_all_reduce(inp2) - dist.all_reduce(inp2) + for i in range(num_communication): + out1 = tensor_model_parallel_all_reduce(inp1) + # the input buffer is immediately modified to test + # synchronization + dist.all_reduce(inp1, group=group) + out2 = tensor_model_parallel_all_reduce(inp2) + dist.all_reduce(inp2, group=group) graph.replay() assert torch.allclose(out1, inp1) assert torch.allclose(out2, inp2) @ray.remote(num_gpus=1, max_calls=1) -def eager_allreduce(world_size, rank, distributed_init_port): +def eager_allreduce(tp_size, pp_size, rank, distributed_init_port): del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, world_size, rank, + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) + # we use the first group to communicate once + # and the second group to communicate twice + # and so on + # this is used to demonstrate that each group can + # communicate independently + num_communication = rank // tp_size + 1 sz = 1024 - custom_all_reduce.init_custom_ar() - fa = custom_all_reduce.get_handle() + fa = get_tp_ca_communicator() inp = torch.ones(sz, dtype=torch.float32, device=device) - out = fa.all_reduce_unreg(inp) - assert torch.allclose(out, inp * world_size) + out = inp + for _ in range(num_communication): + out = fa.all_reduce_unreg(out) + assert torch.allclose(out, inp * (tp_size**num_communication)) inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device) - out = fa.all_reduce_unreg(inp) - assert torch.allclose(out, inp * world_size) + out = inp + for _ in range(num_communication): + out = fa.all_reduce_unreg(out) + assert torch.allclose(out, inp * (tp_size**num_communication)) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize("tensor_parallel_size", [2]) +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("pipeline_parallel_size", [1, 2]) @pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce]) -def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): - multi_process_tensor_parallel(tensor_parallel_size, test_target) - - -if __name__ == "__main__": - multi_process_tensor_parallel(2, graph_allreduce) +def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target): + world_size = tp_size * pipeline_parallel_size + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs to run the test.") + multi_process_tensor_parallel(tp_size, pipeline_parallel_size, test_target) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index b3e30a04344..a0f7500bf0e 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -5,7 +5,7 @@ import torch from vllm.distributed.communication_op import ( # noqa - graph_capture_mode, tensor_model_parallel_all_reduce) + graph_mode, tensor_model_parallel_all_reduce) from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, @@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") ensure_model_parallel_initialized(2, 2) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - with graph_capture_mode(): + with graph_mode(): # two tp groups can communicate independently if torch.distributed.get_rank() in [0, 1]: tensor = tensor_model_parallel_all_reduce(tensor) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 32ab5694e53..9cc776f8324 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,5 +1,5 @@ from collections import namedtuple -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -9,12 +9,13 @@ get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_ca_communicator, get_tp_pynccl_communicator) @contextmanager -def graph_capture_mode(): - # In graph capture, we have to be very careful about the collective +def graph_mode(): + # In graph mode, we have to be very careful about the collective # operations. The current status is: # allreduce \ Mode | Eager | Graph | # -------------------------------------------- @@ -24,10 +25,32 @@ def graph_capture_mode(): # # Note that custom allreduce will have a runtime check, if the tensor size # is too large, it will fallback to the next available option. + # In summary: When using CUDA graph, we use + # either custom all-reduce kernel or pynccl. When not using CUDA + # graph, we use either custom all-reduce kernel or PyTorch NCCL. + # We always prioritize using custom all-reduce kernel but fall back + # to PyTorch or pynccl if it is disabled or not supported. pynccl_comm = get_tp_pynccl_communicator() - assert pynccl_comm is not None - with pynccl_comm.change_state(enable=True, - stream=torch.cuda.current_stream()): + if pynccl_comm is None: + context = nullcontext() + else: + context = pynccl_comm.change_state(enable=True, + stream=torch.cuda.current_stream()) + with context: + yield + + +@contextmanager +def graph_capture(): + """ + `graph_capture` is a context manager which should include the code that + is capturing the CUDA graph. Its main purpose is to ensure that the + some operations will be run after the graph is captured, before the graph + is replayed. + """ + ca_comm = get_tp_ca_communicator() + context = nullcontext() if ca_comm is None else ca_comm.capture() + with context: yield @@ -43,15 +66,15 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: TLDR: always assume this function modifies its input, but use the return value as the output. """ - from vllm.distributed.device_communicators.custom_all_reduce import ( - custom_all_reduce) + ca_comm = get_tp_ca_communicator() # Bypass the function if we are using only 1 GPU. if get_tensor_model_parallel_world_size() == 1: return input_ - out = custom_all_reduce(input_) - if out is not None: - return out + if ca_comm is not None: + out = ca_comm.custom_all_reduce(input_) + if out is not None: + return out pynccl_comm = get_tp_pynccl_communicator() if (pynccl_comm is not None and not pynccl_comm.disabled): pynccl_comm.all_reduce(input_) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 5d26254fb83..30ee9d1f8a1 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -1,154 +1,42 @@ from contextlib import contextmanager -from typing import Any, List, Optional +from typing import Any, List, Optional, Union import torch import torch.distributed as dist +from torch.distributed import ProcessGroup import vllm.envs as envs +from vllm.distributed.parallel_state import ( + get_local_rank, get_tensor_model_parallel_cpu_group) from vllm.logger import init_logger try: import pynvml from vllm._C import custom_ar + + @contextmanager + def _nvml(): + try: + pynvml.nvmlInit() + yield + finally: + pynvml.nvmlShutdown() + except ImportError: # For AMD GPUs custom_ar = None pynvml = None -logger = init_logger(__name__) + @contextmanager + def _nvml(): + try: + yield + finally: + pass -_CA_HANDLE: Optional["CustomAllreduce"] = None -_IS_CAPTURING = False -_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] - - -def init_custom_ar() -> None: - from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) - - global _CA_HANDLE - if _CA_HANDLE is not None: - return - rank = get_tensor_model_parallel_rank() - world_size = get_tensor_model_parallel_world_size() - if world_size == 1: - # No need to initialize custom allreduce for single GPU case. - return - - if world_size not in _SUPPORTED_WORLD_SIZES: - logger.warning( - "Custom allreduce is disabled due to an unsupported world size: " - "%d. Supported world sizes: %s. To silence this warning, specify" - " disable_custom_all_reduce=True explicitly.", world_size, - str(_SUPPORTED_WORLD_SIZES)) - return - num_dev = torch.cuda.device_count() - # note: num dev can be larger than world_size if we're only using - # first few GPUs - if num_dev < world_size: - logger.warning( - "Cannot test GPU P2P because not all GPUs are visible to the " - "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'" - " is set.") - return - - # we only use a subset of GPUs here - # so we only need to check the nvlink connectivity of these GPUs - num_dev = world_size - # test nvlink first, this will filter out most of the cases - # where custom allreduce is not supported - cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES - if cuda_visible_devices: - device_ids = list(map(int, cuda_visible_devices.split(","))) - else: - device_ids = list(range(num_dev)) - # this checks hardware and driver support for NVLink - full_nvlink = _is_full_nvlink(device_ids) - if world_size > 2 and not full_nvlink: - logger.warning( - "Custom allreduce is disabled because it's not supported on more" - " than two PCIe-only GPUs. To silence this warning, specify" - " disable_custom_all_reduce=True explicitly.") - return - # test P2P capability, this checks software/cudaruntime support - # this is expensive to compute at the first time - # then we cache the result - if not _can_p2p(rank, world_size): - logger.warning( - "Custom allreduce is disabled because your platform lacks GPU P2P" - " capability or P2P test failed. To silence this warning, specify" - " disable_custom_all_reduce=True explicitly.") - return - _CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink) - - -def begin_capture() -> None: - global _IS_CAPTURING - _IS_CAPTURING = True - - -def end_capture() -> None: - global _IS_CAPTURING - _IS_CAPTURING = False - - -def is_capturing() -> bool: - return _IS_CAPTURING and _CA_HANDLE is not None - - -def get_handle() -> Optional["CustomAllreduce"]: - return _CA_HANDLE - - -def is_initialized() -> bool: - return _CA_HANDLE is not None - - -@contextmanager -def capture(): - try: - begin_capture() - yield - finally: - end_capture() - handle = get_handle() - if handle is not None: - handle.register_graph_buffers() - - -def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]: - ca_handle = get_handle() - # when custom allreduce is disabled, this will be None - if ca_handle is None: - return None - if is_capturing(): - if torch.cuda.is_current_stream_capturing(): - if ca_handle.should_custom_ar(input): - return ca_handle.all_reduce_reg(input) - else: - if ca_handle.should_custom_ar(input): - # if warm up, mimic the allocation pattern - # since custom allreduce is out-of-place - return torch.empty_like(input) - else: - # note: outside of cuda graph context, - # custom allreduce incurs a cost of cudaMemcpy, which should - # be small(<=1% of overall latency) compared to the performance - # gains of using custom kernels - if ca_handle.should_custom_ar(input): - return ca_handle.all_reduce_unreg(input) - - return None - - -@contextmanager -def _nvml(): - try: - pynvml.nvmlInit() - yield - finally: - pynvml.nvmlShutdown() + +logger = init_logger(__name__) @_nvml() @@ -188,22 +76,112 @@ def _can_p2p(rank: int, world_size: int) -> bool: class CustomAllreduce: + _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] + # max_size: max supported allreduce size def __init__(self, - rank, - world_size, - full_nvlink, + group: Optional[ProcessGroup] = None, + device: Optional[Union[int, str, torch.device]] = None, max_size=8192 * 1024) -> None: + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the CustomAllreduce to. If None, + it will be bind to f"cuda:{local_rank}". + It is the caller's responsibility to make sure each communicator + is bind to a unique device, and all communicators in this group + are in the same node. + """ + self._IS_CAPTURING = False + self.disabled = True + + if custom_ar is None: + # disable because of missing custom allreduce library + # e.g. in a non-cuda environment + return + + group = group or get_tensor_model_parallel_cpu_group() + self.group = group + + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "CustomAllreduce should be attached to a non-NCCL group.") + + rank = dist.get_rank(group=self.group) + world_size = dist.get_world_size(group=self.group) + if world_size == 1: + # No need to initialize custom allreduce for single GPU case. + return + + if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES: + logger.warning( + "Custom allreduce is disabled due to an unsupported world" + " size: %d. Supported world sizes: %s. To silence this " + "warning, specify disable_custom_all_reduce=True explicitly.", + world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES)) + return + + if device is None: + local_rank = get_local_rank() + device = torch.device(f"cuda:{local_rank}") + elif isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices: + device_ids = list(map(int, cuda_visible_devices.split(","))) + else: + device_ids = list(range(torch.cuda.device_count())) + + physical_device_id = device_ids[device.index] + tensor = torch.tensor([physical_device_id], + dtype=torch.int, + device="cpu") + gather_list = [ + torch.tensor([0], dtype=torch.int, device="cpu") + for _ in range(world_size) + ] + dist.all_gather(gather_list, tensor, group=self.group) + physical_device_ids = [t.item() for t in gather_list] + + # test nvlink first, this will filter out most of the cases + # where custom allreduce is not supported + # this checks hardware and driver support for NVLink + full_nvlink = _is_full_nvlink(physical_device_ids) + if world_size > 2 and not full_nvlink: + logger.warning( + "Custom allreduce is disabled because it's not supported on" + " more than two PCIe-only GPUs. To silence this warning, " + "specify disable_custom_all_reduce=True explicitly.") + return + # test P2P capability, this checks software/cudaruntime support + # this is expensive to compute at the first time + # then we cache the result + if not _can_p2p(rank, world_size): + logger.warning( + "Custom allreduce is disabled because your platform lacks " + "GPU P2P capability or P2P test failed. To silence this " + "warning, specify disable_custom_all_reduce=True explicitly.") + return + + self.disabled = False # buffers memory are owned by this Python class and passed to C++ # meta data composes of two parts: meta data for synchronization # (256 bytes) and a temporary buffer for storing intermediate # allreduce results. self.meta = torch.zeros(custom_ar.meta_size() + max_size, dtype=torch.uint8, - device="cuda") + device=self.device) # This is a pre-registered IPC buffer. In eager mode, input tensors # are first copied into this buffer before allreduce is performed - self.buffer = torch.empty(max_size, dtype=torch.uint8, device="cuda") + self.buffer = torch.empty(max_size, + dtype=torch.uint8, + device=self.device) # This is a buffer for storing the tuples of pointers pointing to # IPC buffers from all ranks. Each registered tuple has size of # 8*world_size bytes where world_size is at most 8. Allocating 8MB @@ -211,8 +189,9 @@ def __init__(self, # needs less than 10000 of registered tuples. self.rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, - device="cuda") + device=self.device) self.max_size = max_size + self.rank = rank self.world_size = world_size handles, offsets = self._get_ipc_meta(self.meta) self.full_nvlink = full_nvlink @@ -221,6 +200,21 @@ def __init__(self, self.full_nvlink) self.register_buffer(self.buffer) + @contextmanager + def capture(self): + """ + The main responsibility of this context manager is the + `register_graph_buffers` call at the end of the context. + It records all the buffer addresses used in the CUDA graph. + """ + try: + self._IS_CAPTURING = True + yield + finally: + self._IS_CAPTURING = False + if not self.disabled: + self.register_graph_buffers() + def _get_ipc_meta(self, inp: torch.Tensor): data = inp.untyped_storage()._share_cuda_() shard_data = ( @@ -230,14 +224,29 @@ def _get_ipc_meta(self, inp: torch.Tensor): return self._gather_ipc_meta(shard_data) def _gather_ipc_meta(self, shard_data): - all_data: List[Optional[Any]] = [None] * self.world_size - dist.all_gather_object(all_data, shard_data) + # Note: don't use `[[None]] * self.world_size` here + # because it will create a list of the same reference + all_data: List[Optional[Any]] = [[None] + for i in range(self.world_size)] + all_data[self.rank][0] = shard_data + + ranks = dist.get_process_group_ranks(group=self.group) + ranks.sort() + for i, rank in enumerate(ranks): + dist.broadcast_object_list(all_data[i], + src=rank, + group=self.group, + device="cpu") + + # we cannot directly use `dist.all_gather_object` here + # because it is incompatible with `gloo` backend under inference mode. + # see https://github.com/pytorch/pytorch/issues/126032 for details. handles = [] offsets = [] for i in range(len(all_data)): - handles.append(all_data[i][0]) # type: ignore - offsets.append(all_data[i][1]) # type: ignore + handles.append(all_data[i][0][0]) # type: ignore + offsets.append(all_data[i][0][1]) # type: ignore return handles, offsets def register_buffer(self, inp: torch.Tensor): @@ -269,8 +278,31 @@ def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out) return out + def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: + # when custom allreduce is disabled, this will be None + if self.disabled: + return None + if self._IS_CAPTURING: + if torch.cuda.is_current_stream_capturing(): + if self.should_custom_ar(input): + return self.all_reduce_reg(input) + else: + if self.should_custom_ar(input): + # if warm up, mimic the allocation pattern + # since custom allreduce is out-of-place + return torch.empty_like(input) + else: + # note: outside of cuda graph context, + # custom allreduce incurs a cost of cudaMemcpy, which should + # be small(<=1% of overall latency) compared to the performance + # gains of using custom kernels + if self.should_custom_ar(input): + return self.all_reduce_unreg(input) + + return None + def close(self): - if self._ptr: + if not self.disabled and self._ptr: custom_ar.dispose(self._ptr) self._ptr = 0 diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 168d4cc2df8..092a0910329 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -96,8 +96,10 @@ def __init__( self.stream = torch.cuda.Stream() # A small all_reduce for warmup. - self.all_reduce(torch.zeros(1, device=device)) + data = torch.zeros(1, device=device) + self.all_reduce(data) self.stream.synchronize() + del data # by default it is disabled, e.g. in profiling models and prefill phase. # to use it, use under `with obj.change_state(enable=True)`, usually diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 5075da11bb1..d24104e3ed2 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -13,10 +13,13 @@ logger = init_logger(__name__) +_ENABLE_CUSTOM_ALL_REDUCE = True + # Tensor model parallel group that the current rank belongs to. _TP_DEVICE_GROUP: Optional[ProcessGroup] = None _TP_CPU_GROUP: Optional[ProcessGroup] = None _TP_PYNCCL_COMMUNICATOR = None +_TP_CA_COMMUNICATOR = None # Pipeline model parallel group that the current rank belongs to. _PP_DEVICE_GROUP: Optional[ProcessGroup] = None @@ -47,11 +50,21 @@ _LOCAL_RANK = -1 +def set_custom_all_reduce(enable: bool): + global _ENABLE_CUSTOM_ALL_REDUCE + _ENABLE_CUSTOM_ALL_REDUCE = enable + + def get_tp_pynccl_communicator(): global _TP_PYNCCL_COMMUNICATOR return _TP_PYNCCL_COMMUNICATOR +def get_tp_ca_communicator(): + global _TP_CA_COMMUNICATOR + return _TP_CA_COMMUNICATOR + + def get_local_rank(): global _LOCAL_RANK return _LOCAL_RANK @@ -100,6 +113,9 @@ def init_distributed_environment( if torch.cuda.is_available(): data = data.to(device=f"cuda:{local_rank}") torch.distributed.all_reduce(data) + if torch.cuda.is_available(): + torch.cuda.synchronize() + del data def initialize_model_parallel( @@ -149,7 +165,8 @@ def initialize_model_parallel( rank = torch.distributed.get_rank() # Build the tensor model-parallel groups. - global _TP_DEVICE_GROUP, _TP_CPU_GROUP, _TP_PYNCCL_COMMUNICATOR + global _TP_DEVICE_GROUP, _TP_CPU_GROUP + global _TP_PYNCCL_COMMUNICATOR, _TP_CA_COMMUNICATOR assert _TP_DEVICE_GROUP is None, ( "tensor model parallel group is already initialized") for i in range(num_tensor_model_parallel_groups): @@ -168,6 +185,15 @@ def initialize_model_parallel( device=_LOCAL_RANK, ) + # Initialize a custom fast all-reduce implementation. + if _ENABLE_CUSTOM_ALL_REDUCE: + from vllm.distributed.device_communicators.custom_all_reduce import ( + CustomAllreduce) + _TP_CA_COMMUNICATOR = CustomAllreduce( + group=_TP_CPU_GROUP, + device=_LOCAL_RANK, + ) + # Build the pipeline model-parallel groups. global _PP_DEVICE_GROUP global _PP_GLOBAL_RANKS diff --git a/vllm/test_utils.py b/vllm/test_utils.py index 0cf23e4bb7e..addd8ec1c26 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -6,24 +6,24 @@ def init_test_distributed_environment( - pipeline_parallel_size: int, - tensor_parallel_size: int, + tp_size: int, + pp_size: int, rank: int, distributed_init_port: str, local_rank: int = -1, ) -> None: distributed_init_method = f"tcp://localhost:{distributed_init_port}" init_distributed_environment( - world_size=pipeline_parallel_size * tensor_parallel_size, + world_size=pp_size * tp_size, rank=rank, distributed_init_method=distributed_init_method, local_rank=local_rank) - ensure_model_parallel_initialized(tensor_parallel_size, - pipeline_parallel_size) + ensure_model_parallel_initialized(tp_size, pp_size) def multi_process_tensor_parallel( - tensor_parallel_size: int, + tp_size: int, + pp_size: int, test_target, ) -> None: # Using ray helps debugging the error when it failed @@ -32,10 +32,9 @@ def multi_process_tensor_parallel( distributed_init_port = get_open_port() refs = [] - for rank in range(tensor_parallel_size): + for rank in range(tp_size * pp_size): refs.append( - test_target.remote(tensor_parallel_size, rank, - distributed_init_port)) + test_target.remote(tp_size, pp_size, rank, distributed_init_port)) ray.get(refs) ray.shutdown() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 21d76fd531e..f46b475bdc2 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -12,8 +12,7 @@ ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict -from vllm.distributed.communication_op import graph_capture_mode -from vllm.distributed.device_communicators import custom_all_reduce +from vllm.distributed.communication_op import graph_capture, graph_mode from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -942,13 +941,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] - # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce - # kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use - # either custom all-reduce kernel or pynccl. When not using CUDA - # graph, we use either custom all-reduce kernel or PyTorch NCCL. - # We always prioritize using custom all-reduce kernel but fall back - # to PyTorch or pynccl if it is disabled or not supported. - with custom_all_reduce.capture(): + with graph_capture(): # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): @@ -1040,7 +1033,7 @@ def capture( # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - with graph_capture_mode(): + with graph_mode(): self.model( input_ids, positions, @@ -1055,7 +1048,7 @@ def capture( # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117 - with graph_capture_mode(): + with graph_mode(): hidden_states = self.model( input_ids, positions, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index e4fbc877b8c..82cf58101a9 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -11,9 +11,8 @@ VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.distributed.device_communicators.custom_all_reduce import ( - init_custom_ar) + init_distributed_environment, + set_custom_all_reduce) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput @@ -302,16 +301,14 @@ def init_worker_distributed_environment( local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" + set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) - # Initialize a custom fast all-reduce implementation. - if not parallel_config.disable_custom_all_reduce: - init_custom_ar() - def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype.