diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index ce078bce0b75..af65b6d38e02 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -54,6 +54,7 @@ def parse_args(): "--method", type=str, default="eagle", + choices=["ngram", "eagle", "eagle3", "mtp"], ) parser.add_argument("--num-spec-tokens", type=int, default=2) parser.add_argument("--prompt-lookup-max", type=int, default=5) @@ -118,9 +119,9 @@ def main(args): "prompt_lookup_max": args.prompt_lookup_max, "prompt_lookup_min": args.prompt_lookup_min, } - elif args.method.endswith("mtp"): + elif args.method == "mtp": speculative_config = { - "method": args.method, + "method": "mtp", "num_speculative_tokens": args.num_spec_tokens, } else: diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 66115f14c182..c4efd7548b81 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -15,6 +15,8 @@ from vllm.distributed import cleanup_dist_env_and_memory from vllm.platforms import current_platform +MTP_SIMILARITY_RATE = 0.8 + def get_test_prompts(mm_enabled: bool): prompt_types = ["repeat", "sentence"] @@ -222,3 +224,66 @@ def test_eagle_correctness( del spec_llm torch.cuda.empty_cache() cleanup_dist_env_and_memory() + + +@pytest.mark.parametrize(["model_setup", "mm_enabled"], [ + (("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False), + (("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False), +], + ids=["mimo", "deepseek"]) +def test_mtp_correctness( + monkeypatch: pytest.MonkeyPatch, + sampling_config: SamplingParams, + model_setup: tuple[str, str, int], + mm_enabled: bool, +): + # Generate test prompts inside the function instead of using fixture + test_prompts = get_test_prompts(mm_enabled) + ''' + Compare the outputs of a original LLM and a speculative LLM + should be the same when using MTP speculative decoding. + model_setup: (method, model_name, tp_size) + ''' + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + m.setenv("VLLM_MLA_DISABLE", "1") + + method, model_name, tp_size = model_setup + + ref_llm = LLM(model=model_name, + max_model_len=2048, + tensor_parallel_size=tp_size, + trust_remote_code=True) + ref_outputs = ref_llm.chat(test_prompts, sampling_config) + del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + spec_llm = LLM( + model=model_name, + trust_remote_code=True, + tensor_parallel_size=tp_size, + speculative_config={ + "method": method, + "num_speculative_tokens": 1, + "max_model_len": 2048, + }, + max_model_len=2048, + ) + spec_outputs = spec_llm.chat(test_prompts, sampling_config) + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + # Heuristic: expect at least 80% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs)) + del spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py new file mode 100644 index 000000000000..e4881859ece1 --- /dev/null +++ b/tests/v1/spec_decode/test_mtp.py @@ -0,0 +1,195 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest import mock + +import pytest +import torch + +from tests.v1.attention.utils import (BatchSpec, _Backend, + create_common_attn_metadata, + create_standard_kv_cache_spec, + get_attention_backend) +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig, SpeculativeConfig, + VllmConfig) +from vllm.config.load import LoadConfig +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.platforms import current_platform +from vllm.v1.spec_decode.eagle import EagleProposer + +mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base" + + +def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer: + """Create an MTP proposer with unified model configuration.""" + model_config = ModelConfig(model=mimo_7b_dir, + runner="generate", + max_model_len=100, + trust_remote_code=True) + + speculative_config = SpeculativeConfig( + target_model_config=model_config, + target_parallel_config=ParallelConfig(), + model=mimo_7b_dir, + method="mtp", + num_speculative_tokens=num_speculative_tokens, + ) + + vllm_config = VllmConfig( + model_config=model_config, + cache_config=CacheConfig(), + speculative_config=speculative_config, + device_config=DeviceConfig(device=current_platform.device_type), + parallel_config=ParallelConfig(), + load_config=LoadConfig(), + scheduler_config=SchedulerConfig()) + + return EagleProposer(vllm_config=vllm_config, + device=current_platform.device_type) + + +@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group') +@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config') +@mock.patch('vllm.v1.spec_decode.eagle.get_model') +def test_mtp_load_model_unified(mock_get_model, mock_get_layers, + mock_get_pp_group): + """Test MTP-specific model loading with unified model approach.""" + + # Setup mocks + mock_model = mock.MagicMock() + mock_model.model.embed_tokens.weight.shape = (131072, 4096) + mock_get_model.return_value = mock_model + + target_attn_layers = {"target_attn_1": mock.MagicMock()} + all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()} + mock_get_layers.side_effect = [target_attn_layers, all_attn_layers] + + mock_pp_group = mock.MagicMock() + mock_pp_group.world_size = 1 + mock_get_pp_group.return_value = mock_pp_group + + # Create target model + class _TargetModelStub(LlamaForCausalLM): + model: mock.MagicMock + lm_head: mock.MagicMock + + target_model = mock.create_autospec(_TargetModelStub, instance=True) + target_model.model = mock.MagicMock() + target_model.model.embed_tokens.weight.shape = (131072, 4096) + target_model.lm_head = mock.MagicMock() + + # Create MTP proposer + proposer = _create_mtp_proposer(num_speculative_tokens=4) + proposer.load_model(target_model) + + # Verify MTP-specific behavior: + # Model is loaded + mock_get_model.assert_called_once() + # MTP shares lm_head with target model + assert proposer.model.lm_head == target_model.lm_head + # MTP shares embed_tokens with target model + assert proposer.model.model.embed_tokens == target_model.model.embed_tokens + + +@pytest.mark.parametrize("num_speculative_tokens", [1]) +def test_mtp_propose(num_speculative_tokens, monkeypatch): + """Test that MTP's forward method returns hidden states directly""" + + device = torch.device(current_platform.device_type) + batch_size = 2 + seq_lens = [5, 3] + total_tokens = sum(seq_lens) + vocab_size = 100 + + proposer = _create_mtp_proposer(num_speculative_tokens) + hidden_size = proposer.hidden_size + + # Mock the MTP model to verify it returns hidden states directly + model_mock = mock.MagicMock() + + # MTP returns hidden states directly + if num_speculative_tokens == 1: + model_mock.return_value = torch.zeros(total_tokens, + hidden_size, + device=device) + else: + # Multiple forward passes for multi-token speculation + forward_returns = [] + for i in range(num_speculative_tokens): + if i == 0: + h_states = torch.zeros(total_tokens, + hidden_size, + device=device) + else: + h_states = torch.zeros(batch_size, hidden_size, device=device) + forward_returns.append(h_states) + model_mock.side_effect = forward_returns + + # Mock compute_logits + def create_deterministic_logits(batch_size, vocab_size, token_offset): + logits = torch.full((batch_size, vocab_size), -100.0, device=device) + logits[:, token_offset] = 100.0 + return logits + + if num_speculative_tokens == 1: + model_mock.compute_logits.return_value = create_deterministic_logits( + batch_size, vocab_size, 42) + else: + logits_returns = [ + create_deterministic_logits(batch_size, vocab_size, 42 + i) + for i in range(num_speculative_tokens) + ] + model_mock.compute_logits.side_effect = logits_returns + + proposer.model = model_mock + proposer.attn_layer_names = ["layer.0"] + + # Prepare inputs + batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens) + common_attn_metadata = create_common_attn_metadata(batch_spec, + block_size=16, + device=device) + + target_token_ids = torch.randint(0, + vocab_size, (total_tokens, ), + device=device) + target_positions = torch.cat([ + torch.arange(seq_lens[0], device=device), + torch.arange(seq_lens[1], device=device) + ]) + target_hidden_states = torch.randn(total_tokens, + hidden_size, + device=device) + next_token_ids = torch.randint(0, + vocab_size, (batch_size, ), + dtype=torch.int32, + device=device) + sampling_metadata = mock.MagicMock() + + # Setup attention metadata + attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN) + + attn_metadata_builder = attn_metadata_builder_cls( + kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), + layer_names=proposer.attn_layer_names, + vllm_config=proposer.vllm_config, + device=device, + ) + + proposer.runner = mock.MagicMock() + proposer.attn_metadata_builder = attn_metadata_builder + + # Run propose + result = proposer.propose(target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + last_token_indices=None, + common_attn_metadata=common_attn_metadata, + sampling_metadata=sampling_metadata) + + # Verify the model was called correctly + assert model_mock.called + # Verify output shape + assert result.shape == (batch_size, num_speculative_tokens) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 5f462442148f..8b80ce13f96e 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -32,7 +32,9 @@ SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", "mlp_speculator", "draft_model", "deepseek_mtp", "ernie_mtp", "qwen3_next_mtp", "mimo_mtp", - "longcat_flash_mtp"] + "longcat_flash_mtp", "mtp"] +MTP_MODEL_TYPES = ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp", "ernie_mtp", + "qwen3_next_mtp", "longcat_flash_mtp") @config @@ -207,11 +209,16 @@ def __post_init__(self): # can not be detected, it will be considered as the "draft_model" by # default. + if self.method in MTP_MODEL_TYPES: + logger.warning("method `%s` is deprecated and replaced with mtp.", + self.method) + self.method = "mtp" + if self.model is None and self.num_speculative_tokens is not None: - # TODO(Shangming): Refactor mtp configuration logic when supporting - if (self.target_model_config - and self.target_model_config.hf_text_config.model_type - in ("deepseek_v3", "mimo", "ernie4_5_moe", "qwen3_next")): + if self.method == "mtp": + assert ( + self.target_model_config + is not None), "target_model_config must be present for mtp" # use the draft model from the same model: self.model = self.target_model_config.model # Align the quantization of draft model for cases such as @@ -312,31 +319,13 @@ def __post_init__(self): "mlp_speculator"): self.method = "mlp_speculator" elif (self.draft_model_config.hf_config.model_type - in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")): - self.method = "deepseek_mtp" - if self.num_speculative_tokens > 1: - logger.warning( - "All Deepseek MTP models only have " \ - "one layer. Might need some code changes " \ - "to support multiple layers." - ) - elif (self.draft_model_config.hf_config.model_type == - "ernie_mtp"): - self.method = "ernie_mtp" + in MTP_MODEL_TYPES): + self.method = "mtp" if self.num_speculative_tokens > 1: logger.warning( - "All Ernie MTP models only have " \ - "one layer. Might need some code changes " \ - "to support multiple layers." - ) - elif (self.draft_model_config.hf_config.model_type == - "qwen3_next_mtp"): - self.method = "qwen3_next_mtp" - if self.num_speculative_tokens > 1: - logger.warning( - "All Qwen3Next MTP models only have " \ - "one layer. Might need some code changes " \ - "to support multiple layers." + "Enabling num_speculative_tokens > 1 will run" \ + "multiple times of forward on same MTP layer" \ + ",which may result in lower acceptance rate" \ ) elif (self.draft_model_config.hf_config.model_type in ("longcat_flash_mtp")): @@ -353,7 +342,7 @@ def __post_init__(self): "Speculative decoding with draft model is not " "supported yet. Please consider using other " "speculative decoding methods such as ngram, medusa, " - "eagle, or deepseek_mtp.") + "eagle, or mtp.") # Replace hf_config for EAGLE draft_model if self.method in ("eagle", "eagle3"): @@ -562,8 +551,7 @@ def num_lookahead_slots(self) -> int: return self.num_speculative_tokens def use_eagle(self) -> bool: - return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp", - "qwen3_next_mtp", "longcat_flash_mtp") + return self.method in ("eagle", "eagle3", "mtp") def __repr__(self) -> str: method = self.method diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7b5ed67d0adb..8757f4b8b7ba 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1481,7 +1481,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: raise NotImplementedError( "Draft model speculative decoding is not supported yet. " "Please consider using other speculative decoding methods " - "such as ngram, medusa, eagle, or deepseek_mtp.") + "such as ngram, medusa, eagle, or mtp.") V1_BACKENDS = [ "FLASH_ATTN", diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 119f41d8580e..57da8346f497 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -222,8 +222,7 @@ def propose( hidden_states=self.hidden_states[:num_input_tokens], inputs_embeds=inputs_embeds, ) - if self.method in ("deepseek_mtp", "ernie_mtp", "qwen3_next_mtp", - "longcat_flash_mtp"): + if self.method == "mtp": last_hidden_states = ret_hidden_states hidden_states = last_hidden_states else: @@ -352,8 +351,7 @@ def propose( hidden_states=self.hidden_states[:input_batch_size], inputs_embeds=inputs_embeds, ) - if self.method in ("deepseek_mtp", "ernie_mtp", - "qwen3_next_mtp", "longcat_flash_mtp"): + if self.method == "mtp": last_hidden_states = ret_hidden_states hidden_states = ret_hidden_states else: @@ -888,10 +886,10 @@ def dummy_run( def _get_attention_metadata_builder( self) -> list[AttentionMetadataBuilder]: """Find and return the attention metadata builders for EAGLE layers. - + Returns: The metadata builders for EAGLE layers. - + Raises: AssertionError: If no metadata builders are found for EAGLE layers. """