diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 326fe6dd048a..837a409670f6 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -20,8 +20,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.linear import UnquantizedLinearMethod -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.quant_utils import ( diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 087c5004bde0..3964eca7d36b 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -9,7 +9,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata) from vllm.attention.selector import get_attn_backend -from vllm.config import CacheConfig, QuantizationConfig +from vllm.config import CacheConfig +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, make_local_attention_virtual_batches, subclass_attention_backend) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index ccb91999d370..c909265c071d 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1,29 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# ruff: noqa: F401 -import ast -import copy -import hashlib -import inspect -import json -import os -import textwrap -from contextlib import contextmanager -from dataclasses import field, fields, is_dataclass, replace -from functools import cached_property, lru_cache -from pathlib import Path -from typing import (TYPE_CHECKING, Any, Literal, Optional, Protocol, TypeVar, - Union, cast) - -import regex as re -import torch -from pydantic import ConfigDict, SkipValidation -from pydantic.dataclasses import dataclass -from typing_extensions import runtime_checkable - -import vllm.envs as envs -from vllm import version from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType, PrefixCachingHashAlgo) from vllm.config.compilation import (CompilationConfig, CompilationLevel, @@ -48,806 +25,82 @@ from vllm.config.speculative import SpeculativeConfig from vllm.config.speech_to_text import SpeechToTextConfig from vllm.config.structured_outputs import StructuredOutputsConfig -from vllm.config.utils import ConfigType, config, get_attr_docs, is_init_field -from vllm.logger import init_logger -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.transformers_utils.runai_utils import is_runai_obj_uri -from vllm.utils import random_uuid - -if TYPE_CHECKING: - from _typeshed import DataclassInstance - from transformers.configuration_utils import PretrainedConfig - - from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -else: - DataclassInstance = Any - PretrainedConfig = Any - QuantizationConfig = Any - QuantizationMethods = Any - BaseModelLoader = Any - LogitsProcessor = Any - -logger = init_logger(__name__) -DataclassInstanceT = TypeVar("DataclassInstanceT", bound=DataclassInstance) - - -@runtime_checkable -class SupportsHash(Protocol): - - def compute_hash(self) -> str: - ... - - -class SupportsMetricsInfo(Protocol): - - def metrics_info(self) -> dict[str, str]: - ... - - -@config -@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) -class VllmConfig: - """Dataclass which contains all vllm-related configuration. This - simplifies passing around the distinct configurations in the codebase. - """ - - # TODO: use default_factory once default constructing ModelConfig doesn't - # try to download a model - model_config: ModelConfig = None # type: ignore - """Model configuration.""" - cache_config: CacheConfig = field(default_factory=CacheConfig) - """Cache configuration.""" - parallel_config: ParallelConfig = field(default_factory=ParallelConfig) - """Parallel configuration.""" - scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig) - """Scheduler configuration.""" - device_config: DeviceConfig = field(default_factory=DeviceConfig) - """Device configuration.""" - load_config: LoadConfig = field(default_factory=LoadConfig) - """Load configuration.""" - lora_config: Optional[LoRAConfig] = None - """LoRA configuration.""" - speculative_config: Optional[SpeculativeConfig] = None - """Speculative decoding configuration.""" - structured_outputs_config: StructuredOutputsConfig = field( - default_factory=StructuredOutputsConfig) - """Structured outputs configuration.""" - observability_config: Optional[ObservabilityConfig] = None - """Observability configuration.""" - quant_config: Optional[QuantizationConfig] = None - """Quantization configuration.""" - compilation_config: CompilationConfig = field( - default_factory=CompilationConfig) - """`torch.compile` and cudagraph capture configuration for the model. - - As a shorthand, `-O` can be used to directly specify the compilation - level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`). - Currently, -O and -O= are supported as well but this will likely be - removed in favor of clearer -O syntax in the future. - - NOTE: level 0 is the default level without any optimization. level 1 and 2 - are for internal testing only. level 3 is the recommended level for - production, also default in V1. - - You can specify the full compilation config like so: - `{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}` - """ - kv_transfer_config: Optional[KVTransferConfig] = None - """The configurations for distributed KV cache transfer.""" - kv_events_config: Optional[KVEventsConfig] = None - """The configurations for event publishing.""" - # some opaque config, only used to provide additional information - # for the hash computation, mainly used for testing, debugging or out of - # tree config registration. - additional_config: Union[dict, SupportsHash] = field(default_factory=dict) - """Additional config for specified platform. Different platforms may - support different configs. Make sure the configs are valid for the platform - you are using. Contents must be hashable.""" - instance_id: str = "" - """The ID of the vLLM instance.""" - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - factors: list[Any] = [] - - # summarize vllm config - vllm_factors: list[Any] = [] - from vllm import __version__ - vllm_factors.append(__version__) - vllm_factors.append(envs.VLLM_USE_V1) - if self.model_config: - vllm_factors.append(self.model_config.compute_hash()) - else: - vllm_factors.append("None") - if self.cache_config: - vllm_factors.append(self.cache_config.compute_hash()) - else: - vllm_factors.append("None") - if self.parallel_config: - vllm_factors.append(self.parallel_config.compute_hash()) - else: - vllm_factors.append("None") - if self.scheduler_config: - vllm_factors.append(self.scheduler_config.compute_hash()) - else: - vllm_factors.append("None") - if self.device_config: - vllm_factors.append(self.device_config.compute_hash()) - else: - vllm_factors.append("None") - if self.load_config: - vllm_factors.append(self.load_config.compute_hash()) - else: - vllm_factors.append("None") - if self.lora_config: - vllm_factors.append(self.lora_config.compute_hash()) - # LoRA creates static buffers based on max_num_batched_tokens. - # The tensor sizes and strides get captured in the torch.compile - # graph explicitly. - vllm_factors.append( - str(self.scheduler_config.max_num_batched_tokens)) - else: - vllm_factors.append("None") - if self.speculative_config: - vllm_factors.append(self.speculative_config.compute_hash()) - else: - vllm_factors.append("None") - if self.structured_outputs_config: - vllm_factors.append(self.structured_outputs_config.compute_hash()) - else: - vllm_factors.append("None") - if self.observability_config: - vllm_factors.append(self.observability_config.compute_hash()) - else: - vllm_factors.append("None") - if self.quant_config: - pass # should be captured by model_config.quantization - if self.compilation_config: - vllm_factors.append(self.compilation_config.compute_hash()) - else: - vllm_factors.append("None") - if self.kv_transfer_config: - vllm_factors.append(self.kv_transfer_config.compute_hash()) - else: - vllm_factors.append("None") - if self.additional_config: - if isinstance(additional_config := self.additional_config, dict): - additional_config_hash = hashlib.md5( - json.dumps(additional_config, sort_keys=True).encode(), - usedforsecurity=False, - ).hexdigest() - else: - additional_config_hash = additional_config.compute_hash() - vllm_factors.append(additional_config_hash) - else: - vllm_factors.append("None") - factors.append(vllm_factors) - - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest()[:10] - return hash_str - - def pad_for_cudagraph(self, batch_size: int) -> int: - # if batch_size > self.compilation_config.max_capture_size, - # it should raise an IndexError. - # the caller should make sure the batch_size is within the range, - # i.e., batch_size <= self.compilation_config.max_capture_size - return self.compilation_config.bs_to_padded_graph_size[batch_size] - - @staticmethod - def _get_quantization_config( - model_config: ModelConfig, - load_config: LoadConfig) -> Optional[QuantizationConfig]: - """Get the quantization config.""" - from vllm.platforms import current_platform - if model_config.quantization is not None: - from vllm.model_executor.model_loader.weight_utils import ( - get_quant_config) - quant_config = get_quant_config(model_config, load_config) - capability_tuple = current_platform.get_device_capability() - - if capability_tuple is not None: - capability = capability_tuple.to_int() - if capability < quant_config.get_min_capability(): - raise ValueError( - f"The quantization method {model_config.quantization} " - "is not supported for the current GPU. Minimum " - f"capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}.") - supported_dtypes = quant_config.get_supported_act_dtypes() - if model_config.dtype not in supported_dtypes: - raise ValueError( - f"{model_config.dtype} is not supported for quantization " - f"method {model_config.quantization}. Supported dtypes: " - f"{supported_dtypes}") - quant_config.maybe_update_config(model_config.model) - return quant_config - return None - - @staticmethod - def get_quantization_config( - model_config: ModelConfig, - load_config: LoadConfig) -> Optional[QuantizationConfig]: - import copy - - # For some reason, the _ version of this modifies the model_config - # object, so using deepcopy to avoid this problem. - return VllmConfig._get_quantization_config(copy.deepcopy(model_config), - load_config) - - def with_hf_config( - self, - hf_config: PretrainedConfig, - architectures: Optional[list[str]] = None, - ) -> "VllmConfig": - if architectures is not None: - hf_config = copy.deepcopy(hf_config) - hf_config.architectures = architectures - - model_config = copy.deepcopy(self.model_config) - model_config.hf_config = hf_config - - return replace(self, model_config=model_config) - - def __post_init__(self): - """Verify configs are valid & consistent with each other. - """ - - self.try_verify_and_update_config() - - if self.model_config is not None: - self.model_config.verify_with_parallel_config(self.parallel_config) - self.model_config.verify_dual_chunk_attention_config( - self.load_config) - - self.cache_config.verify_with_parallel_config(self.parallel_config) - - if self.lora_config is not None: - self.lora_config.verify_with_cache_config(self.cache_config) - self.lora_config.verify_with_model_config(self.model_config) - - if self.quant_config is None and self.model_config is not None: - self.quant_config = VllmConfig._get_quantization_config( - self.model_config, self.load_config) - - from vllm.platforms import current_platform - if self.model_config is not None and \ - self.scheduler_config.chunked_prefill_enabled and \ - self.model_config.dtype == torch.float32 and \ - current_platform.get_device_capability() == (7, 5): - logger.warning_once( - "Turing devices tensor cores do not support float32 matmul. " - "To workaround this limitation, vLLM will set 'ieee' input " - "precision for chunked prefill triton kernels.") - - # If the user does not explicitly set a compilation level, then - # we use the default level. The default level depends on other - # settings (see the below code). - if self.compilation_config.level is None: - if envs.VLLM_USE_V1: - if (self.model_config is not None - and not self.model_config.enforce_eager): - self.compilation_config.level = CompilationLevel.PIECEWISE - else: - self.compilation_config.level = \ - CompilationLevel.NO_COMPILATION - - else: - # NB: Passing both --enforce-eager and a compilation level - # in V0 means the compilation level wins out. - self.compilation_config.level = CompilationLevel.NO_COMPILATION - - # async tp is built on top of sequence parallelism - # and requires it to be enabled. - if self.compilation_config.pass_config.enable_async_tp: - self.compilation_config.pass_config.enable_sequence_parallelism = \ - True - if self.compilation_config.pass_config.enable_sequence_parallelism: - self.compilation_config.custom_ops.append("+rms_norm") - - if current_platform.support_static_graph_mode(): - # if cudagraph_mode is not explicitly set by users, set default - # value - if self.compilation_config.cudagraph_mode is None: - if envs.VLLM_USE_V1 and self.compilation_config.level \ - == CompilationLevel.PIECEWISE: - # default to full and piecewise for most models - self.compilation_config.cudagraph_mode = \ - CUDAGraphMode.FULL_AND_PIECEWISE - - # pooling models and encoder-decoder models - # do not support full cudagraphs - if self.model_config is not None and \ - (self.model_config.pooler_config is not None - or self.model_config.is_encoder_decoder): - self.compilation_config.cudagraph_mode = \ - CUDAGraphMode.PIECEWISE - else: - self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE - - # disable cudagraph when enforce eager execution - if self.model_config is not None and \ - self.model_config.enforce_eager: - logger.info("Cudagraph is disabled under eager mode") - self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE - elif envs.VLLM_USE_V1: - self.compilation_config.cudagraph_num_of_warmups = 1 - - self._set_cudagraph_sizes() - else: - self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE - - if self.cache_config.kv_sharing_fast_prefill: - - if self.speculative_config is not None and \ - self.speculative_config.use_eagle(): - raise NotImplementedError( - "Fast prefill optimization for KV sharing is not " - "compatible with EAGLE as EAGLE requires correct logits " - "for all tokens while fast prefill gives incorrect logits " - "for prompt tokens.") - - logger.warning_once( - "--kv-sharing-fast-prefill requires changes on model side for " - "correctness and to realize prefill savings. ") - - disable_chunked_prefill_reasons: list[str] = [] - - if self.model_config: - if self.model_config.pooler_config: - pooling_type = self.model_config.pooler_config.pooling_type - if pooling_type is None or pooling_type.lower() != "last": - disable_chunked_prefill_reasons.append( - "Only \"last\" pooling supports chunked " - "prefill and prefix caching; disabling both.") - if not getattr(self.model_config.hf_config, "is_causal", True): - disable_chunked_prefill_reasons.append( - "Only models using causal attention supports chunked " - "prefill and prefix caching; disabling both.") - elif self.model_config.is_encoder_decoder: - self.scheduler_config.max_num_encoder_input_tokens = \ - MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config) - logger.debug( - "Encoder-decoder model detected: setting " - "`max_num_encoder_input_tokens` to encoder length (%s)", - self.scheduler_config.max_num_encoder_input_tokens) - self.scheduler_config.disable_chunked_mm_input = True - disable_chunked_prefill_reasons.append( - "Encoder-decoder models do not support chunked prefill nor" - " prefix caching; disabling both.") - if (self.model_config.architecture - == "WhisperForConditionalGeneration" - and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") - != "spawn"): - logger.warning( - "Whisper is known to have issues with " - "forked workers. If startup is hanging, " - "try setting 'VLLM_WORKER_MULTIPROC_METHOD' " - "to 'spawn'.") - - if disable_chunked_prefill_reasons: - for reason in disable_chunked_prefill_reasons: - logger.info(reason) - self.scheduler_config.chunked_prefill_enabled = False - self.scheduler_config.long_prefill_token_threshold = 0 - - if self.cache_config is not None: - self.cache_config.enable_prefix_caching = False - - if (self.kv_events_config is not None - and self.kv_events_config.enable_kv_cache_events - and not self.cache_config.enable_prefix_caching): - logger.warning( - "KV cache events are on, but prefix caching is not enabled." - "Use --enable-prefix-caching to enable.") - if (self.kv_events_config is not None - and self.kv_events_config.publisher != "null" - and not self.kv_events_config.enable_kv_cache_events): - logger.warning("KV cache events are disabled," - "but the scheduler is configured to publish them." - "Modify KVEventsConfig.enable_kv_cache_events" - "to True to enable.") - current_platform.check_and_update_config(self) - - # Do this after all the updates to compilation_config.level - if envs.VLLM_USE_V1 and \ - self.compilation_config.level == CompilationLevel.PIECEWISE: - self.compilation_config.set_splitting_ops_for_v1() - - # final check of cudagraph mode after all possible updates - if envs.VLLM_USE_V1 and current_platform.is_cuda_alike(): - if self.compilation_config.cudagraph_mode.has_full_cudagraphs()\ - and self.model_config is not None and \ - not self.model_config.disable_cascade_attn and\ - not self.compilation_config.cudagraph_mode.\ - has_piecewise_cudagraphs(): - logger.warning_once( - "No piecewise cudagraph for executing cascade attention." - " Will fall back to eager execution if a batch runs " - "into cascade attentions") - - if self.compilation_config.cudagraph_mode\ - .requires_piecewise_compilation(): - assert self.compilation_config.level == \ - CompilationLevel.PIECEWISE, \ - "Compilation level should be CompilationLevel.PIECEWISE "\ - "when cudagraph_mode piecewise cudagraphs is used, "\ - f"cudagraph_mode={self.compilation_config.cudagraph_mode}" - - # final migrate the deprecated flags - self.compilation_config.use_cudagraph = self.compilation_config.\ - cudagraph_mode!= CUDAGraphMode.NONE - self.compilation_config.full_cuda_graph = self.compilation_config.\ - cudagraph_mode.has_full_cudagraphs() - - if self.parallel_config.enable_dbo: - a2a_backend = envs.VLLM_ALL2ALL_BACKEND - assert a2a_backend in \ - ["deepep_low_latency", "deepep_high_throughput"], \ - "Microbatching currently only supports the deepep_low_latency and "\ - f"deepep_high_throughput all2all backend. {a2a_backend} is not "\ - "supported. To fix set the VLLM_ALL2ALL_BACKEND environment "\ - "variable to deepep_low_latency or deepep_high_throughput and "\ - "install the DeepEP kernels." - - if not self.model_config.disable_cascade_attn: - self.model_config.disable_cascade_attn = True - logger.warning_once( - "Disabling cascade attention when DBO is enabled.") - - if not self.instance_id: - self.instance_id = random_uuid()[:5] - - if (envs.VLLM_USE_V1 - and not self.scheduler_config.disable_hybrid_kv_cache_manager): - # logger should only print warning message for hybrid models. As we - # can't know whether the model is hybrid or not now, so we don't log - # warning message here and will log it later. - if not current_platform.support_hybrid_kv_cache(): - # Hybrid KV cache manager is not supported on non-GPU platforms. - self.scheduler_config.disable_hybrid_kv_cache_manager = True - if self.kv_transfer_config is not None: - # Hybrid KV cache manager is not compatible with KV transfer. - self.scheduler_config.disable_hybrid_kv_cache_manager = True - if self.kv_events_config is not None: - # Hybrid KV cache manager is not compatible with KV events. - self.scheduler_config.disable_hybrid_kv_cache_manager = True - if self.model_config is not None and \ - self.model_config.attention_chunk_size is not None: - if self.speculative_config is not None and \ - self.speculative_config.use_eagle(): - # Hybrid KV cache manager is not yet supported with chunked - # local attention + eagle. - self.scheduler_config.disable_hybrid_kv_cache_manager = True - elif \ - not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: - logger.warning( - "There is a latency regression when using chunked local" - " attention with the hybrid KV cache manager. Disabling" - " it, by default. To enable it, set the environment " - "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1." - ) - # Hybrid KV cache manager is not yet supported with chunked - # local attention. - self.scheduler_config.disable_hybrid_kv_cache_manager = True - - if self.compilation_config.debug_dump_path: - self.compilation_config.debug_dump_path = \ - self.compilation_config.debug_dump_path.absolute().expanduser() - if envs.VLLM_DEBUG_DUMP_PATH is not None: - env_path = Path(envs.VLLM_DEBUG_DUMP_PATH).absolute().expanduser() - if self.compilation_config.debug_dump_path: - logger.warning( - "Config-specified debug dump path is overridden" - " by VLLM_DEBUG_DUMP_PATH to %s", env_path) - self.compilation_config.debug_dump_path = env_path - - def update_sizes_for_sequence_parallelism(self, - possible_sizes: list) -> list: - # remove the sizes that not multiple of tp_size when - # enable sequence parallelism - removed_sizes = [ - size for size in possible_sizes - if size % self.parallel_config.tensor_parallel_size != 0 - ] - if removed_sizes: - logger.warning( - "Batch sizes %s are removed because they are not " - "multiple of tp_size %d when " - "sequence parallelism is enabled", removed_sizes, - self.parallel_config.tensor_parallel_size) - - return [ - size for size in possible_sizes - if size % self.parallel_config.tensor_parallel_size == 0 - ] - - def _set_cudagraph_sizes(self): - """ - vLLM defines the default candidate list of batch sizes for CUDA graph - capture as: - - ```python - max_graph_size = min(max_num_seqs * 2, 512) - # 1, 2, 4, then multiples of 8 up to max_graph_size - cuda_graph_sizes = [1, 2, 4, 8, 16, 24, 32, 40, ..., max_graph_size] - - In the end, `vllm_config.compilation_config.cudagraph_capture_sizes` - will be the final sizes to capture cudagraph (in descending order). - - These sizes are used to capture and reuse CUDA graphs for - performance-critical paths (e.g., decoding). Capturing enables - significantly faster kernel dispatch by avoiding Python overhead. The - list is then filtered based on `max_num_batched_tokens` (e.g., 8192 on - most GPUs), which controls the total allowed number of tokens in a - batch. Since each sequence may have a variable number of tokens, the - maximum usable batch size will depend on actual sequence lengths. - - Example: - With `max_num_batched_tokens = 8192`, and typical sequences - averaging ~32 tokens, most practical batch sizes fall below 256. - However, the system will still allow capture sizes up to 512 if - shape and memory permit. - - Note: - If users explicitly specify cudagraph capture sizes in the - compilation config, those will override this default logic. - At runtime: - - - If batch size <= one of the `cudagraph_capture_sizes`, the closest - padded CUDA graph will be used. - - If batch size > largest `cudagraph_capture_sizes`, cudagraph will - not be used. - """ - - # calculate the default `batch_size_capture_list` - batch_size_capture_list = [] - if self.model_config is not None and \ - not self.model_config.enforce_eager: - cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes - if len(cuda_graph_sizes) == 1: - batch_size_capture_list = [1, 2, 4] + [ - i for i in range(8, cuda_graph_sizes[0] + 1, 8) - ] - elif len(cuda_graph_sizes) > 1: - batch_size_capture_list = sorted(cuda_graph_sizes) - else: - raise TypeError(f"Invalid value for {cuda_graph_sizes=}.") - if self.parallel_config.tensor_parallel_size > 1 and \ - self.compilation_config.pass_config.enable_sequence_parallelism: - batch_size_capture_list = \ - self.update_sizes_for_sequence_parallelism(batch_size_capture_list) - max_num_tokens = self.scheduler_config.max_num_batched_tokens - batch_size_capture_list = [ - size for size in batch_size_capture_list - if size <= max_num_tokens - ] - - self.compilation_config.init_with_cudagraph_sizes( - batch_size_capture_list) - - def recalculate_max_model_len(self, max_model_len: int): - # Can only be called in try_verify_and_update_config - model_config = self.model_config - max_model_len = model_config.get_and_verify_max_len(max_model_len) - self.model_config.max_model_len = max_model_len - self.scheduler_config.max_model_len = max_model_len - - def try_verify_and_update_config(self): - if self.model_config is None: - return - - # Avoid running try_verify_and_update_config multiple times - if getattr(self.model_config, "config_updated", False): - return - self.model_config.config_updated = True - - architecture = self.model_config.architecture - if architecture is None: - return - - from vllm.model_executor.models.config import ( - MODELS_CONFIG_MAP, HybridAttentionMambaModelConfig) - cls = MODELS_CONFIG_MAP.get(architecture, None) - if cls is not None: - cls.verify_and_update_config(self) - - if self.model_config.is_hybrid: - HybridAttentionMambaModelConfig.verify_and_update_config(self) - - if self.model_config.convert_type == "classify": - # Maybe convert ForCausalLM into ForSequenceClassification model. - from vllm.model_executor.models.adapters import ( - SequenceClassificationConfig) - SequenceClassificationConfig.verify_and_update_config(self) - - if hasattr(self.model_config, "model_weights") and is_runai_obj_uri( - self.model_config.model_weights): - if self.load_config.load_format == "auto": - logger.info("Detected Run:ai model config. " - "Overriding `load_format` to 'runai_streamer'") - self.load_config.load_format = "runai_streamer" - elif self.load_config.load_format != "runai_streamer": - raise ValueError(f"To load a model from S3, 'load_format' " - f"must be 'runai_streamer', " - f"but got '{self.load_config.load_format}'. " - f"Model: {self.model_config.model}") - - def compile_debug_dump_path(self) -> Optional[Path]: - """Returns a rank-aware path for dumping - torch.compile debug information. - """ - if self.compilation_config.debug_dump_path is None: - return None - tp_rank = self.parallel_config.rank - dp_rank = self.parallel_config.data_parallel_rank - data_parallel_size = self.parallel_config.data_parallel_size - append_path = f"rank_{tp_rank}" if data_parallel_size == 1 \ - else f"rank_{tp_rank}_dp_{dp_rank}" - path = self.compilation_config.debug_dump_path / append_path - return path - - def __str__(self): - return ( - f"model={self.model_config.model!r}, " - f"speculative_config={self.speculative_config!r}, " - f"tokenizer={self.model_config.tokenizer!r}, " - f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}, " - f"tokenizer_mode={self.model_config.tokenizer_mode}, " - f"revision={self.model_config.revision}, " - f"tokenizer_revision={self.model_config.tokenizer_revision}, " - f"trust_remote_code={self.model_config.trust_remote_code}, " - f"dtype={self.model_config.dtype}, " - f"max_seq_len={self.model_config.max_model_len}, " - f"download_dir={self.load_config.download_dir!r}, " - f"load_format={self.load_config.load_format}, " - f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, " # noqa - f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa - f"data_parallel_size={self.parallel_config.data_parallel_size}, " # noqa - f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa - f"quantization={self.model_config.quantization}, " - f"enforce_eager={self.model_config.enforce_eager}, " - f"kv_cache_dtype={self.cache_config.cache_dtype}, " - f"device_config={self.device_config.device}, " - f"structured_outputs_config={self.structured_outputs_config!r}, " - f"observability_config={self.observability_config!r}, " - f"seed={self.model_config.seed}, " - f"served_model_name={self.model_config.served_model_name}, " - f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " - f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa - f"pooler_config={self.model_config.pooler_config!r}, " - f"compilation_config={self.compilation_config!r}") - - -_current_vllm_config: Optional[VllmConfig] = None -_current_prefix: Optional[str] = None - - -@contextmanager -def set_current_vllm_config(vllm_config: VllmConfig, - check_compile=False, - prefix: Optional[str] = None): - """ - Temporarily set the current vLLM config. - Used during model initialization. - We save the current vLLM config in a global variable, - so that all modules can access it, e.g. custom ops - can access the vLLM config to determine how to dispatch. - """ - global _current_vllm_config, _current_prefix - old_vllm_config = _current_vllm_config - old_prefix = _current_prefix - from vllm.compilation.counter import compilation_counter - num_models_seen = compilation_counter.num_models_seen - try: - _current_vllm_config = vllm_config - _current_prefix = prefix - yield - except Exception: - raise - else: - if check_compile: - vllm_config.compilation_config.custom_op_log_check() - - if check_compile and \ - vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \ - and compilation_counter.num_models_seen == num_models_seen: - # If the model supports compilation, - # compilation_counter.num_models_seen should be increased - # by at least 1. - # If it is not increased, it means the model does not support - # compilation (does not have @support_torch_compile decorator). - logger.warning( - "`torch.compile` is turned on, but the model %s" - " does not support it. Please open an issue on GitHub" - " if you want it to be supported.", - vllm_config.model_config.model) - finally: - _current_vllm_config = old_vllm_config - _current_prefix = old_prefix - # Clear the compilation config cache when context changes - get_cached_compilation_config.cache_clear() - - -@lru_cache(maxsize=1) -def get_cached_compilation_config(): - """Cache config to avoid repeated calls to get_current_vllm_config()""" - return get_current_vllm_config().compilation_config - - -def get_current_vllm_config() -> VllmConfig: - if _current_vllm_config is None: - # in ci, usually when we test custom ops/modules directly, - # we don't set the vllm config. In that case, we set a default - # config. - logger.warning("Current vLLM config is not set.") - from vllm.config import VllmConfig - return VllmConfig() - return _current_vllm_config - - -def get_current_model_prefix() -> str: - """ - Get the prefix of the model that's currently being initialized. - """ - assert _current_prefix is not None, \ - "Current model prefix is not set. " - return _current_prefix - - -T = TypeVar("T") - - -def get_layers_from_vllm_config( - vllm_config: VllmConfig, - layer_type: type[T], - layer_names: Optional[list[str]] = None) -> dict[str, T]: - """ - Get layers from the vLLM config. - - Args: - vllm_config: The vLLM config. - layer_type: The type of the layer to get. - layer_names: The names of the layers to get. If None, return all layers. - """ - - if layer_names is None: - layer_names = list( - vllm_config.compilation_config.static_forward_context.keys()) - - forward_context = vllm_config.compilation_config.static_forward_context - - return { - layer_name: forward_context[layer_name] - for layer_name in layer_names - if isinstance(forward_context[layer_name], layer_type) - } - - -def update_config(config: DataclassInstanceT, - overrides: dict[str, Any]) -> DataclassInstanceT: - processed_overrides = {} - for field_name, value in overrides.items(): - assert hasattr( - config, field_name), f"{type(config)} has no field `{field_name}`" - current_value = getattr(config, field_name) - if is_dataclass(current_value) and not is_dataclass(value): - assert isinstance(value, dict), ( - f"Overrides to {type(config)}.{field_name} must be a dict" - f" or {type(current_value)}, but got {type(value)}") - value = update_config( - current_value, # type: ignore[type-var] - value) - processed_overrides[field_name] = value - return replace(config, **processed_overrides) +from vllm.config.utils import (ConfigType, SupportsMetricsInfo, config, + get_attr_docs, is_init_field, update_config) +from vllm.config.vllm import (VllmConfig, get_cached_compilation_config, + get_current_vllm_config, + get_layers_from_vllm_config, + set_current_vllm_config) + +__all__ = [ + # From vllm.config.cache + "BlockSize", + "CacheConfig", + "CacheDType", + "MambaDType", + "PrefixCachingHashAlgo", + # From vllm.config.compilation + "CompilationConfig", + "CompilationLevel", + "CUDAGraphMode", + "PassConfig", + # From vllm.config.device + "Device", + "DeviceConfig", + # From vllm.config.kv_events + "KVEventsConfig", + # From vllm.config.kv_transfer + "KVTransferConfig", + # From vllm.config.load + "LoadConfig", + # From vllm.config.lora + "LoRAConfig", + # From vllm.config.model + "ConvertOption", + "HfOverrides", + "LogprobsMode", + "ModelConfig", + "ModelDType", + "ModelImpl", + "RunnerOption", + "TaskOption", + "TokenizerMode", + "iter_architecture_defaults", + "try_match_architecture_defaults", + # From vllm.config.multimodal + "MMCacheType", + "MMEncoderTPMode", + "MultiModalConfig", + # From vllm.config.observability + "DetailedTraceModules", + "ObservabilityConfig", + # From vllm.config.parallel + "DistributedExecutorBackend", + "EPLBConfig", + "ParallelConfig", + # From vllm.config.pooler + "PoolerConfig", + # From vllm.config.scheduler + "RunnerType", + "SchedulerConfig", + "SchedulerPolicy", + # From vllm.config.speculative + "SpeculativeConfig", + # From vllm.config.speech_to_text + "SpeechToTextConfig", + # From vllm.config.structured_outputs + "StructuredOutputsConfig", + # From vllm.config.utils + "ConfigType", + "SupportsMetricsInfo", + "config", + "get_attr_docs", + "is_init_field", + "update_config", + # From vllm.config.vllm + "VllmConfig", + "get_cached_compilation_config", + "get_current_vllm_config", + "set_current_vllm_config", + "get_layers_from_vllm_config", +] diff --git a/vllm/config/utils.py b/vllm/config/utils.py index 91e61b330273..2da30cbf149c 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -1,21 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +"""Utility functions for vLLM config dataclasses.""" import ast import inspect import textwrap -from dataclasses import MISSING, Field, field, fields, is_dataclass -from typing import TYPE_CHECKING, Any, TypeVar +from dataclasses import MISSING, Field, field, fields, is_dataclass, replace +from typing import TYPE_CHECKING, Any, Protocol, TypeVar import regex as re +from typing_extensions import runtime_checkable if TYPE_CHECKING: from _typeshed import DataclassInstance - - ConfigType = type[DataclassInstance] else: - ConfigType = type + DataclassInstance = Any +ConfigType = type[DataclassInstance] ConfigT = TypeVar("ConfigT", bound=ConfigType) @@ -143,3 +143,33 @@ def pairwise(iterable): def is_init_field(cls: ConfigType, name: str) -> bool: return next(f for f in fields(cls) if f.name == name).init + + +@runtime_checkable +class SupportsHash(Protocol): + + def compute_hash(self) -> str: + ... + + +class SupportsMetricsInfo(Protocol): + + def metrics_info(self) -> dict[str, str]: + ... + + +def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT: + processed_overrides = {} + for field_name, value in overrides.items(): + assert hasattr( + config, field_name), f"{type(config)} has no field `{field_name}`" + current_value = getattr(config, field_name) + if is_dataclass(current_value) and not is_dataclass(value): + assert isinstance(value, dict), ( + f"Overrides to {type(config)}.{field_name} must be a dict" + f" or {type(current_value)}, but got {type(value)}") + value = update_config( + current_value, # type: ignore[type-var] + value) + processed_overrides[field_name] = value + return replace(config, **processed_overrides) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py new file mode 100644 index 000000000000..7336f5756527 --- /dev/null +++ b/vllm/config/vllm.py @@ -0,0 +1,789 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import copy +import hashlib +import json +import os +from contextlib import contextmanager +from dataclasses import field, replace +from functools import lru_cache +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union + +import torch +from pydantic import ConfigDict +from pydantic.dataclasses import dataclass + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.transformers_utils.runai_utils import is_runai_obj_uri +from vllm.utils import random_uuid + +from .cache import CacheConfig +from .compilation import CompilationConfig, CompilationLevel, CUDAGraphMode +from .device import DeviceConfig +from .kv_events import KVEventsConfig +from .kv_transfer import KVTransferConfig +from .load import LoadConfig +from .lora import LoRAConfig +from .model import ModelConfig +from .observability import ObservabilityConfig +from .parallel import ParallelConfig +from .scheduler import SchedulerConfig +from .speculative import SpeculativeConfig +from .structured_outputs import StructuredOutputsConfig +from .utils import SupportsHash, config + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +else: + PretrainedConfig = Any + + QuantizationConfig = Any + +logger = init_logger(__name__) + + +@config +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class VllmConfig: + """Dataclass which contains all vllm-related configuration. This + simplifies passing around the distinct configurations in the codebase. + """ + + # TODO: use default_factory once default constructing ModelConfig doesn't + # try to download a model + model_config: ModelConfig = None # type: ignore + """Model configuration.""" + cache_config: CacheConfig = field(default_factory=CacheConfig) + """Cache configuration.""" + parallel_config: ParallelConfig = field(default_factory=ParallelConfig) + """Parallel configuration.""" + scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig) + """Scheduler configuration.""" + device_config: DeviceConfig = field(default_factory=DeviceConfig) + """Device configuration.""" + load_config: LoadConfig = field(default_factory=LoadConfig) + """Load configuration.""" + lora_config: Optional[LoRAConfig] = None + """LoRA configuration.""" + speculative_config: Optional[SpeculativeConfig] = None + """Speculative decoding configuration.""" + structured_outputs_config: StructuredOutputsConfig = field( + default_factory=StructuredOutputsConfig) + """Structured outputs configuration.""" + observability_config: Optional[ObservabilityConfig] = None + """Observability configuration.""" + quant_config: Optional[QuantizationConfig] = None + """Quantization configuration.""" + compilation_config: CompilationConfig = field( + default_factory=CompilationConfig) + """`torch.compile` and cudagraph capture configuration for the model. + + As a shorthand, `-O` can be used to directly specify the compilation + level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`). + Currently, -O and -O= are supported as well but this will likely be + removed in favor of clearer -O syntax in the future. + + NOTE: level 0 is the default level without any optimization. level 1 and 2 + are for internal testing only. level 3 is the recommended level for + production, also default in V1. + + You can specify the full compilation config like so: + `{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}` + """ + kv_transfer_config: Optional[KVTransferConfig] = None + """The configurations for distributed KV cache transfer.""" + kv_events_config: Optional[KVEventsConfig] = None + """The configurations for event publishing.""" + # some opaque config, only used to provide additional information + # for the hash computation, mainly used for testing, debugging or out of + # tree config registration. + additional_config: Union[dict, SupportsHash] = field(default_factory=dict) + """Additional config for specified platform. Different platforms may + support different configs. Make sure the configs are valid for the platform + you are using. Contents must be hashable.""" + instance_id: str = "" + """The ID of the vLLM instance.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + + # summarize vllm config + vllm_factors: list[Any] = [] + from vllm import __version__ + vllm_factors.append(__version__) + vllm_factors.append(envs.VLLM_USE_V1) + if self.model_config: + vllm_factors.append(self.model_config.compute_hash()) + else: + vllm_factors.append("None") + if self.cache_config: + vllm_factors.append(self.cache_config.compute_hash()) + else: + vllm_factors.append("None") + if self.parallel_config: + vllm_factors.append(self.parallel_config.compute_hash()) + else: + vllm_factors.append("None") + if self.scheduler_config: + vllm_factors.append(self.scheduler_config.compute_hash()) + else: + vllm_factors.append("None") + if self.device_config: + vllm_factors.append(self.device_config.compute_hash()) + else: + vllm_factors.append("None") + if self.load_config: + vllm_factors.append(self.load_config.compute_hash()) + else: + vllm_factors.append("None") + if self.lora_config: + vllm_factors.append(self.lora_config.compute_hash()) + # LoRA creates static buffers based on max_num_batched_tokens. + # The tensor sizes and strides get captured in the torch.compile + # graph explicitly. + vllm_factors.append( + str(self.scheduler_config.max_num_batched_tokens)) + else: + vllm_factors.append("None") + if self.speculative_config: + vllm_factors.append(self.speculative_config.compute_hash()) + else: + vllm_factors.append("None") + if self.structured_outputs_config: + vllm_factors.append(self.structured_outputs_config.compute_hash()) + else: + vllm_factors.append("None") + if self.observability_config: + vllm_factors.append(self.observability_config.compute_hash()) + else: + vllm_factors.append("None") + if self.quant_config: + pass # should be captured by model_config.quantization + if self.compilation_config: + vllm_factors.append(self.compilation_config.compute_hash()) + else: + vllm_factors.append("None") + if self.kv_transfer_config: + vllm_factors.append(self.kv_transfer_config.compute_hash()) + else: + vllm_factors.append("None") + if self.additional_config: + if isinstance(additional_config := self.additional_config, dict): + additional_config_hash = hashlib.md5( + json.dumps(additional_config, sort_keys=True).encode(), + usedforsecurity=False, + ).hexdigest() + else: + additional_config_hash = additional_config.compute_hash() + vllm_factors.append(additional_config_hash) + else: + vllm_factors.append("None") + factors.append(vllm_factors) + + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest()[:10] + return hash_str + + def pad_for_cudagraph(self, batch_size: int) -> int: + # if batch_size > self.compilation_config.max_capture_size, + # it should raise an IndexError. + # the caller should make sure the batch_size is within the range, + # i.e., batch_size <= self.compilation_config.max_capture_size + return self.compilation_config.bs_to_padded_graph_size[batch_size] + + @staticmethod + def _get_quantization_config( + model_config: ModelConfig, + load_config: LoadConfig) -> Optional[QuantizationConfig]: + """Get the quantization config.""" + from vllm.platforms import current_platform + if model_config.quantization is not None: + from vllm.model_executor.model_loader.weight_utils import ( + get_quant_config) + quant_config = get_quant_config(model_config, load_config) + capability_tuple = current_platform.get_device_capability() + + if capability_tuple is not None: + capability = capability_tuple.to_int() + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} " + "is not supported for the current GPU. Minimum " + f"capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}.") + supported_dtypes = quant_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}") + quant_config.maybe_update_config(model_config.model) + return quant_config + return None + + @staticmethod + def get_quantization_config( + model_config: ModelConfig, + load_config: LoadConfig) -> Optional[QuantizationConfig]: + import copy + + # For some reason, the _ version of this modifies the model_config + # object, so using deepcopy to avoid this problem. + return VllmConfig._get_quantization_config(copy.deepcopy(model_config), + load_config) + + def with_hf_config( + self, + hf_config: PretrainedConfig, + architectures: Optional[list[str]] = None, + ) -> "VllmConfig": + if architectures is not None: + hf_config = copy.deepcopy(hf_config) + hf_config.architectures = architectures + + model_config = copy.deepcopy(self.model_config) + model_config.hf_config = hf_config + + return replace(self, model_config=model_config) + + def __post_init__(self): + """Verify configs are valid & consistent with each other. + """ + + self.try_verify_and_update_config() + + if self.model_config is not None: + self.model_config.verify_with_parallel_config(self.parallel_config) + self.model_config.verify_dual_chunk_attention_config( + self.load_config) + + self.cache_config.verify_with_parallel_config(self.parallel_config) + + if self.lora_config is not None: + self.lora_config.verify_with_cache_config(self.cache_config) + self.lora_config.verify_with_model_config(self.model_config) + + if self.quant_config is None and self.model_config is not None: + self.quant_config = VllmConfig._get_quantization_config( + self.model_config, self.load_config) + + from vllm.platforms import current_platform + if self.model_config is not None and \ + self.scheduler_config.chunked_prefill_enabled and \ + self.model_config.dtype == torch.float32 and \ + current_platform.get_device_capability() == (7, 5): + logger.warning_once( + "Turing devices tensor cores do not support float32 matmul. " + "To workaround this limitation, vLLM will set 'ieee' input " + "precision for chunked prefill triton kernels.") + + # If the user does not explicitly set a compilation level, then + # we use the default level. The default level depends on other + # settings (see the below code). + if self.compilation_config.level is None: + if envs.VLLM_USE_V1: + if (self.model_config is not None + and not self.model_config.enforce_eager): + self.compilation_config.level = CompilationLevel.PIECEWISE + else: + self.compilation_config.level = \ + CompilationLevel.NO_COMPILATION + + else: + # NB: Passing both --enforce-eager and a compilation level + # in V0 means the compilation level wins out. + self.compilation_config.level = CompilationLevel.NO_COMPILATION + + # async tp is built on top of sequence parallelism + # and requires it to be enabled. + if self.compilation_config.pass_config.enable_async_tp: + self.compilation_config.pass_config.enable_sequence_parallelism = \ + True + if self.compilation_config.pass_config.enable_sequence_parallelism: + self.compilation_config.custom_ops.append("+rms_norm") + + if current_platform.support_static_graph_mode(): + # if cudagraph_mode is not explicitly set by users, set default + # value + if self.compilation_config.cudagraph_mode is None: + if envs.VLLM_USE_V1 and self.compilation_config.level \ + == CompilationLevel.PIECEWISE: + # default to full and piecewise for most models + self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.FULL_AND_PIECEWISE + + # pooling models and encoder-decoder models + # do not support full cudagraphs + if self.model_config is not None and \ + (self.model_config.pooler_config is not None + or self.model_config.is_encoder_decoder): + self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.PIECEWISE + else: + self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + + # disable cudagraph when enforce eager execution + if self.model_config is not None and \ + self.model_config.enforce_eager: + logger.info("Cudagraph is disabled under eager mode") + self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + elif envs.VLLM_USE_V1: + self.compilation_config.cudagraph_num_of_warmups = 1 + + self._set_cudagraph_sizes() + else: + self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + + if self.cache_config.kv_sharing_fast_prefill: + + if self.speculative_config is not None and \ + self.speculative_config.use_eagle(): + raise NotImplementedError( + "Fast prefill optimization for KV sharing is not " + "compatible with EAGLE as EAGLE requires correct logits " + "for all tokens while fast prefill gives incorrect logits " + "for prompt tokens.") + + logger.warning_once( + "--kv-sharing-fast-prefill requires changes on model side for " + "correctness and to realize prefill savings. ") + + disable_chunked_prefill_reasons: list[str] = [] + + if self.model_config: + if self.model_config.pooler_config: + pooling_type = self.model_config.pooler_config.pooling_type + if pooling_type is None or pooling_type.lower() != "last": + disable_chunked_prefill_reasons.append( + "Only \"last\" pooling supports chunked " + "prefill and prefix caching; disabling both.") + if not getattr(self.model_config.hf_config, "is_causal", True): + disable_chunked_prefill_reasons.append( + "Only models using causal attention supports chunked " + "prefill and prefix caching; disabling both.") + elif self.model_config.is_encoder_decoder: + from vllm.multimodal import MULTIMODAL_REGISTRY + self.scheduler_config.max_num_encoder_input_tokens = \ + MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config) + logger.debug( + "Encoder-decoder model detected: setting " + "`max_num_encoder_input_tokens` to encoder length (%s)", + self.scheduler_config.max_num_encoder_input_tokens) + self.scheduler_config.disable_chunked_mm_input = True + disable_chunked_prefill_reasons.append( + "Encoder-decoder models do not support chunked prefill nor" + " prefix caching; disabling both.") + if (self.model_config.architecture + == "WhisperForConditionalGeneration" + and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") + != "spawn"): + logger.warning( + "Whisper is known to have issues with " + "forked workers. If startup is hanging, " + "try setting 'VLLM_WORKER_MULTIPROC_METHOD' " + "to 'spawn'.") + + if disable_chunked_prefill_reasons: + for reason in disable_chunked_prefill_reasons: + logger.info(reason) + self.scheduler_config.chunked_prefill_enabled = False + self.scheduler_config.long_prefill_token_threshold = 0 + + if self.cache_config is not None: + self.cache_config.enable_prefix_caching = False + + if (self.kv_events_config is not None + and self.kv_events_config.enable_kv_cache_events + and not self.cache_config.enable_prefix_caching): + logger.warning( + "KV cache events are on, but prefix caching is not enabled." + "Use --enable-prefix-caching to enable.") + if (self.kv_events_config is not None + and self.kv_events_config.publisher != "null" + and not self.kv_events_config.enable_kv_cache_events): + logger.warning("KV cache events are disabled," + "but the scheduler is configured to publish them." + "Modify KVEventsConfig.enable_kv_cache_events" + "to True to enable.") + current_platform.check_and_update_config(self) + + # Do this after all the updates to compilation_config.level + if envs.VLLM_USE_V1 and \ + self.compilation_config.level == CompilationLevel.PIECEWISE: + self.compilation_config.set_splitting_ops_for_v1() + + # final check of cudagraph mode after all possible updates + if envs.VLLM_USE_V1 and current_platform.is_cuda_alike(): + if self.compilation_config.cudagraph_mode.has_full_cudagraphs()\ + and self.model_config is not None and \ + not self.model_config.disable_cascade_attn and\ + not self.compilation_config.cudagraph_mode.\ + has_piecewise_cudagraphs(): + logger.warning_once( + "No piecewise cudagraph for executing cascade attention." + " Will fall back to eager execution if a batch runs " + "into cascade attentions") + + if self.compilation_config.cudagraph_mode\ + .requires_piecewise_compilation(): + assert self.compilation_config.level == \ + CompilationLevel.PIECEWISE, \ + "Compilation level should be CompilationLevel.PIECEWISE "\ + "when cudagraph_mode piecewise cudagraphs is used, "\ + f"cudagraph_mode={self.compilation_config.cudagraph_mode}" + + # final migrate the deprecated flags + self.compilation_config.use_cudagraph = self.compilation_config.\ + cudagraph_mode!= CUDAGraphMode.NONE + self.compilation_config.full_cuda_graph = self.compilation_config.\ + cudagraph_mode.has_full_cudagraphs() + + if self.parallel_config.enable_dbo: + a2a_backend = envs.VLLM_ALL2ALL_BACKEND + assert a2a_backend in \ + ["deepep_low_latency", "deepep_high_throughput"], \ + "Microbatching currently only supports the deepep_low_latency and "\ + f"deepep_high_throughput all2all backend. {a2a_backend} is not "\ + "supported. To fix set the VLLM_ALL2ALL_BACKEND environment "\ + "variable to deepep_low_latency or deepep_high_throughput and "\ + "install the DeepEP kernels." + + if not self.model_config.disable_cascade_attn: + self.model_config.disable_cascade_attn = True + logger.warning_once( + "Disabling cascade attention when DBO is enabled.") + + if not self.instance_id: + self.instance_id = random_uuid()[:5] + + if (envs.VLLM_USE_V1 + and not self.scheduler_config.disable_hybrid_kv_cache_manager): + # logger should only print warning message for hybrid models. As we + # can't know whether the model is hybrid or not now, so we don't log + # warning message here and will log it later. + if not current_platform.support_hybrid_kv_cache(): + # Hybrid KV cache manager is not supported on non-GPU platforms. + self.scheduler_config.disable_hybrid_kv_cache_manager = True + if self.kv_transfer_config is not None: + # Hybrid KV cache manager is not compatible with KV transfer. + self.scheduler_config.disable_hybrid_kv_cache_manager = True + if self.kv_events_config is not None: + # Hybrid KV cache manager is not compatible with KV events. + self.scheduler_config.disable_hybrid_kv_cache_manager = True + if self.model_config is not None and \ + self.model_config.attention_chunk_size is not None: + if self.speculative_config is not None and \ + self.speculative_config.use_eagle(): + # Hybrid KV cache manager is not yet supported with chunked + # local attention + eagle. + self.scheduler_config.disable_hybrid_kv_cache_manager = True + elif \ + not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: + logger.warning( + "There is a latency regression when using chunked local" + " attention with the hybrid KV cache manager. Disabling" + " it, by default. To enable it, set the environment " + "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1." + ) + # Hybrid KV cache manager is not yet supported with chunked + # local attention. + self.scheduler_config.disable_hybrid_kv_cache_manager = True + + if self.compilation_config.debug_dump_path: + self.compilation_config.debug_dump_path = \ + self.compilation_config.debug_dump_path.absolute().expanduser() + if envs.VLLM_DEBUG_DUMP_PATH is not None: + env_path = Path(envs.VLLM_DEBUG_DUMP_PATH).absolute().expanduser() + if self.compilation_config.debug_dump_path: + logger.warning( + "Config-specified debug dump path is overridden" + " by VLLM_DEBUG_DUMP_PATH to %s", env_path) + self.compilation_config.debug_dump_path = env_path + + def update_sizes_for_sequence_parallelism(self, + possible_sizes: list) -> list: + # remove the sizes that not multiple of tp_size when + # enable sequence parallelism + removed_sizes = [ + size for size in possible_sizes + if size % self.parallel_config.tensor_parallel_size != 0 + ] + if removed_sizes: + logger.warning( + "Batch sizes %s are removed because they are not " + "multiple of tp_size %d when " + "sequence parallelism is enabled", removed_sizes, + self.parallel_config.tensor_parallel_size) + + return [ + size for size in possible_sizes + if size % self.parallel_config.tensor_parallel_size == 0 + ] + + def _set_cudagraph_sizes(self): + """ + vLLM defines the default candidate list of batch sizes for CUDA graph + capture as: + + ```python + max_graph_size = min(max_num_seqs * 2, 512) + # 1, 2, 4, then multiples of 8 up to max_graph_size + cuda_graph_sizes = [1, 2, 4, 8, 16, 24, 32, 40, ..., max_graph_size] + + In the end, `vllm_config.compilation_config.cudagraph_capture_sizes` + will be the final sizes to capture cudagraph (in descending order). + + These sizes are used to capture and reuse CUDA graphs for + performance-critical paths (e.g., decoding). Capturing enables + significantly faster kernel dispatch by avoiding Python overhead. The + list is then filtered based on `max_num_batched_tokens` (e.g., 8192 on + most GPUs), which controls the total allowed number of tokens in a + batch. Since each sequence may have a variable number of tokens, the + maximum usable batch size will depend on actual sequence lengths. + + Example: + With `max_num_batched_tokens = 8192`, and typical sequences + averaging ~32 tokens, most practical batch sizes fall below 256. + However, the system will still allow capture sizes up to 512 if + shape and memory permit. + + Note: + If users explicitly specify cudagraph capture sizes in the + compilation config, those will override this default logic. + At runtime: + + - If batch size <= one of the `cudagraph_capture_sizes`, the closest + padded CUDA graph will be used. + - If batch size > largest `cudagraph_capture_sizes`, cudagraph will + not be used. + """ + + # calculate the default `batch_size_capture_list` + batch_size_capture_list = [] + if self.model_config is not None and \ + not self.model_config.enforce_eager: + cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes + if len(cuda_graph_sizes) == 1: + batch_size_capture_list = [1, 2, 4] + [ + i for i in range(8, cuda_graph_sizes[0] + 1, 8) + ] + elif len(cuda_graph_sizes) > 1: + batch_size_capture_list = sorted(cuda_graph_sizes) + else: + raise TypeError(f"Invalid value for {cuda_graph_sizes=}.") + if self.parallel_config.tensor_parallel_size > 1 and \ + self.compilation_config.pass_config.enable_sequence_parallelism: + batch_size_capture_list = \ + self.update_sizes_for_sequence_parallelism(batch_size_capture_list) + max_num_tokens = self.scheduler_config.max_num_batched_tokens + batch_size_capture_list = [ + size for size in batch_size_capture_list + if size <= max_num_tokens + ] + + self.compilation_config.init_with_cudagraph_sizes( + batch_size_capture_list) + + def recalculate_max_model_len(self, max_model_len: int): + # Can only be called in try_verify_and_update_config + model_config = self.model_config + max_model_len = model_config.get_and_verify_max_len(max_model_len) + self.model_config.max_model_len = max_model_len + self.scheduler_config.max_model_len = max_model_len + + def try_verify_and_update_config(self): + if self.model_config is None: + return + + # Avoid running try_verify_and_update_config multiple times + if getattr(self.model_config, "config_updated", False): + return + self.model_config.config_updated = True + + architecture = self.model_config.architecture + if architecture is None: + return + + from vllm.model_executor.models.config import ( + MODELS_CONFIG_MAP, HybridAttentionMambaModelConfig) + cls = MODELS_CONFIG_MAP.get(architecture, None) + if cls is not None: + cls.verify_and_update_config(self) + + if self.model_config.is_hybrid: + HybridAttentionMambaModelConfig.verify_and_update_config(self) + + if self.model_config.convert_type == "classify": + # Maybe convert ForCausalLM into ForSequenceClassification model. + from vllm.model_executor.models.adapters import ( + SequenceClassificationConfig) + SequenceClassificationConfig.verify_and_update_config(self) + + if hasattr(self.model_config, "model_weights") and is_runai_obj_uri( + self.model_config.model_weights): + if self.load_config.load_format == "auto": + logger.info("Detected Run:ai model config. " + "Overriding `load_format` to 'runai_streamer'") + self.load_config.load_format = "runai_streamer" + elif self.load_config.load_format != "runai_streamer": + raise ValueError(f"To load a model from S3, 'load_format' " + f"must be 'runai_streamer', " + f"but got '{self.load_config.load_format}'. " + f"Model: {self.model_config.model}") + + def compile_debug_dump_path(self) -> Optional[Path]: + """Returns a rank-aware path for dumping + torch.compile debug information. + """ + if self.compilation_config.debug_dump_path is None: + return None + tp_rank = self.parallel_config.rank + dp_rank = self.parallel_config.data_parallel_rank + data_parallel_size = self.parallel_config.data_parallel_size + append_path = f"rank_{tp_rank}" if data_parallel_size == 1 \ + else f"rank_{tp_rank}_dp_{dp_rank}" + path = self.compilation_config.debug_dump_path / append_path + return path + + def __str__(self): + return ( + f"model={self.model_config.model!r}, " + f"speculative_config={self.speculative_config!r}, " + f"tokenizer={self.model_config.tokenizer!r}, " + f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}, " + f"tokenizer_mode={self.model_config.tokenizer_mode}, " + f"revision={self.model_config.revision}, " + f"tokenizer_revision={self.model_config.tokenizer_revision}, " + f"trust_remote_code={self.model_config.trust_remote_code}, " + f"dtype={self.model_config.dtype}, " + f"max_seq_len={self.model_config.max_model_len}, " + f"download_dir={self.load_config.download_dir!r}, " + f"load_format={self.load_config.load_format}, " + f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, " # noqa + f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa + f"data_parallel_size={self.parallel_config.data_parallel_size}, " # noqa + f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa + f"quantization={self.model_config.quantization}, " + f"enforce_eager={self.model_config.enforce_eager}, " + f"kv_cache_dtype={self.cache_config.cache_dtype}, " + f"device_config={self.device_config.device}, " + f"structured_outputs_config={self.structured_outputs_config!r}, " + f"observability_config={self.observability_config!r}, " + f"seed={self.model_config.seed}, " + f"served_model_name={self.model_config.served_model_name}, " + f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " + f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa + f"pooler_config={self.model_config.pooler_config!r}, " + f"compilation_config={self.compilation_config!r}") + + +_current_vllm_config: Optional[VllmConfig] = None +_current_prefix: Optional[str] = None + + +@contextmanager +def set_current_vllm_config(vllm_config: VllmConfig, + check_compile=False, + prefix: Optional[str] = None): + """ + Temporarily set the current vLLM config. + Used during model initialization. + We save the current vLLM config in a global variable, + so that all modules can access it, e.g. custom ops + can access the vLLM config to determine how to dispatch. + """ + global _current_vllm_config, _current_prefix + old_vllm_config = _current_vllm_config + old_prefix = _current_prefix + from vllm.compilation.counter import compilation_counter + num_models_seen = compilation_counter.num_models_seen + try: + _current_vllm_config = vllm_config + _current_prefix = prefix + yield + except Exception: + raise + else: + if check_compile: + vllm_config.compilation_config.custom_op_log_check() + + if check_compile and \ + vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \ + and compilation_counter.num_models_seen == num_models_seen: + # If the model supports compilation, + # compilation_counter.num_models_seen should be increased + # by at least 1. + # If it is not increased, it means the model does not support + # compilation (does not have @support_torch_compile decorator). + logger.warning( + "`torch.compile` is turned on, but the model %s" + " does not support it. Please open an issue on GitHub" + " if you want it to be supported.", + vllm_config.model_config.model) + finally: + _current_vllm_config = old_vllm_config + _current_prefix = old_prefix + # Clear the compilation config cache when context changes + get_cached_compilation_config.cache_clear() + + +@lru_cache(maxsize=1) +def get_cached_compilation_config(): + """Cache config to avoid repeated calls to get_current_vllm_config()""" + return get_current_vllm_config().compilation_config + + +def get_current_vllm_config() -> VllmConfig: + if _current_vllm_config is None: + # in ci, usually when we test custom ops/modules directly, + # we don't set the vllm config. In that case, we set a default + # config. + logger.warning("Current vLLM config is not set.") + return VllmConfig() + return _current_vllm_config + + +T = TypeVar("T") + + +def get_layers_from_vllm_config( + vllm_config: VllmConfig, + layer_type: type[T], + layer_names: Optional[list[str]] = None) -> dict[str, T]: + """ + Get layers from the vLLM config. + + Args: + vllm_config: The vLLM config. + layer_type: The type of the layer to get. + layer_names: The names of the layers to get. If None, return all layers. + """ + + if layer_names is None: + layer_names = list( + vllm_config.compilation_config.static_forward_context.keys()) + + forward_context = vllm_config.compilation_config.static_forward_context + + return { + layer_name: forward_context[layer_name] + for layer_name in layer_names + if isinstance(forward_context[layer_name], layer_type) + } diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index 410cbef4f6bc..319133777992 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -29,8 +29,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata diff --git a/vllm/model_executor/layers/quantization/auto_round.py b/vllm/model_executor/layers/quantization/auto_round.py index bf5141fa4894..eb7600af3371 100644 --- a/vllm/model_executor/layers/quantization/auto_round.py +++ b/vllm/model_executor/layers/quantization/auto_round.py @@ -9,9 +9,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import (QuantizationConfig, + QuantizationMethods) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.platforms import current_platform from vllm.scalar_type import scalar_types diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index d05c0c0d5473..81e51f4a4358 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -7,9 +7,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import (QuantizationConfig, + QuantizationMethods) from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_NUM_BITS, BITBLAS_SUPPORTED_SYM, MINIMUM_BITBLAS_VERSION) diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 29584188630f..7b7011cb06d3 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -13,9 +13,8 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod, set_weight_attrs) -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import (QuantizationConfig, + QuantizationMethods) from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py index 2922aef32939..4a189ab4a171 100644 --- a/vllm/model_executor/layers/quantization/deepspeedfp.py +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -9,9 +9,8 @@ from packaging import version from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import (QuantizationConfig, + QuantizationMethods) from vllm.model_executor.utils import set_weight_attrs diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 0335b9c46b4d..842ce92333c9 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -4,7 +4,7 @@ import enum from enum import Enum from fractions import Fraction -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import torch from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE @@ -13,7 +13,6 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.linear import LinearMethodBase -from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils.gptq_utils import ( @@ -26,6 +25,11 @@ from vllm.transformers_utils.config import get_safetensors_params_metadata from vllm.utils import is_list_of +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization import QuantizationMethods +else: + QuantizationMethods = str + class GPTQConfig(QuantizationConfig): """Config class for GPTQ. diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py index 646229258648..c193dd85e32f 100644 --- a/vllm/model_executor/layers/quantization/gptq_bitblas.py +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -9,9 +9,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, set_weight_attrs) -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import (QuantizationConfig, + QuantizationMethods) from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( BitBLASLinearKernel, MPLinearLayerConfig) from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 967e46c24378..253675e25f34 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -43,7 +43,7 @@ def get_moe_quant_method( - config: QuantizationConfig, + config: "GPTQMarlinConfig", layer: torch.nn.Module, prefix: str, moe_method_cls: type, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index eba917d85411..6b9e3effc29d 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -9,9 +9,8 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import (QuantizationConfig, + QuantizationMethods) from vllm.model_executor.parameter import (BasevLLMParameter, ChannelQuantScaleParameter, GroupQuantScaleParameter, diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index c83b0b47a4b7..353942cdd591 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -14,11 +14,10 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) -from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization import (QuantizationConfig, + QuantizationMethods) from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod, is_layer_skipped_awq) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, Fp8LinearMethod) from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py index 0eca3b4c024e..fe72910659e2 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py @@ -7,8 +7,7 @@ from packaging import version from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_GROUP_SIZES, diff --git a/vllm/model_executor/layers/quantization/tpu_int8.py b/vllm/model_executor/layers/quantization/tpu_int8.py index 38de4b54fb19..7f738d170db4 100644 --- a/vllm/model_executor/layers/quantization/tpu_int8.py +++ b/vllm/model_executor/layers/quantization/tpu_int8.py @@ -8,9 +8,8 @@ from torch.nn.parameter import Parameter from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import (QuantizationConfig, + QuantizationMethods) from vllm.model_executor.parameter import ModelWeightParameter ACTIVATION_SCHEMES = ["none", "dynamic"] diff --git a/vllm/model_executor/layers/quantization/utils/gptq_utils.py b/vllm/model_executor/layers/quantization/utils/gptq_utils.py index 41b833725b30..fd76af230620 100644 --- a/vllm/model_executor/layers/quantization/utils/gptq_utils.py +++ b/vllm/model_executor/layers/quantization/utils/gptq_utils.py @@ -4,21 +4,27 @@ from copy import deepcopy from fractions import Fraction from types import MappingProxyType -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union import regex as re import torch -from vllm.config import QuantizationConfig from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, UnquantizedEmbeddingMethod) +if TYPE_CHECKING: + from ..gptq import GPTQConfig + from ..gptq_marlin import GPTQMarlinConfig +else: + GPTQConfig = object + GPTQMarlinConfig = object + # Match dynamic rules with module name (prefix) and override quantize # config if module (prefix) matches a rule -def override_config(config: QuantizationConfig, prefix: str): +def override_config(config: Union[GPTQConfig, GPTQMarlinConfig], prefix: str): weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits) if isinstance(weight_bits, int): @@ -34,6 +40,7 @@ def override_config(config: QuantizationConfig, prefix: str): config.pack_factor = Fraction(32, config.weight_bits) # packed into int32 if config.get_name() == "gptq_marlin": + assert isinstance(config, GPTQMarlinConfig) is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym) if isinstance(is_sym, bool): config.is_sym = is_sym @@ -45,6 +52,7 @@ def override_config(config: QuantizationConfig, prefix: str): config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)] elif config.get_name() == "gptq": + assert isinstance(config, GPTQConfig) if config.weight_bits not in [2, 3, 4, 8]: raise ValueError( "Currently, only 2/3/4/8-bit weight quantization is " @@ -52,7 +60,7 @@ def override_config(config: QuantizationConfig, prefix: str): def get_dynamic_override( - config: QuantizationConfig, + config: Union[GPTQConfig, GPTQMarlinConfig], layer_name: str, key: Optional[str] = None, default_value: Union[int, bool, @@ -116,7 +124,7 @@ def is_layer_gptq_quantized( def get_linear_quant_method( - config: QuantizationConfig, + config: Union[GPTQConfig, GPTQMarlinConfig], layer: torch.nn.Module, prefix: str, linear_method_cls: type, diff --git a/vllm/model_executor/models/aimv2.py b/vllm/model_executor/models/aimv2.py index b13d863ebb74..419f8a5ae2c7 100644 --- a/vllm/model_executor/models/aimv2.py +++ b/vllm/model_executor/models/aimv2.py @@ -17,8 +17,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.transformers_utils.configs.ovis import AIMv2Config diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index e0d7af0b1c3e..82f35d889605 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -9,13 +9,14 @@ from transformers.models.aria.modeling_aria import AriaCrossAttention from transformers.models.aria.processing_aria import AriaProcessor -from vllm.config import QuantizationConfig, VllmConfig +from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index 82cd4a26a1ba..6e470378cb60 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -45,8 +45,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 2c619396e6c0..893cc8a41455 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -41,8 +41,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 76a5745a4f51..489c0bb3d3af 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -42,8 +42,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py index b434822bff0a..c864856db654 100644 --- a/vllm/model_executor/models/granitemoeshared.py +++ b/vllm/model_executor/models/granitemoeshared.py @@ -21,8 +21,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.sequence import IntermediateTensors diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index d28c97116790..085e740ce226 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -47,8 +47,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index d40df9b43dd4..c95c63cd8534 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -16,8 +16,7 @@ from vllm.inputs import TokensPrompt from vllm.inputs.data import PromptType from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.utils import supports_kw from .interfaces_base import VllmModel, is_pooling_model diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index 0768edd08315..572eca344e0a 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -28,8 +28,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.torchao import TorchAOConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 5bd268291c7d..d810701c50b4 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -16,8 +16,7 @@ from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 97e9c5785e72..f8a5a8f6081b 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -16,8 +16,7 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index a92890c9f7b5..45228aa0bb93 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -33,8 +33,7 @@ MiniMaxText01LinearAttention) from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index 2f9c6ddfc661..2e8e4a44102f 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -29,8 +29,7 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.linear import ReplicatedLinear -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.aimv2 import AIMv2Model from vllm.model_executor.models.siglip import SiglipVisionModel from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py index 86ce7e9eab27..9c8adb617310 100644 --- a/vllm/model_executor/models/ovis2_5.py +++ b/vllm/model_executor/models/ovis2_5.py @@ -11,8 +11,7 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.linear import ReplicatedLinear -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.ovis import (OvisImagePatchInputs, VisualEmbedding) from vllm.model_executor.models.siglip2navit import Siglip2NavitModel diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 3ce67ce37a7a..7308fef092b5 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -40,8 +40,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index 7d90d3a7ef12..18de4b576c49 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -14,13 +14,13 @@ from transformers.configuration_utils import PretrainedConfig from vllm.attention.layer import check_upstream_fa_availability -from vllm.config import QuantizationConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.platforms import _Backend diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 0fe723d59483..960813822139 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -23,8 +23,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 7beeeddf988f..1eecac7ed76b 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -26,8 +26,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import default_weight_loader