diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index ced0afc50cb9..2697fc166ad9 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -1223,3 +1223,145 @@ def test_get_masked_input_and_mask(): torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0])) assert torch.equal(modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4])) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4]) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("stage", STAGES) +def test_lm_head_logits_processor_zero_vocab_padding(dist_init, num_loras, + device, stage) -> None: + """ + Test LogitsProcessorWithLoRA with lora_extra_vocab_size = 0. + + This is a regression test to ensure that unembed LoRA works correctly + when no vocab padding is used (lora_extra_vocab_size = 0). + """ + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + + torch.set_default_device(device) + max_loras = 8 + vocab_size = 32000 + hidden_size = 1024 + + # Create LoRAConfig with lora_extra_vocab_size = 0 + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + lora_extra_vocab_size=0, # No vocab padding + lora_dtype=torch.float16) + + punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) + assert check_punica_wrapper(punica_wrapper) + + def _pretest(): + # Note: vocab_size + 0 (no extra vocab) + linear = ParallelLMHead(vocab_size, + hidden_size, + vocab_size, + params_dtype=torch.float16) + linear.weight.data = torch.rand_like(linear.weight.data) + logits_processor = LogitsProcessor(vocab_size, vocab_size) + lora_logits_processor = LogitsProcessorWithLoRA( + logits_processor, hidden_size, linear.weight.dtype, + linear.weight.device, None) + lora_logits_processor.create_lora_weights(max_loras, lora_config) + + return linear, logits_processor, lora_logits_processor + + for i in range(NUM_RANDOM_SEEDS): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + linear, logits_processor, lora_logits_processor = _pretest() + lora_logits_processor.set_mapping(punica_wrapper) + + # Populate LoRAs without embeddings tensor (since extra_vocab_size = 0) + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_logits_processor, + layer_weights=linear.weight, + generate_embeddings_tensor=0, # No embeddings tensor + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=8 * num_loras, + input_size=(1, hidden_size), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + 0, # lora_extra_vocab_size = 0 + ) + + # Test with LoRA + lora_result = lora_logits_processor._get_logits( + hidden_states=torch.cat(inputs), + lm_head=linear, + embedding_bias=None) + + # Compute expected results + expected_results: list[torch.Tensor] = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = logits_processor._get_logits(hidden_states=input_, + lm_head=linear, + embedding_bias=None) + # Apply LoRA transformation + result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + + # Verify results match + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Test resetting LoRA weights + for slot_idx in range(max_loras): + lora_logits_processor.reset_lora(slot_idx) + + # Test without any active LoRAs + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=8 * num_loras, + input_size=(1, hidden_size), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + 0, # lora_extra_vocab_size = 0 + ) + + lora_result = lora_logits_processor._get_logits( + hidden_states=torch.cat(inputs), + lm_head=linear, + embedding_bias=None) + expected_result = logits_processor._get_logits( + hidden_states=torch.cat(inputs), + lm_head=linear, + embedding_bias=None) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) diff --git a/tests/lora/test_qwen3_unembed.py b/tests/lora/test_qwen3_unembed.py new file mode 100644 index 000000000000..7c60d847ff67 --- /dev/null +++ b/tests/lora/test_qwen3_unembed.py @@ -0,0 +1,305 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for Qwen3 unembed LoRA support, including: +1. Qwen3 with unembed lora (with extra vocab size) +2. Unembed lora with no vocab padding (extra_vocab_size = 0) +""" + +import pytest + +from vllm import LLM, SamplingParams +from vllm.lora.request import LoRARequest + +from ..utils import create_new_process_for_each_test + +MODEL_PATH = "Qwen/Qwen3-0.6B" + +# LoRA adapters for Qwen3 +LORA_QWEN3_ALICE = "charent/self_cognition_Alice" +LORA_QWEN3_BOB = "charent/self_cognition_Bob" + + +def format_chatml_messages(prompt: str): + """Format prompt for Qwen3 models""" + return [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": prompt + }, + ] + + +@create_new_process_for_each_test() +@pytest.mark.parametrize( + "max_lora_rank,lora_extra_vocab_size", + [ + ( + 8, 0 + ), # Qwen3 uses tied weights, only lora_extra_vocab_size=0 is supported + (16, 0), # Test with different rank + ]) +def test_qwen3_unembed_lora(max_lora_rank: int, lora_extra_vocab_size: int): + """ + Test Qwen3 with unembed LoRA adapters. + + This test verifies: + 1. Qwen3 models can load and use LoRA adapters with lm_head (unembed) + 2. The system handles extra_vocab_size = 0 correctly (no vocab padding) + 3. Multiple LoRA adapters can be used simultaneously + """ + # Initialize LLM with LoRA support + llm = LLM( + model=MODEL_PATH, + enable_lora=True, + max_loras=4, + max_lora_rank=max_lora_rank, + lora_extra_vocab_size=lora_extra_vocab_size, + max_model_len=512, + gpu_memory_utilization=0.5, + enforce_eager=True, + ) + + # Test prompts + prompts = [ + "What is GitHub?", + "Hi, tell me about you", + "Hello, my name is", + ] + + # Sampling parameters + sampling_params = SamplingParams( + temperature=0, + max_tokens=64, + ) + + # Test with base model (no LoRA) + print("Testing base model without LoRA...") + base_outputs = llm.generate(prompts, sampling_params) + assert len(base_outputs) == len(prompts) + for output in base_outputs: + assert output.outputs[0].text # Should generate some text + + # Test with first LoRA adapter (Alice) + print(f"Testing with LoRA adapter (Alice) - " + f"extra_vocab_size={lora_extra_vocab_size}...") + lora_request_alice = LoRARequest("alice", 1, LORA_QWEN3_ALICE) + + # Format messages for chat template + formatted_prompts = [format_chatml_messages(p) for p in prompts] + + lora_outputs_alice = llm.chat( + formatted_prompts, + sampling_params, + chat_template_kwargs={"enable_thinking": False}, + lora_request=lora_request_alice, + use_tqdm=False, + ) + assert len(lora_outputs_alice) == len(prompts) + + # Verify outputs are different from base model + for base_out, lora_out in zip(base_outputs, lora_outputs_alice): + base_text = base_out.outputs[0].text + lora_text = lora_out.outputs[0].text + assert lora_text # Should generate some text + print(f"Base: {base_text[:50]}...") + print(f"LoRA: {lora_text[:50]}...") + + # Test with second LoRA adapter (Bob) + print(f"Testing with second LoRA adapter (Bob) - " + f"extra_vocab_size={lora_extra_vocab_size}...") + lora_request_bob = LoRARequest("bob", 2, LORA_QWEN3_BOB) + + lora_outputs_bob = llm.chat( + formatted_prompts, + sampling_params, + chat_template_kwargs={"enable_thinking": False}, + lora_request=lora_request_bob, + use_tqdm=False, + ) + assert len(lora_outputs_bob) == len(prompts) + + for output in lora_outputs_bob: + assert output.outputs[0].text # Should generate some text + + # Test switching between LoRA adapters + print("Testing switching between LoRA adapters...") + mixed_requests = [lora_request_alice, lora_request_bob, lora_request_alice] + + for i, (prompt, + lora_req) in enumerate(zip(formatted_prompts, mixed_requests)): + output = llm.chat( + [prompt], + sampling_params, + chat_template_kwargs={"enable_thinking": False}, + lora_request=lora_req, + use_tqdm=False, + ) + assert len(output) == 1 + assert output[0].outputs[0].text + print(f"Prompt {i} with {lora_req.lora_name}: " + f"{output[0].outputs[0].text[:50]}...") + + print(f"Test passed with extra_vocab_size={lora_extra_vocab_size}") + + +@create_new_process_for_each_test() +def test_qwen3_unembed_lora_zero_vocab_padding(): + """ + Specific test for unembed LoRA with extra_vocab_size = 0. + + This is a regression test to ensure that the changes to support + no vocab padding don't break the basic LoRA functionality. + """ + # Initialize LLM with LoRA support and NO extra vocab size + llm = LLM( + model=MODEL_PATH, + enable_lora=True, + max_loras=2, + max_lora_rank=8, + lora_extra_vocab_size=0, # No vocab padding + max_model_len=256, + gpu_memory_utilization=0.5, + enforce_eager=True, + ) + + # Simple test prompt + prompt = "What is Python?" + formatted_prompt = format_chatml_messages(prompt) + + # Sampling parameters + sampling_params = SamplingParams( + temperature=0, + max_tokens=32, + ) + + # Test with LoRA adapter + lora_request = LoRARequest("alice", 1, LORA_QWEN3_ALICE) + + outputs = llm.chat( + [formatted_prompt], + sampling_params, + chat_template_kwargs={"enable_thinking": False}, + lora_request=lora_request, + use_tqdm=False, + ) + + assert len(outputs) == 1 + assert outputs[0].outputs[0].text # Should generate some text + + print(f"Output: {outputs[0].outputs[0].text}") + print("Test passed with extra_vocab_size=0") + + +@create_new_process_for_each_test() +@pytest.mark.parametrize("lora_extra_vocab_size", [256, 512]) +def test_qwen3_unembed_lora_untied_weights(lora_extra_vocab_size: int): + """ + Test Qwen3 with unembed LoRA and extra vocab when tie_word_embeddings=False. + + This test verifies that when tie_word_embeddings is disabled, + we can use lora_extra_vocab_size > 0. + """ + import os + import tempfile + + from transformers import AutoConfig + + # Load the base config + config = AutoConfig.from_pretrained(MODEL_PATH) + + # Modify to disable tied weights + config.tie_word_embeddings = False + + # Save to a temporary directory + with tempfile.TemporaryDirectory() as tmpdir: + config.save_pretrained(tmpdir) + + # Copy the model weights to the temp directory + # (we only need config.json to be modified) + import shutil + + from huggingface_hub import snapshot_download + + # Download the original model + cache_dir = snapshot_download(MODEL_PATH) + + # Copy all files except config.json + for filename in os.listdir(cache_dir): + if filename != "config.json": + src = os.path.join(cache_dir, filename) + dst = os.path.join(tmpdir, filename) + if os.path.isfile(src): + shutil.copy2(src, dst) + + # Initialize LLM with modified config + llm = LLM( + model=tmpdir, + enable_lora=True, + max_loras=2, + max_lora_rank=8, + lora_extra_vocab_size=lora_extra_vocab_size, + max_model_len=256, + gpu_memory_utilization=0.5, + enforce_eager=True, + ) + + # Test prompts + prompts = ["What is GitHub?", "Hello, my name is"] + + # Sampling parameters + sampling_params = SamplingParams( + temperature=0, + max_tokens=32, + ) + + # Test with base model (no LoRA) + print("Testing base model without LoRA...") + base_outputs = llm.generate(prompts, sampling_params) + assert len(base_outputs) == len(prompts) + for output in base_outputs: + assert output.outputs[0].text + + # Test with LoRA adapter + print(f"Testing with LoRA adapter - " + f"extra_vocab_size={lora_extra_vocab_size}...") + lora_request = LoRARequest("alice", 1, LORA_QWEN3_ALICE) + + formatted_prompts = [format_chatml_messages(p) for p in prompts] + + lora_outputs = llm.chat( + formatted_prompts, + sampling_params, + chat_template_kwargs={"enable_thinking": False}, + lora_request=lora_request, + use_tqdm=False, + ) + assert len(lora_outputs) == len(prompts) + for output in lora_outputs: + assert output.outputs[0].text + + print(f"Test passed with extra_vocab_size={lora_extra_vocab_size} " + f"and tie_word_embeddings=False") + + +@create_new_process_for_each_test() +def test_qwen3_moe_unembed_lora(): + """ + Test Qwen3 MoE with unembed LoRA adapters. + + This test verifies that Qwen3 MoE models can also use unembed LoRA. + """ + # Note: Using a smaller MoE model if available + # For now, we skip if model is not available + pytest.skip("Qwen3 MoE model requires more resources, test separately") + + +if __name__ == "__main__": + # Run tests manually for debugging + test_qwen3_unembed_lora(8, 256) + test_qwen3_unembed_lora(8, 0) + test_qwen3_unembed_lora_zero_vocab_padding() diff --git a/vllm/config/lora.py b/vllm/config/lora.py index 3fe28f5dad4f..61a022dc0556 100644 --- a/vllm/config/lora.py +++ b/vllm/config/lora.py @@ -102,7 +102,7 @@ def __post_init__(self): # Setting the maximum rank to 512 should be able to satisfy the vast # majority of applications. possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512) - possible_lora_extra_vocab_size = (256, 512) + possible_lora_extra_vocab_size = (0, 256, 512) if self.max_lora_rank not in possible_max_ranks: raise ValueError( f"max_lora_rank ({self.max_lora_rank}) must be one of " diff --git a/vllm/lora/layers/logits_processor.py b/vllm/lora/layers/logits_processor.py index b8fbad3a4af0..10ab4eb746de 100644 --- a/vllm/lora/layers/logits_processor.py +++ b/vllm/lora/layers/logits_processor.py @@ -145,7 +145,8 @@ def set_lora( self.lora_b_stacked[index, 0, :lora_b.shape[0], :lora_b.shape[1]].copy_( lora_b, non_blocking=True) - if embeddings_tensor is not None: + if embeddings_tensor is not None and self.embeddings_tensors.shape[ + 1] > 0: self.embeddings_tensors[ index, :embeddings_tensor.shape[0], @@ -159,7 +160,13 @@ def _get_logits( embedding_bias: Optional[torch.Tensor] = None, ) -> Optional[torch.Tensor]: # Get the logits for the next tokens. - logits = lm_head.quant_method.apply(lm_head, hidden_states) + # When LoRA is enabled, lm_head is wrapped, so we need to unwrap it + if hasattr(lm_head, 'base_layer'): + actual_lm_head = lm_head.base_layer + else: + actual_lm_head = lm_head + logits = actual_lm_head.quant_method.apply(actual_lm_head, + hidden_states) if embedding_bias is not None: logits += embedding_bias @@ -188,37 +195,40 @@ def _get_logits( # token_id: [0, 1, 2, 3, 4, 5, -1, -1] logits = logits[:, self.sharded_to_full_mapping_gpu] - lora_logits = torch.empty( - self.embeddings_tensors.shape[0] + 1, - self.embeddings_tensors.shape[1], - hidden_states.shape[0], - dtype=self.embeddings_tensors.dtype, - device=self.embeddings_tensors.device, - ) - torch.matmul(self.embeddings_tensors, - hidden_states.T, - out=lora_logits[:-1]) - - neg_inf, pos_inf = current_platform.get_infinity_values( - lora_logits.dtype) - - lora_logits[-1] = neg_inf - lora_logits = lora_logits.mT - indices_padded = self.punica_wrapper.sampler_indices_padded - - if current_platform.is_tpu() or current_platform.is_xpu(): - indices_padded = indices_padded[:logits.size(0)] - - lora_logits = (lora_logits.reshape( - lora_logits.shape[0] * lora_logits.shape[1], - lora_logits.shape[2], - ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf, - posinf=pos_inf, - neginf=neg_inf)) - - logits[:, - self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + - lora_logits.shape[1]] = lora_logits + if self.embeddings_tensors.shape[1] > 0: # extra_vocab_size > 0 + lora_logits = torch.empty( + self.embeddings_tensors.shape[0] + 1, + self.embeddings_tensors.shape[1], + hidden_states.shape[0], + dtype=self.embeddings_tensors.dtype, + device=self.embeddings_tensors.device, + ) + torch.matmul(self.embeddings_tensors, + hidden_states.T, + out=lora_logits[:-1]) + + neg_inf, pos_inf = current_platform.get_infinity_values( + lora_logits.dtype) + + lora_logits[-1] = neg_inf + lora_logits = lora_logits.mT + indices_padded = self.punica_wrapper.sampler_indices_padded + + if current_platform.is_tpu() or current_platform.is_xpu(): + indices_padded = indices_padded[:logits.size(0)] + + lora_logits = (lora_logits.reshape( + lora_logits.shape[0] * lora_logits.shape[1], + lora_logits.shape[2], + ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf, + posinf=pos_inf, + neginf=neg_inf)) + + # Only assign if lora_logits has data (not empty after index_select) + if lora_logits.shape[1] > 0: + logits[:, self.base_layer. + org_vocab_size:self.base_layer.org_vocab_size + + lora_logits.shape[1]] = lora_logits lora_output: Optional[ torch.Tensor] = self.punica_wrapper.add_lora_logits( diff --git a/vllm/lora/layers/vocal_parallel_embedding.py b/vllm/lora/layers/vocal_parallel_embedding.py index ca01c7e17fff..cda5bda7f431 100644 --- a/vllm/lora/layers/vocal_parallel_embedding.py +++ b/vllm/lora/layers/vocal_parallel_embedding.py @@ -102,7 +102,8 @@ def set_lora( self.lora_b_stacked[index, 0, :lora_b.shape[0], :lora_b.shape[1]].copy_( lora_b, non_blocking=True) - if embeddings_tensor is not None: + if embeddings_tensor is not None and self.embeddings_tensors.shape[ + 1] > 0: self.embeddings_tensors[ index, :embeddings_tensor.shape[0], diff --git a/vllm/lora/models.py b/vllm/lora/models.py index cc64cc78affa..3d04afee5749 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -554,17 +554,27 @@ def create_dummy_lora( if module_name not in self.packed_modules: assert embedding_modules is not None if parts[-1] in embedding_modules: - input_dim = (module.base_layer.org_vocab_size + - self.lora_config.lora_extra_vocab_size if - hasattr(module.base_layer, "org_vocab_size") - else module.base_layer.weight.shape[1]) - output_dim = module.base_layer.embedding_dim if hasattr( - module.base_layer, - "embedding_dim") else module.base_layer.weight.shape[0] - embeddings_tensor_dim = (module.base_layer.embedding_dim if - hasattr(module.base_layer, - "embedding_dim") else - module.base_layer.weight.shape[1]) + # Special-case lm_head: wrapped by LogitsProcessorWithLoRA. + # LoRA input dim is hidden_size, output dim is vocab size. + # LogitsProcessorWithLoRA handles extra vocab size directly. + if parts[-1] == "lm_head": + input_dim = module.lora_a_stacked[0].shape[-1] + output_dim = module.lora_b_stacked[0].shape[-2] + embeddings_tensor_dim = input_dim + else: + input_dim = (module.base_layer.org_vocab_size + + self.lora_config.lora_extra_vocab_size + if hasattr(module.base_layer, + "org_vocab_size") else + module.base_layer.weight.shape[1]) + output_dim = (module.base_layer.embedding_dim + if hasattr(module.base_layer, + "embedding_dim") else + module.base_layer.weight.shape[0]) + embeddings_tensor_dim = ( + module.base_layer.embedding_dim if hasattr( + module.base_layer, "embedding_dim") else + module.base_layer.weight.shape[1]) lora = LoRALayerWeights.create_dummy_lora_weights( module_name, input_dim, diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index ae72fd30c399..3690b3e002f3 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -40,7 +40,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.sequence import IntermediateTensors from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP @@ -273,6 +274,13 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): ], } + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + "unembed_tokens": "output_embeddings", + } + embedding_padding_modules = ["lm_head", "unembed_tokens"] + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -286,19 +294,30 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = Qwen3Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + # Calculate unpadded vocab size (with LoRA extra vocab if applicable) + self.unpadded_vocab_size = config.vocab_size + if lora_config and not config.tie_word_embeddings: + # Only add extra vocab if weights are not tied + # (tied weights can't support extra vocab) + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + if get_pp_group().is_last_rank: if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=(DEFAULT_VOCAB_PADDING_SIZE if not lora_config + else lora_config.lora_vocab_padding_size), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head")) else: self.lm_head = PPMissingLayer() - self.logits_processor = LogitsProcessor(config.vocab_size) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 61f1abad72b6..6da70b777421 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -582,14 +582,23 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, ], } + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + "unembed_tokens": "output_embeddings", + } + embedding_padding_modules = ["lm_head", "unembed_tokens"] + fall_back_to_pt_during_load = False def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_text_config quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config + self.lora_config = lora_config self.model = Qwen3MoeModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.lm_head = ParallelLMHead(config.vocab_size,