From e0fc559926fa60b05a94bd0eacfa65c7ba5d2211 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Mon, 29 Sep 2025 15:47:36 -0400 Subject: [PATCH 1/2] chore: remove and update imports accordingly Signed-off-by: Aaron Pham --- .../tensorizer_loader/conftest.py | 2 +- tools/pre_commit/check_pickle_imports.py | 1 - vllm/executor/executor_base.py | 10 +- vllm/executor/ray_utils.py | 2 +- vllm/executor/uniproc_executor.py | 10 +- vllm/platforms/cuda.py | 12 +- vllm/platforms/rocm.py | 12 +- vllm/v1/executor/multiproc_executor.py | 4 +- vllm/v1/worker/worker_base.py | 280 ++++++++++++++++-- vllm/worker/__init__.py | 0 vllm/worker/worker_base.py | 279 ----------------- 11 files changed, 275 insertions(+), 337 deletions(-) delete mode 100644 vllm/worker/__init__.py delete mode 100644 vllm/worker/worker_base.py diff --git a/tests/model_executor/model_loader/tensorizer_loader/conftest.py b/tests/model_executor/model_loader/tensorizer_loader/conftest.py index 571dc2e0eb50..cc02d7ecf20b 100644 --- a/tests/model_executor/model_loader/tensorizer_loader/conftest.py +++ b/tests/model_executor/model_loader/tensorizer_loader/conftest.py @@ -10,7 +10,7 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.v1.executor.abstract import UniProcExecutor -from vllm.worker.worker_base import WorkerWrapperBase +from vllm.v1.worker.worker_base import WorkerWrapperBase MODEL_REF = "facebook/opt-125m" diff --git a/tools/pre_commit/check_pickle_imports.py b/tools/pre_commit/check_pickle_imports.py index acbbc1f181d6..c97a5b0b6c71 100644 --- a/tools/pre_commit/check_pickle_imports.py +++ b/tools/pre_commit/check_pickle_imports.py @@ -36,7 +36,6 @@ 'benchmarks/cutlass_benchmarks/w8a8_benchmarks.py', 'benchmarks/cutlass_benchmarks/sparse_benchmarks.py', # cloudpickle - 'vllm/worker/worker_base.py', 'vllm/executor/mp_distributed_executor.py', 'vllm/executor/ray_distributed_executor.py', 'vllm/entrypoints/llm.py', diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index e3063ec2b8ab..fe80be61410c 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -19,7 +19,7 @@ from vllm.tasks import SupportedTask from vllm.utils import make_async from vllm.v1.outputs import PoolerOutput, SamplerOutput -from vllm.worker.worker_base import WorkerBase +from vllm.v1.worker.worker_base import WorkerBase logger = init_logger(__name__) @@ -30,7 +30,7 @@ class ExecutorBase(ABC): """Base class for all executors. An executor is responsible for executing the model on one device, - or it can be a distributed executor + or it can be a distributed executor that can execute the model on multiple devices. """ @@ -83,7 +83,7 @@ def collective_rpc(self, Returns: A list containing the results from each worker. - + Note: It is recommended to use this API to only pass control messages, and set up data-plane communication to pass data. @@ -100,7 +100,7 @@ def determine_num_available_blocks(self) -> tuple[int, int]: Returns a tuple `(num_gpu_blocks, num_cpu_blocks)`, where `num_gpu_blocks` are blocks that are "active" on the device and can be - appended to. + appended to. `num_cpu_blocks` refers to "swapped" blocks in CPU memory and cannot be appended to. """ @@ -327,7 +327,7 @@ def _run_workers( run only in the remote TP workers, not the driver worker. It will also be run asynchronously and return a list of futures rather than blocking on the results. - + # TODO: simplify and merge with collective_rpc """ raise NotImplementedError diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 0bdeb2856989..d8eb7977dbde 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -16,7 +16,7 @@ from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import get_ip -from vllm.worker.worker_base import WorkerWrapperBase +from vllm.v1.worker.worker_base import WorkerWrapperBase if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index 7a753d608a43..d669592e75f1 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -19,7 +19,7 @@ from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.executor.utils import get_and_update_mm_cache from vllm.v1.outputs import AsyncModelRunnerOutput -from vllm.worker.worker_base import WorkerWrapperBase +from vllm.v1.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -160,10 +160,10 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: """ Determine the number of available KV blocks. Add an additional all_reduce to get the min across all ranks. - Note that even if we have the same `gpu_memory_utilization` and - `swap_space`, the available memory in every rank might still - differ because NCCL can take different amounts of memory in - different ranks. Therefore, it is necessary to test if all ranks + Note that even if we have the same `gpu_memory_utilization` and + `swap_space`, the available memory in every rank might still + differ because NCCL can take different amounts of memory in + different ranks. Therefore, it is necessary to test if all ranks agree on the same KV cache configuration. """ a, b = super().determine_num_available_blocks() diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 6738d3dec286..1463fe34fc75 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -110,17 +110,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: model_config = vllm_config.model_config if parallel_config.worker_cls == "auto": - if vllm_config.speculative_config: - if not envs.VLLM_USE_V1: - raise NotImplementedError( - "Speculative decoding is not supported on vLLM V0.") - parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" - else: - if envs.VLLM_USE_V1: - parallel_config.worker_cls = \ - "vllm.v1.worker.gpu_worker.Worker" - else: - parallel_config.worker_cls = "vllm.worker.worker.Worker" + parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" cache_config = vllm_config.cache_config if cache_config and cache_config.block_size is None: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 1dacd026b667..f67568bf07c1 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -327,17 +327,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: cache_config.block_size = 16 if parallel_config.worker_cls == "auto": - if vllm_config.speculative_config: - if not use_v1: - raise NotImplementedError( - "Speculative decoding is not supported on vLLM V0.") - parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" - else: - if use_v1: - parallel_config.worker_cls = \ - "vllm.v1.worker.gpu_worker.Worker" - else: - parallel_config.worker_cls = "vllm.worker.worker.Worker" + parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" # Aiter rms norm perform best when CUDA Graph capture is enabled. if (use_v1 and use_aiter_rms_norm and not is_eager_execution and "-rms_norm" not in compilation_config.custom_ops): diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index ef90af263664..eecdf8def6de 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -41,7 +41,7 @@ from vllm.v1.executor.utils import get_and_update_mm_cache from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput) -from vllm.worker.worker_base import WorkerWrapperBase +from vllm.v1.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -702,7 +702,7 @@ def setup_proc_title_and_log_prefix(enable_ep: bool) -> None: def set_multiprocessing_worker_envs(): """ Set up environment variables that should be used when there are workers - in a multiprocessing environment. This should be called by the parent + in a multiprocessing environment. This should be called by the parent process before worker processes are created""" _maybe_force_spawn() diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index 038ce4b54f96..f620343ddf4b 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -1,23 +1,35 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from __future__ import annotations + +import os +from typing import Any, Callable, Optional, TypeVar, Union import torch import torch.nn as nn -from vllm.config import VllmConfig +from vllm.config import VllmConfig, set_current_vllm_config from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.sequence import ExecuteModelRequest +from vllm.utils import (enable_trace_function_call_for_thread, + resolve_obj_by_qualname, run_method, + update_environment_variables, + warn_for_unimplemented_methods) from vllm.v1.kv_cache_interface import KVCacheSpec -from vllm.worker.worker_base import WorkerBase as WorkerBaseV0 +from vllm.v1.outputs import SamplerOutput logger = init_logger(__name__) +_R = TypeVar("_R") -class WorkerBase(WorkerBaseV0): - """ - Abstract class for v1 worker, mainly define some methods for v1. - For methods shared by v0 and v1, define them in v0 WorkerBase + +@warn_for_unimplemented_methods +class WorkerBase: + """Worker interface that allows vLLM to cleanly separate implementations for + different hardware. Also abstracts control plane communication, e.g., to + communicate request metadata to other workers. """ def __init__( @@ -27,20 +39,22 @@ def __init__( rank: int, distributed_init_method: str, is_driver_worker: bool = False, - ): - """ - Initialize common worker components. - - Args: - vllm_config: Complete vLLM configuration - local_rank: Local device index - rank: Global rank in distributed setup - distributed_init_method: Distributed initialization method - is_driver_worker: Whether this worker handles driver - responsibilities - """ - # Configuration storage - super().__init__(vllm_config=vllm_config) + ) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.observability_config = vllm_config.observability_config + self.kv_transfer_config = vllm_config.kv_transfer_config + self.compilation_config = vllm_config.compilation_config + + from vllm.platforms import current_platform + self.current_platform = current_platform self.parallel_config.rank = rank self.local_rank = local_rank @@ -63,3 +77,227 @@ def compile_or_warm_up_model(self) -> None: def check_health(self) -> None: """Basic health check (override for device-specific checks).""" return + + def init_device(self) -> None: + """Initialize device state, such as loading the model or other on-device + memory allocations. + """ + raise NotImplementedError + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache with the given size in blocks. + """ + raise NotImplementedError + + def get_model(self) -> nn.Module: + raise NotImplementedError + + def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R: + """Apply a function on the model inside this worker.""" + return fn(self.get_model()) + + def load_model(self) -> None: + """Load model onto target device.""" + raise NotImplementedError + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> Optional[list[SamplerOutput]]: + raise NotImplementedError + + def start_worker_execution_loop(self) -> None: + """Execute model loop in parallel worker. + + You can stop the loop by executing a driver worker with an empty output. + See `stop_remote_worker_execution_loop` for more details. + """ + with self.current_platform.inference_mode(): + while True: + output = self.execute_model(execute_model_req=None) + if output is None: + return None + + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available blocks for the GPU KV cache and + swappable CPU KV cache. + + The implementation may run profiling or other heuristics to determine + the size of caches. + + Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + are blocks that are "active" on the device and can be appended to. + num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be + appended to. + """ + raise NotImplementedError + + def get_cache_block_size_bytes(self) -> int: + """Return the size of a single cache block, in bytes. Used in + speculative decoding. + """ + raise NotImplementedError + + def add_lora(self, lora_request: LoRARequest) -> bool: + raise NotImplementedError + + def remove_lora(self, lora_id: int) -> bool: + raise NotImplementedError + + def pin_lora(self, lora_id: int) -> bool: + raise NotImplementedError + + def list_loras(self) -> set[int]: + raise NotImplementedError + + @property + def vocab_size(self) -> int: + """Get vocabulary size from model configuration.""" + return self.model_config.get_vocab_size() + + def shutdown(self) -> None: + """Clean up resources held by the worker.""" + return + + +class WorkerWrapperBase: + """ + This class represents one process in an executor/engine. It is responsible + for lazily initializing the worker and handling the worker's lifecycle. + We first instantiate the WorkerWrapper, which remembers the worker module + and class name. Then, when we call `update_environment_variables`, and the + real initialization happens in `init_worker`. + """ + + def __init__( + self, + vllm_config: VllmConfig, + rpc_rank: int = 0, + ) -> None: + """ + Initialize the worker wrapper with the given vllm_config and rpc_rank. + Note: rpc_rank is the rank of the worker in the executor. In most cases, + it is also the rank of the worker in the distributed group. However, + when multiple executors work together, they can be different. + e.g. in the case of SPMD-style offline inference with TP=2, + users can launch 2 engines/executors, each with only 1 worker. + All workers have rpc_rank=0, but they have different ranks in the TP + group. + """ + self.rpc_rank = rpc_rank + self.worker: Optional[WorkerBase] = None + self.vllm_config: Optional[VllmConfig] = None + # do not store this `vllm_config`, `init_worker` will set the final + # one. TODO: investigate if we can remove this field in + # `WorkerWrapperBase`, `init_cached_hf_modules` should be + # unnecessary now. + if vllm_config.model_config is not None: + # it can be None in tests + trust_remote_code = vllm_config.model_config.trust_remote_code + if trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + def shutdown(self) -> None: + if self.worker is not None: + self.worker.shutdown() + + def adjust_rank(self, rank_mapping: dict[int, int]) -> None: + """ + Adjust the rpc_rank based on the given mapping. + It is only used during the initialization of the executor, + to adjust the rpc_rank of workers after we create all workers. + """ + if self.rpc_rank in rank_mapping: + self.rpc_rank = rank_mapping[self.rpc_rank] + + def update_environment_variables( + self, + envs_list: list[dict[str, str]], + ) -> None: + envs = envs_list[self.rpc_rank] + key = 'CUDA_VISIBLE_DEVICES' + if key in envs and key in os.environ: + # overwriting CUDA_VISIBLE_DEVICES is desired behavior + # suppress the warning in `update_environment_variables` + del os.environ[key] + update_environment_variables(envs) + + def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None: + """ + Here we inject some common logic before initializing the worker. + Arguments are passed to the worker class constructor. + """ + kwargs = all_kwargs[self.rpc_rank] + self.vllm_config = kwargs.get("vllm_config") + assert self.vllm_config is not None, ( + "vllm_config is required to initialize the worker") + enable_trace_function_call_for_thread(self.vllm_config) + + from vllm.plugins import load_general_plugins + load_general_plugins() + + if isinstance(self.vllm_config.parallel_config.worker_cls, str): + worker_class = resolve_obj_by_qualname( + self.vllm_config.parallel_config.worker_cls) + else: + raise ValueError( + "passing worker_cls is no longer supported. Please pass keep the class in a separate module and pass the qualified name of the class as a string." # noqa: E501 + ) + if self.vllm_config.parallel_config.worker_extension_cls: + worker_extension_cls = resolve_obj_by_qualname( + self.vllm_config.parallel_config.worker_extension_cls) + extended_calls = [] + if worker_extension_cls not in worker_class.__bases__: + # check any conflicts between worker and worker_extension_cls + for attr in dir(worker_extension_cls): + if attr.startswith("__"): + continue + assert not hasattr(worker_class, attr), ( + f"Worker class {worker_class} already has an attribute" + f" {attr}, which conflicts with the worker" + f" extension class {worker_extension_cls}.") + if callable(getattr(worker_extension_cls, attr)): + extended_calls.append(attr) + # dynamically inherit the worker extension class + worker_class.__bases__ = worker_class.__bases__ + ( + worker_extension_cls, ) + logger.info( + "Injected %s into %s for extended collective_rpc calls %s", + worker_extension_cls, worker_class, extended_calls) + with set_current_vllm_config(self.vllm_config): + # To make vLLM config available during worker initialization + self.worker = worker_class(**kwargs) + assert self.worker is not None + + def initialize_from_config(self, kv_cache_configs: list[Any]) -> None: + kv_cache_config = kv_cache_configs[self.rpc_rank] + with set_current_vllm_config(self.vllm_config): + self.worker.initialize_from_config(kv_cache_config) # type: ignore + + def init_device(self): + with set_current_vllm_config(self.vllm_config): + # To make vLLM config available during device initialization + self.worker.init_device() # type: ignore + + def execute_method(self, method: Union[str, bytes], *args, **kwargs): + try: + # method resolution order: + # if a method is defined in this class, it will be called directly. + # otherwise, since we define `__getattr__` and redirect attribute + # query to `self.worker`, the method will be called on the worker. + return run_method(self, method, args, kwargs) + except Exception as e: + # if the driver worker also execute methods, + # exceptions in the rest worker may cause deadlock in rpc like ray + # see https://github.com/vllm-project/vllm/issues/3455 + # print the error and inform the user to solve the error + msg = (f"Error executing method {method!r}. " + "This might cause deadlock in distributed execution.") + logger.exception(msg) + raise e + + def __getattr__(self, attr): + return getattr(self.worker, attr) diff --git a/vllm/worker/__init__.py b/vllm/worker/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py deleted file mode 100644 index 20fabef4f19b..000000000000 --- a/vllm/worker/worker_base.py +++ /dev/null @@ -1,279 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os -from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, - Union) - -import cloudpickle -import torch.nn as nn - -from vllm.config import VllmConfig, set_current_vllm_config -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest -from vllm.utils import (enable_trace_function_call_for_thread, - resolve_obj_by_qualname, run_method, - update_environment_variables, - warn_for_unimplemented_methods) -from vllm.v1.outputs import SamplerOutput - -logger = init_logger(__name__) - -_R = TypeVar("_R") - - -@warn_for_unimplemented_methods -class WorkerBase: - """Worker interface that allows vLLM to cleanly separate implementations for - different hardware. Also abstracts control plane communication, e.g., to - communicate request metadata to other workers. - """ - - def __init__( - self, - vllm_config: VllmConfig, - ) -> None: - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config - self.observability_config = vllm_config.observability_config - self.kv_transfer_config = vllm_config.kv_transfer_config - self.compilation_config = vllm_config.compilation_config - from vllm.platforms import current_platform - self.current_platform = current_platform - - def init_device(self) -> None: - """Initialize device state, such as loading the model or other on-device - memory allocations. - """ - raise NotImplementedError - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache with the given size in blocks. - """ - raise NotImplementedError - - def get_model(self) -> nn.Module: - raise NotImplementedError - - def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R: - """Apply a function on the model inside this worker.""" - return fn(self.get_model()) - - def load_model(self) -> None: - """Load model onto target device.""" - raise NotImplementedError - - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[List[SamplerOutput]]: - raise NotImplementedError - - def start_worker_execution_loop(self) -> None: - """Execute model loop in parallel worker. - - You can stop the loop by executing a driver worker with an empty output. - See `stop_remote_worker_execution_loop` for more details. - """ - with self.current_platform.inference_mode(): - while True: - output = self.execute_model(execute_model_req=None) - if output is None: - return None - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available blocks for the GPU KV cache and - swappable CPU KV cache. - - The implementation may run profiling or other heuristics to determine - the size of caches. - - Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks - are blocks that are "active" on the device and can be appended to. - num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be - appended to. - """ - raise NotImplementedError - - def get_cache_block_size_bytes(self) -> int: - """Return the size of a single cache block, in bytes. Used in - speculative decoding. - """ - raise NotImplementedError - - def add_lora(self, lora_request: LoRARequest) -> bool: - raise NotImplementedError - - def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError - - def pin_lora(self, lora_id: int) -> bool: - raise NotImplementedError - - def list_loras(self) -> Set[int]: - raise NotImplementedError - - @property - def vocab_size(self) -> int: - """Get vocabulary size from model configuration.""" - return self.model_config.get_vocab_size() - - def shutdown(self) -> None: - """Clean up resources held by the worker.""" - return - - -class WorkerWrapperBase: - """ - This class represents one process in an executor/engine. It is responsible - for lazily initializing the worker and handling the worker's lifecycle. - We first instantiate the WorkerWrapper, which remembers the worker module - and class name. Then, when we call `update_environment_variables`, and the - real initialization happens in `init_worker`. - """ - - def __init__( - self, - vllm_config: VllmConfig, - rpc_rank: int = 0, - ) -> None: - """ - Initialize the worker wrapper with the given vllm_config and rpc_rank. - Note: rpc_rank is the rank of the worker in the executor. In most cases, - it is also the rank of the worker in the distributed group. However, - when multiple executors work together, they can be different. - e.g. in the case of SPMD-style offline inference with TP=2, - users can launch 2 engines/executors, each with only 1 worker. - All workers have rpc_rank=0, but they have different ranks in the TP - group. - """ - self.rpc_rank = rpc_rank - self.worker: Optional[WorkerBase] = None - self.vllm_config: Optional[VllmConfig] = None - # do not store this `vllm_config`, `init_worker` will set the final - # one. TODO: investigate if we can remove this field in - # `WorkerWrapperBase`, `init_cached_hf_modules` should be - # unnecessary now. - if vllm_config.model_config is not None: - # it can be None in tests - trust_remote_code = vllm_config.model_config.trust_remote_code - if trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - init_cached_hf_modules() - - def shutdown(self) -> None: - if self.worker is not None: - self.worker.shutdown() - - def adjust_rank(self, rank_mapping: Dict[int, int]) -> None: - """ - Adjust the rpc_rank based on the given mapping. - It is only used during the initialization of the executor, - to adjust the rpc_rank of workers after we create all workers. - """ - if self.rpc_rank in rank_mapping: - self.rpc_rank = rank_mapping[self.rpc_rank] - - def update_environment_variables(self, envs_list: List[Dict[str, - str]]) -> None: - envs = envs_list[self.rpc_rank] - key = 'CUDA_VISIBLE_DEVICES' - if key in envs and key in os.environ: - # overwriting CUDA_VISIBLE_DEVICES is desired behavior - # suppress the warning in `update_environment_variables` - del os.environ[key] - update_environment_variables(envs) - - def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: - """ - Here we inject some common logic before initializing the worker. - Arguments are passed to the worker class constructor. - """ - kwargs = all_kwargs[self.rpc_rank] - self.vllm_config = kwargs.get("vllm_config") - assert self.vllm_config is not None, ( - "vllm_config is required to initialize the worker") - enable_trace_function_call_for_thread(self.vllm_config) - - from vllm.plugins import load_general_plugins - load_general_plugins() - - if isinstance(self.vllm_config.parallel_config.worker_cls, str): - worker_class = resolve_obj_by_qualname( - self.vllm_config.parallel_config.worker_cls) - else: - logger.warning( - "passing worker_cls as a class object is strongly deprecated," - " as the serialization of class objects can be tricky and" - " error-prone. To be safe, please keep the class in a separate" - " module and pass the qualified name of the class as a string." - ) - assert isinstance(self.vllm_config.parallel_config.worker_cls, - bytes) - worker_class = cloudpickle.loads( - self.vllm_config.parallel_config.worker_cls) - if self.vllm_config.parallel_config.worker_extension_cls: - worker_extension_cls = resolve_obj_by_qualname( - self.vllm_config.parallel_config.worker_extension_cls) - extended_calls = [] - if worker_extension_cls not in worker_class.__bases__: - # check any conflicts between worker and worker_extension_cls - for attr in dir(worker_extension_cls): - if attr.startswith("__"): - continue - assert not hasattr(worker_class, attr), ( - f"Worker class {worker_class} already has an attribute" - f" {attr}, which conflicts with the worker" - f" extension class {worker_extension_cls}.") - if callable(getattr(worker_extension_cls, attr)): - extended_calls.append(attr) - # dynamically inherit the worker extension class - worker_class.__bases__ = worker_class.__bases__ + ( - worker_extension_cls, ) - logger.info( - "Injected %s into %s for extended collective_rpc calls %s", - worker_extension_cls, worker_class, extended_calls) - with set_current_vllm_config(self.vllm_config): - # To make vLLM config available during worker initialization - self.worker = worker_class(**kwargs) - assert self.worker is not None - - def initialize_from_config(self, kv_cache_configs: List[Any]) -> None: - kv_cache_config = kv_cache_configs[self.rpc_rank] - with set_current_vllm_config(self.vllm_config): - self.worker.initialize_from_config(kv_cache_config) # type: ignore - - def init_device(self): - with set_current_vllm_config(self.vllm_config): - # To make vLLM config available during device initialization - self.worker.init_device() # type: ignore - - def execute_method(self, method: Union[str, bytes], *args, **kwargs): - try: - # method resolution order: - # if a method is defined in this class, it will be called directly. - # otherwise, since we define `__getattr__` and redirect attribute - # query to `self.worker`, the method will be called on the worker. - return run_method(self, method, args, kwargs) - except Exception as e: - # if the driver worker also execute methods, - # exceptions in the rest worker may cause deadlock in rpc like ray - # see https://github.com/vllm-project/vllm/issues/3455 - # print the error and inform the user to solve the error - msg = (f"Error executing method {method!r}. " - "This might cause deadlock in distributed execution.") - logger.exception(msg) - raise e - - def __getattr__(self, attr): - return getattr(self.worker, attr) From 9d5446e5042a0d58a47dabb22e987fef8279f544 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Mon, 29 Sep 2025 16:04:18 -0400 Subject: [PATCH 2/2] chore: restore docs Signed-off-by: Aaron Pham --- vllm/v1/worker/worker_base.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index f620343ddf4b..5b393ee6bf3e 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -40,6 +40,17 @@ def __init__( distributed_init_method: str, is_driver_worker: bool = False, ) -> None: + """ + Initialize common worker components. + + Args: + vllm_config: Complete vLLM configuration + local_rank: Local device index + rank: Global rank in distributed setup + distributed_init_method: Distributed initialization method + is_driver_worker: Whether this worker handles driver + responsibilities + """ self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config