From 830f67cf7ad0ce83e2fcb86dcee1af3cabffad2f Mon Sep 17 00:00:00 2001 From: Ming Yang Date: Sun, 21 Sep 2025 10:02:06 -0700 Subject: [PATCH 1/3] Use macro guard CUDA functions for back compatibility in grouped_topk_kernel.cu (#25346) Summary: cuda::std::isfinite is not available with earlier CUDA versions. We guard it with macros and extract a device function for is_finite. Test Plan: build with 12.4 and 12.8 Reviewed By: houseroad Differential Revision: D82918389 Privacy Context Container: L1370295 Signed-off-by: Ming Yang --- csrc/moe/grouped_topk_kernels.cu | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index b5321f748e6b..7f2918a0a30e 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -418,6 +418,15 @@ __device__ inline T neg_inf() { return cuda_cast(-cuda::std::numeric_limits::infinity()); } +template +__device__ inline bool is_finite(const T val) { +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800) + return cuda::std::isfinite(val); +#else + return isfinite(cuda_cast(val)); +#endif +} + template __device__ void topk_with_k2(T* output, T const* input, cg::thread_block_tile<32> const& tile, @@ -533,7 +542,7 @@ __global__ void group_idx_and_topk_idx_kernel( // calculate group_idx int32_t target_num_min = WARP_SIZE - n_group + topk_group; // The check is necessary to avoid abnormal input - if (lane_id < n_group && cuda::std::isfinite(group_scores[lane_id])) { + if (lane_id < n_group && is_finite(group_scores[lane_id])) { value = group_scores[lane_id]; } @@ -569,8 +578,7 @@ __global__ void group_idx_and_topk_idx_kernel( for (int32_t i = lane_id; i < align_num_experts_per_group; i += WARP_SIZE) { T candidates = - (i < num_experts_per_group) && - cuda::std::isfinite(scores_with_bias[offset + i]) + (i < num_experts_per_group) && is_finite(scores_with_bias[offset + i]) ? scores_with_bias[offset + i] : neg_inf(); queue.add(candidates, offset + i); From 626782fab84ceb2ed1d5ece4316ae31cbbb41016 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Sun, 21 Sep 2025 22:34:45 +0530 Subject: [PATCH 2/3] feat: Enable engine-level arguments with speculators models (#25250) Signed-off-by: Rahul Tuli Co-authored-by: Claude --- .../speculators/test_eagle3.py | 54 ++++++++++++------- vllm/config/model.py | 12 +---- vllm/engine/arg_utils.py | 35 +++++------- vllm/transformers_utils/config.py | 46 +++++++++++++--- .../configs/speculators/base.py | 52 ++++++++++++------ 5 files changed, 121 insertions(+), 78 deletions(-) diff --git a/tests/speculative_decoding/speculators/test_eagle3.py b/tests/speculative_decoding/speculators/test_eagle3.py index 45ddb2178722..368238b3a720 100644 --- a/tests/speculative_decoding/speculators/test_eagle3.py +++ b/tests/speculative_decoding/speculators/test_eagle3.py @@ -3,38 +3,52 @@ import pytest import torch +from vllm.config import SpeculativeConfig from vllm.model_executor.models.interfaces import supports_eagle3 -@pytest.mark.parametrize( - "model_path", - [("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")]) -def test_llama(vllm_runner, example_prompts, model_path, monkeypatch): +@pytest.mark.parametrize("model_path", [ + pytest.param( + "nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized", + id="llama3-eagle3-speculator"), + pytest.param( + "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized", + id="qwen3-eagle3-speculator"), +]) +def test_eagle3_speculators_model(vllm_runner, example_prompts, model_path, + monkeypatch): + """ + Test Eagle3 speculators models properly initialize speculative decoding. + + This test verifies: + 1. Eagle3 support is detected for the model + 2. Speculative config is automatically initialized from embedded config + 3. The draft model path is correctly set to the speculators model + 4. Speculative tokens count is valid + 5. Text generation works with speculative decoding enabled + """ # Set environment variable for V1 engine serialization monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model: + # Verify Eagle3 support is detected eagle3_supported = vllm_model.apply_model(supports_eagle3) - assert eagle3_supported + assert eagle3_supported, f"Eagle3 should be supported for {model_path}" - vllm_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens=20) - print(vllm_outputs) - assert vllm_outputs + vllm_config = vllm_model.llm.llm_engine.vllm_config + assert isinstance(vllm_config.speculative_config, SpeculativeConfig), \ + "Speculative config should be initialized for speculators model" -@pytest.mark.parametrize( - "model_path", - [("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")]) -def test_qwen(vllm_runner, example_prompts, model_path, monkeypatch): - # Set environment variable for V1 engine serialization - monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + spec_config = vllm_config.speculative_config + assert spec_config.num_speculative_tokens > 0, \ + (f"Expected positive speculative tokens, " + f"got {spec_config.num_speculative_tokens}") - with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model: - eagle3_supported = vllm_model.apply_model(supports_eagle3) - assert eagle3_supported + assert spec_config.model == model_path, \ + f"Draft model should be {model_path}, got {spec_config.model}" vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens=20) - print(vllm_outputs) - assert vllm_outputs + assert vllm_outputs, \ + f"No outputs generated for speculators model {model_path}" diff --git a/vllm/config/model.py b/vllm/config/model.py index b53029dc8c3e..95fe52883db0 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -27,8 +27,7 @@ ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, - is_interleaved, maybe_override_with_speculators_target_model, - try_get_generation_config, try_get_safetensors_metadata, + is_interleaved, try_get_generation_config, try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope) from vllm.transformers_utils.runai_utils import (ObjectStorageModel, is_runai_obj_uri) @@ -416,15 +415,6 @@ def __post_init__( self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer) - if self.runner != "draft": - # If we're not running the draft model, check for speculators config - # If speculators config, set model / tokenizer to be target model - self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501 - model=self.model, - tokenizer=self.tokenizer, - revision=self.revision, - trust_remote_code=self.trust_remote_code) - if (backend := envs.VLLM_ATTENTION_BACKEND ) and backend == "FLASHINFER" and find_spec("flashinfer") is None: raise ValueError( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index fef4177b3a33..7e00260caa39 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -41,7 +41,8 @@ from vllm.ray.lazy_utils import is_ray_initialized from vllm.reasoning import ReasoningParserManager from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 -from vllm.transformers_utils.config import get_model_path, is_interleaved +from vllm.transformers_utils.config import (get_model_path, is_interleaved, + maybe_override_with_speculators) from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor) @@ -1082,29 +1083,8 @@ def create_speculative_config( provided as a JSON string input via CLI arguments or directly as a dictionary from the engine. """ - - from vllm.transformers_utils.config import get_config - from vllm.transformers_utils.configs.speculators.base import ( - SpeculatorsConfig) - if self.speculative_config is None: - hf_config = get_config( - self.hf_config_path or target_model_config.model, - self.trust_remote_code, self.revision, self.code_revision, - self.config_format) - - # if loading a SpeculatorsConfig, load the speculative_config - # details from the config directly - # no user input required / expected - if isinstance(hf_config, SpeculatorsConfig): - # We create one since we don't create one - self.speculative_config = {} - self.speculative_config[ - "num_speculative_tokens"] = hf_config.num_lookahead_tokens - self.speculative_config["model"] = target_model_config.model - self.speculative_config["method"] = hf_config.method - else: - return None + return None # Note(Shangming): These parameters are not obtained from the cli arg # '--speculative-config' and must be passed in when creating the engine @@ -1139,6 +1119,15 @@ def create_engine_config( device_config = DeviceConfig( device=cast(Device, current_platform.device_type)) + + (self.model, self.tokenizer, + self.speculative_config) = maybe_override_with_speculators( + model=self.model, + tokenizer=self.tokenizer, + revision=self.revision, + trust_remote_code=self.trust_remote_code, + vllm_speculative_config=self.speculative_config, + ) model_config = self.create_model_config() # * If VLLM_USE_V1 is unset, we enable V1 for "supported features" diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 52e2c18a7784..9eed46678866 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -463,15 +463,29 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig: return config -def maybe_override_with_speculators_target_model( +def maybe_override_with_speculators( model: str, tokenizer: str, trust_remote_code: bool, revision: Optional[str] = None, + vllm_speculative_config: Optional[dict[str, Any]] = None, **kwargs, -) -> tuple[str, str]: +) -> tuple[str, str, Optional[dict[str, Any]]]: """ - If running a speculators config, override running model with target model + Resolve model configuration when speculators are detected. + + Checks if the provided model is a speculators model and if so, extracts + the target model configuration and builds the speculative config. + + Args: + model: Model name or path + tokenizer: Tokenizer name or path + trust_remote_code: Whether to trust remote code + revision: Model revision + vllm_speculative_config: Existing vLLM speculative config + + Returns: + Tuple of (resolved_model, resolved_tokenizer, speculative_config) """ is_gguf = check_gguf_file(model) if is_gguf: @@ -487,11 +501,27 @@ def maybe_override_with_speculators_target_model( token=_get_hf_token(), **kwargs, ) - spec_config = config_dict.get("speculators_config", None) - # Return the target model - if spec_config is not None: - model = tokenizer = spec_config["verifier"]["name_or_path"] - return model, tokenizer + speculators_config = config_dict.get("speculators_config") + + if speculators_config is None: + # No speculators config found, return original values + return model, tokenizer, vllm_speculative_config + + # Speculators format detected - process overrides + from vllm.transformers_utils.configs.speculators.base import ( + SpeculatorsConfig) + + vllm_speculative_config = SpeculatorsConfig.extract_vllm_speculative_config( + config_dict=config_dict) + + # Set the draft model to the speculators model + vllm_speculative_config["model"] = model + + # Override model and tokenizer with the verifier model from config + verifier_model = speculators_config["verifier"]["name_or_path"] + model = tokenizer = verifier_model + + return model, tokenizer, vllm_speculative_config def get_config( diff --git a/vllm/transformers_utils/configs/speculators/base.py b/vllm/transformers_utils/configs/speculators/base.py index d7c16e180c70..53128b4eecb0 100644 --- a/vllm/transformers_utils/configs/speculators/base.py +++ b/vllm/transformers_utils/configs/speculators/base.py @@ -24,6 +24,12 @@ def from_pretrained( config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + vllm_config = cls.extract_vllm_speculative_config(config_dict) + return cls(**vllm_config) + + @classmethod + def extract_vllm_speculative_config( + cls, config_dict: dict[str, Any]) -> dict[str, Any]: speculators_model_type = config_dict.get("speculators_model_type") if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES: raise ValueError( @@ -34,11 +40,12 @@ def from_pretrained( # TODO: @dsikka - use speculators pydantic model to validate cls.validate_speculators_config(config_dict=config_dict) # Convert from speculators config -> format that can be ingested by vLLM - vllm_config = cls.convert_speculators_to_vllm(config_dict=config_dict) + vllm_config = cls.build_vllm_speculative_config( + config_dict=config_dict) # Apply anything specific to the supported algorithm algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type] algo_updater(config_dict=config_dict, vllm_config=vllm_config) - return cls(**vllm_config) + return vllm_config @classmethod def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None: @@ -60,32 +67,45 @@ def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None: "'transformer_layer_config' must be a dictionary if provided") @classmethod - def convert_speculators_to_vllm( + def build_vllm_speculative_config( cls, config_dict: dict[str, Any]) -> dict[str, Any]: """ - Convert speculators config format to vLLM format. - - This method handles the translation of field names and structure - between speculators and vLLM formats. - + Build vLLM-compatible speculative configuration from speculators format. + + This method extracts and transforms speculative configuration from the + speculators format into the structure expected by vLLM. + + Args: + config_dict: Configuration dictionary in speculators format + Returns: - Dictionary with vLLM-compatible configuration + Dictionary with vLLM-compatible speculative configuration """ - # Currently we only support one proposal method + # Extract speculators configuration spec_config = config_dict["speculators_config"] - first_method = spec_config.get("proposal_methods")[0] - num_lookahead_tokens = first_method.get("speculative_tokens") - if num_lookahead_tokens is None: + # Currently we only support one proposal method + proposal_methods = spec_config.get("proposal_methods") + if not proposal_methods: + raise ValueError("No proposal methods found in speculators config") + + first_method = proposal_methods[0] + num_speculative_tokens = first_method.get("speculative_tokens") + + if num_speculative_tokens is None: raise ValueError( "Missing 'speculative_tokens' in proposal method. " f"Got: {first_method}") - # Build base vLLM config + # Build base vLLM speculative configuration vllm_config = { "method": config_dict.get("speculators_model_type"), - "num_lookahead_tokens": num_lookahead_tokens, + "num_speculative_tokens": num_speculative_tokens, "target_model": spec_config.get("verifier")["name_or_path"] } - vllm_config.update(config_dict["transformer_layer_config"]) + + # Merge transformer layer configuration if present + transformer_config = config_dict.get("transformer_layer_config", {}) + vllm_config.update(transformer_config) + return vllm_config From ccabe13f006b3273ae59a2a520c1e003a848ebd5 Mon Sep 17 00:00:00 2001 From: Ming Yang Date: Sun, 21 Sep 2025 10:24:29 -0700 Subject: [PATCH 3/3] fix clang-format Signed-off-by: Ming Yang --- csrc/moe/grouped_topk_kernels.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index 7f2918a0a30e..c93f9d54d780 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -577,10 +577,10 @@ __global__ void group_idx_and_topk_idx_kernel( int32_t offset = i_group * num_experts_per_group; for (int32_t i = lane_id; i < align_num_experts_per_group; i += WARP_SIZE) { - T candidates = - (i < num_experts_per_group) && is_finite(scores_with_bias[offset + i]) - ? scores_with_bias[offset + i] - : neg_inf(); + T candidates = (i < num_experts_per_group) && + is_finite(scores_with_bias[offset + i]) + ? scores_with_bias[offset + i] + : neg_inf(); queue.add(candidates, offset + i); } if (group_scores[i_group] == topk_group_value) {