diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 630b25a3b613..186f9faa6bfb 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -50,7 +50,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port): for sz in test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: - with graph_capture(): + with graph_capture() as graph_capture_context: # use integers so result matches NCCL exactly inp1 = torch.randint(1, 16, (sz, ), @@ -62,7 +62,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port): device=torch.cuda.current_device()) torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph): + with torch.cuda.graph(graph, + stream=graph_capture_context.stream): for i in range(num_communication): out1 = tensor_model_parallel_all_reduce(inp1) # the input buffer is immediately modified to test diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index a0f7500bf0ee..529e75fb2c9e 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -5,7 +5,7 @@ import torch from vllm.distributed.communication_op import ( # noqa - graph_mode, tensor_model_parallel_all_reduce) + graph_capture, tensor_model_parallel_all_reduce) from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, @@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") ensure_model_parallel_initialized(2, 2) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - with graph_mode(): + with graph_capture(): # two tp groups can communicate independently if torch.distributed.get_rank() in [0, 1]: tensor = tensor_model_parallel_all_reduce(tensor) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 9cc776f8324f..a2f2a1681c7b 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,5 +1,6 @@ from collections import namedtuple from contextlib import contextmanager, nullcontext +from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -13,45 +14,54 @@ get_tp_pynccl_communicator) -@contextmanager -def graph_mode(): - # In graph mode, we have to be very careful about the collective - # operations. The current status is: - # allreduce \ Mode | Eager | Graph | - # -------------------------------------------- - # custom allreduce | enabled | enabled | - # PyNccl | disabled| enabled | - # torch.distributed | enabled | disabled| - # - # Note that custom allreduce will have a runtime check, if the tensor size - # is too large, it will fallback to the next available option. - # In summary: When using CUDA graph, we use - # either custom all-reduce kernel or pynccl. When not using CUDA - # graph, we use either custom all-reduce kernel or PyTorch NCCL. - # We always prioritize using custom all-reduce kernel but fall back - # to PyTorch or pynccl if it is disabled or not supported. - pynccl_comm = get_tp_pynccl_communicator() - if pynccl_comm is None: - context = nullcontext() - else: - context = pynccl_comm.change_state(enable=True, - stream=torch.cuda.current_stream()) - with context: - yield +@dataclass +class GraphCaptureContext: + stream: torch.cuda.Stream @contextmanager def graph_capture(): """ - `graph_capture` is a context manager which should include the code that + `graph_capture` is a context manager which should surround the code that is capturing the CUDA graph. Its main purpose is to ensure that the some operations will be run after the graph is captured, before the graph - is replayed. + is replayed. It returns a `GraphCaptureContext` object which contains the + necessary data for the graph capture. Currently, it only contains the + stream that the graph capture is running on. This stream is set to the + current CUDA stream when the context manager is entered and reset to the + default stream when the context manager is exited. This is to ensure that + the graph capture is running on a separate stream from the default stream, + in order to explicitly distinguish the kernels to capture + from other kernels possibly launched on background in the default stream. """ + stream = torch.cuda.Stream() + graph_capture_context = GraphCaptureContext(stream) ca_comm = get_tp_ca_communicator() - context = nullcontext() if ca_comm is None else ca_comm.capture() - with context: - yield + maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture() + with torch.cuda.stream(stream), maybe_ca_context: + # In graph mode, we have to be very careful about the collective + # operations. The current status is: + # allreduce \ Mode | Eager | Graph | + # -------------------------------------------- + # custom allreduce | enabled | enabled | + # PyNccl | disabled| enabled | + # torch.distributed | enabled | disabled| + # + # Note that custom allreduce will have a runtime check, if the tensor + # size is too large, it will fallback to the next available option. + # In summary: When using CUDA graph, we use + # either custom all-reduce kernel or pynccl. When not using CUDA + # graph, we use either custom all-reduce kernel or PyTorch NCCL. + # We always prioritize using custom all-reduce kernel but fall back + # to PyTorch or pynccl if it is disabled or not supported. + pynccl_comm = get_tp_pynccl_communicator() + if pynccl_comm is None: + maybe_pynccl_context = nullcontext() + else: + maybe_pynccl_context = pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream()) + with maybe_pynccl_context: + yield graph_capture_context def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index dcdd4b962454..80a2269514d9 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -10,7 +10,7 @@ ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict -from vllm.distributed.communication_op import graph_capture, graph_mode +from vllm.distributed.communication_op import graph_capture from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -827,7 +827,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] - with graph_capture(): + with graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): @@ -863,6 +863,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: kv_caches, attn_metadata, memory_pool=self.graph_memory_pool, + stream=graph_capture_context.stream, ) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[batch_size] = graph_runner @@ -907,15 +908,27 @@ def capture( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - memory_pool, + memory_pool: Optional[Tuple[int, int]], + stream: torch.cuda.Stream, **kwargs, ) -> None: assert self._graph is None # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - with graph_mode(): - self.model( + self.model( + input_ids, + positions, + kv_caches, + attn_metadata, + **kwargs, + ) + torch.cuda.synchronize() + + # Capture the graph. + self._graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): + hidden_states = self.model( input_ids, positions, kv_caches, @@ -924,21 +937,6 @@ def capture( ) torch.cuda.synchronize() - # Capture the graph. - # NOTE(woosuk): Python 3.8 does not support multi-line with statements. - # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement - self._graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117 - with graph_mode(): - hidden_states = self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - **kwargs, - ) - torch.cuda.synchronize() - # Save the input and output buffers. self.input_buffers = { "input_ids": input_ids,