Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,16 @@ class EngineCoreOutputs(msgspec.Struct,
outputs: List[EngineCoreOutput]


@dataclass
class EngineCoreProfile:
is_start: bool


class EngineCoreRequestType(enum.Enum):
"""
Request types defined as hex byte strings, so it can be sent over sockets
without separate encoding step.
"""
ADD = b'\x00'
ABORT = b'\x01'
PROFILE = b'\x02'
4 changes: 2 additions & 2 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,10 @@ async def check_health(self) -> None:
logger.debug("Called check_health.")

async def start_profile(self) -> None:
raise ValueError("Not supported on V1 yet.")
await self.engine_core.profile(True)

async def stop_profile(self) -> None:
raise ValueError("Not supported on V1 yet.")
await self.engine_core.profile(False)

@property
def is_running(self) -> bool:
Expand Down
14 changes: 12 additions & 2 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import multiprocessing
import pickle
import queue
import threading
import time
Expand All @@ -16,7 +17,8 @@
from vllm.usage.usage_lib import UsageContext
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreRequest, EngineCoreRequestType)
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.mm_input_mapper import MMInputMapper
from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.request import Request, RequestStatus
Expand Down Expand Up @@ -126,6 +128,9 @@ def step(self) -> List[EngineCoreOutput]:
scheduler_output, output)
return engine_core_outputs

def profile(self, is_start=True):
self.model_executor.worker.profile(is_start)


class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""
Expand Down Expand Up @@ -312,11 +317,14 @@ def _log_stats(self):
self._last_logging_time = now

def _handle_client_request(
self, request: Union[EngineCoreRequest, List[str]]) -> None:
self, request: Union[EngineCoreRequest, EngineCoreProfile,
List[str]]) -> None:
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""

if isinstance(request, EngineCoreRequest):
self.add_request(request)
elif isinstance(request, EngineCoreProfile):
self.model_executor.worker.profile(request.is_start)
else:
# TODO: make an EngineCoreAbort wrapper
assert isinstance(request, list)
Expand All @@ -341,6 +349,8 @@ def process_input_socket(self, input_path: str):
request = decoder_add_req.decode(request_data)
elif request_type == EngineCoreRequestType.ABORT.value:
request = decoder_abort_req.decode(request_data)
elif request_type == EngineCoreRequestType.PROFILE.value:
request = pickle.loads(request_data)
else:
raise ValueError(f"Unknown RequestType: {request_type}")

Expand Down
28 changes: 23 additions & 5 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from vllm.logger import init_logger
from vllm.utils import get_open_zmq_ipc_path
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreRequest, EngineCoreRequestType)
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.serial_utils import PickleEncoder

Expand Down Expand Up @@ -58,6 +59,9 @@ def get_output(self) -> List[EngineCoreOutput]:
def add_request(self, request: EngineCoreRequest) -> None:
raise NotImplementedError

async def profile(self, is_start=True) -> None:
raise NotImplementedError

def abort_requests(self, request_ids: List[str]) -> None:
raise NotImplementedError

Expand Down Expand Up @@ -95,6 +99,9 @@ def add_request(self, request: EngineCoreRequest) -> None:
def abort_requests(self, request_ids: List[str]) -> None:
self.engine_core.abort_requests(request_ids)

async def profile(self, is_start=True) -> None:
self.engine_core.profile(is_start)


class MPClient(EngineCoreClient):
"""
Expand Down Expand Up @@ -177,8 +184,10 @@ def get_output(self) -> List[EngineCoreOutput]:
engine_core_outputs = self.decoder.decode(frame.buffer).outputs
return engine_core_outputs

def _send_input(self, request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, List[str]]) -> None:
def _send_input(
self, request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, EngineCoreProfile,
List[str]]) -> None:

# (RequestType, SerializedRequest)
msg = (request_type.value, self.encoder.encode(request))
Expand All @@ -190,6 +199,10 @@ def add_request(self, request: EngineCoreRequest) -> None:
def abort_requests(self, request_ids: List[str]) -> None:
self._send_input(EngineCoreRequestType.ABORT, request_ids)

async def profile(self, is_start=True) -> None:
self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start))


class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore."""
Expand All @@ -205,8 +218,9 @@ async def get_output_async(self) -> List[EngineCoreOutput]:
return engine_core_outputs

async def _send_input(
self, request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, List[str]]) -> None:
self, request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, EngineCoreProfile,
List[str]]) -> None:

msg = (request_type.value, self.encoder.encode(request))
await self.input_socket.send_multipart(msg, copy=False)
Expand All @@ -217,3 +231,7 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
async def abort_requests_async(self, request_ids: List[str]) -> None:
if len(request_ids) > 0:
await self._send_input(EngineCoreRequestType.ABORT, request_ids)

async def profile(self, is_start=True) -> None:
await self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start))
25 changes: 25 additions & 0 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.distributed

import vllm.envs as envs
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
Expand Down Expand Up @@ -56,6 +57,22 @@ def __init__(
init_cached_hf_modules()

self.model_runner = GPUModelRunner(vllm_config)
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None

def initialize(self):
if self.device_config.device.type == "cuda":
Expand Down Expand Up @@ -184,6 +201,14 @@ def execute_model(
# TODO(woosuk): Send the output to the engine process.
return output

def profile(self, is_start=True):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
self.profiler.start()
else:
self.profiler.stop()


def init_worker_distributed_environment(
parallel_config: ParallelConfig,
Expand Down