Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/distributed/test_custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):

for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
with graph_capture():
with graph_capture() as graph_capture_context:
# use integers so result matches NCCL exactly
inp1 = torch.randint(1,
16, (sz, ),
Expand All @@ -62,7 +62,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
device=torch.cuda.current_device())
torch.cuda.synchronize()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
with torch.cuda.graph(graph,
stream=graph_capture_context.stream):
for i in range(num_communication):
out1 = tensor_model_parallel_all_reduce(inp1)
# the input buffer is immediately modified to test
Expand Down
4 changes: 2 additions & 2 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from vllm.distributed.communication_op import ( # noqa
graph_mode, tensor_model_parallel_all_reduce)
graph_capture, tensor_model_parallel_all_reduce)
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
Expand Down Expand Up @@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
ensure_model_parallel_initialized(2, 2)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
with graph_mode():
with graph_capture():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, how do we make sure it's not using custom all reduce?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, even before this PR, we cannot make sure it's not using custom all reduce. It is true in CI because our CI does not have custom allreduce.

To solve this problem, another refactor is needed. We need to expose a new function to create tp groups with different communicators. That's my next PR to come!

# two tp groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
tensor = tensor_model_parallel_all_reduce(tensor)
Expand Down
70 changes: 40 additions & 30 deletions vllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
Expand All @@ -13,45 +14,54 @@
get_tp_pynccl_communicator)


@contextmanager
def graph_mode():
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the tensor size
# is too large, it will fallback to the next available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or pynccl if it is disabled or not supported.
pynccl_comm = get_tp_pynccl_communicator()
if pynccl_comm is None:
context = nullcontext()
else:
context = pynccl_comm.change_state(enable=True,
stream=torch.cuda.current_stream())
with context:
yield
@dataclass
class GraphCaptureContext:
stream: torch.cuda.Stream
Comment on lines +17 to +19
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this work for non-CUDA backends?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For XPU, this will be torch.xpu.Stream .



@contextmanager
def graph_capture():
"""
`graph_capture` is a context manager which should include the code that
`graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed.
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current CUDA stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
stream = torch.cuda.Stream()
graph_capture_context = GraphCaptureContext(stream)
ca_comm = get_tp_ca_communicator()
context = nullcontext() if ca_comm is None else ca_comm.capture()
with context:
yield
maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
with torch.cuda.stream(stream), maybe_ca_context:
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the tensor
# size is too large, it will fallback to the next available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or pynccl if it is disabled or not supported.
pynccl_comm = get_tp_pynccl_communicator()
if pynccl_comm is None:
maybe_pynccl_context = nullcontext()
else:
maybe_pynccl_context = pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream())
with maybe_pynccl_context:
yield graph_capture_context


def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
Expand Down
38 changes: 18 additions & 20 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.distributed.communication_op import graph_capture, graph_mode
from vllm.distributed.communication_op import graph_capture
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -827,7 +827,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
]

with graph_capture():
with graph_capture() as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list):
Expand Down Expand Up @@ -863,6 +863,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
kv_caches,
attn_metadata,
memory_pool=self.graph_memory_pool,
stream=graph_capture_context.stream,
)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[batch_size] = graph_runner
Expand Down Expand Up @@ -907,15 +908,27 @@ def capture(
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
memory_pool,
memory_pool: Optional[Tuple[int, int]],
stream: torch.cuda.Stream,
**kwargs,
) -> None:
assert self._graph is None
# Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
with graph_mode():
self.model(
self.model(
input_ids,
positions,
kv_caches,
attn_metadata,
**kwargs,
)
torch.cuda.synchronize()

# Capture the graph.
self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
hidden_states = self.model(
input_ids,
positions,
kv_caches,
Expand All @@ -924,21 +937,6 @@ def capture(
)
torch.cuda.synchronize()

# Capture the graph.
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
with graph_mode():
hidden_states = self.model(
input_ids,
positions,
kv_caches,
attn_metadata,
**kwargs,
)
torch.cuda.synchronize()

# Save the input and output buffers.
self.input_buffers = {
"input_ids": input_ids,
Expand Down