Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from safetensors.torch import load_file
from torch import nn

from vllm.config import ModelConfig, VllmConfig
from vllm.config.lora import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
RowParallelLinearWithLoRA)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager)
from vllm.lora.peft_helper import PEFTHelper
Expand Down Expand Up @@ -435,10 +436,19 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device,
target_modules=["layer1.dense1", "dense2"],
lora_dtype=DEFAULT_DTYPE,
)

model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config,
lora_config=lora_config)

vllm_config.scheduler_config.max_num_seqs = 4
vllm_config.scheduler_config.max_num_batched_tokens = 2
worker_adapter_manager = LRUCacheWorkerLoRAManager(
4, 2,
dummy_model.unpadded_vocab_size - lora_config.lora_extra_vocab_size,
lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)

worker_adapter_manager.max_num_seqs = 4
worker_adapter_manager.max_num_batched_tokens = 2
Comment on lines +440 to +450
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The vllm_config is not correctly initialized for the test. The ModelConfig within it doesn't have the hf_config from the dummy_model, which will cause vllm_config.model_config.get_vocab_size() to return 0 inside LRUCacheWorkerLoRAManager. This leads to incorrect behavior, especially when calculating target_embedding_padding.

To fix this, you should associate the dummy_model.config with the vllm_config's model_config. Also, the manual setting of max_num_seqs and max_num_batched_tokens on worker_adapter_manager is redundant as these are already set during initialization from the vllm_config.

Suggested change
model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config,
lora_config=lora_config)
vllm_config.scheduler_config.max_num_seqs = 4
vllm_config.scheduler_config.max_num_batched_tokens = 2
worker_adapter_manager = LRUCacheWorkerLoRAManager(
4, 2,
dummy_model.unpadded_vocab_size - lora_config.lora_extra_vocab_size,
lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
worker_adapter_manager.max_num_seqs = 4
worker_adapter_manager.max_num_batched_tokens = 2
model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config,
lora_config=lora_config)
# Manually set hf_config for the test since ModelConfig doesn't take it
# in __init__ and we are not loading from a real model path.
vllm_config.model_config.hf_config = dummy_model.config
vllm_config.scheduler_config.max_num_seqs = 4
vllm_config.scheduler_config.max_num_batched_tokens = 2
worker_adapter_manager = LRUCacheWorkerLoRAManager(
vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)


worker_adapter_manager.create_lora_manager(dummy_model)

mapping = LoRAMapping([], [])
Expand Down Expand Up @@ -517,10 +527,20 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device,
max_cpu_loras=4,
max_loras=4,
lora_dtype=DEFAULT_DTYPE)
worker_adapter_manager = WorkerLoRAManager(
4, 2, dummy_model_gate_up.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, device,
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)

model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config,
lora_config=lora_config)

vllm_config.scheduler_config.max_num_seqs = 4
vllm_config.scheduler_config.max_num_batched_tokens = 2

worker_adapter_manager = WorkerLoRAManager(vllm_config, device,
EMBEDDING_MODULES,
EMBEDDING_PADDING_MODULES)
worker_adapter_manager.vocab_size = (
dummy_model_gate_up.unpadded_vocab_size -
lora_config.lora_extra_vocab_size)
worker_adapter_manager.create_lora_manager(dummy_model_gate_up)

dummy_lora_files = f"{tmp_path}/lora_adapter"
Expand Down
2 changes: 1 addition & 1 deletion tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from safetensors.torch import save_file

from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights


class DummyLoRAManager:
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
Expand Down
23 changes: 12 additions & 11 deletions vllm/lora/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch

from vllm.config.lora import LoRAConfig
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.lora.models import (LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager, create_lora_manager)
Expand All @@ -27,25 +27,26 @@ class WorkerLoRAManager:

def __init__(
self,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
vllm_config: VllmConfig,
device: torch.device,
embedding_modules: dict[str, str],
embedding_padding_modules: list[str],
lora_model_cls: type[LoRAModel] = LoRAModel,
max_position_embeddings: Optional[int] = None,
):
self._lora_model_cls = lora_model_cls
self.embedding_modules = embedding_modules
self.embedding_padding_modules = embedding_padding_modules
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = max_num_batched_tokens
self.vocab_size = vocab_size
self.lora_config = lora_config
self.max_position_embeddings = max_position_embeddings
self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
self.max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
self.vocab_size = vllm_config.model_config.get_vocab_size()
self.lora_config = vllm_config.lora_config

# Use get_text_config() in case of multimodal models
text_config = vllm_config.model_config.hf_config.get_text_config()

self.max_position_embeddings = text_config.max_position_embeddings
self.device = device
# Lazily initialized by create_lora_manager.
self._adapter_manager: LoRAModelManager
Expand Down
5 changes: 2 additions & 3 deletions vllm/v1/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,8 @@ def load_model(self, eep_scale_up: bool = False) -> None:
self.model = get_model(vllm_config=self.vllm_config)

if self.lora_config:
self.model = self.load_lora_model(self.model, self.model_config,
self.scheduler_config,
self.lora_config, self.device)
self.model = self.load_lora_model(self.model, self.vllm_config,
self.device)

def get_model(self) -> nn.Module:
return self.model
Expand Down
5 changes: 1 addition & 4 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2552,10 +2552,7 @@ def load_model(self, eep_scale_up: bool = False) -> None:
self.model = model_loader.load_model(
vllm_config=self.vllm_config, model_config=self.model_config)
if self.lora_config:
self.model = self.load_lora_model(self.model,
self.model_config,
self.scheduler_config,
self.lora_config,
self.model = self.load_lora_model(self.model, self.vllm_config,
self.device)
if hasattr(self, "drafter"):
logger.info("Loading drafter model...")
Expand Down
15 changes: 3 additions & 12 deletions vllm/v1/worker/lora_model_runner_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
import torch.nn as nn

from vllm.config import ModelConfig, SchedulerConfig
from vllm.config import VllmConfig
from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
Expand All @@ -31,9 +31,7 @@ class LoRAModelRunnerMixin:

LORA_WARMUP_RANK = 8

def load_lora_model(self, model: nn.Module, model_config: ModelConfig,
scheduler_config: SchedulerConfig,
lora_config: LoRAConfig,
def load_lora_model(self, model: nn.Module, vllm_config: VllmConfig,
device: torch.device) -> nn.Module:

if not supports_lora(model):
Expand All @@ -44,19 +42,12 @@ def load_lora_model(self, model: nn.Module, model_config: ModelConfig,
logger.warning("Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.")

# Use get_text_config() in case of multimodal models
text_config = model_config.hf_config.get_text_config()

# Add LoRA Manager to the Model Runner
self.lora_manager = LRUCacheWorkerLoRAManager(
scheduler_config.max_num_seqs,
scheduler_config.max_num_batched_tokens,
model_config.get_vocab_size(),
lora_config,
vllm_config,
device,
model.embedding_modules,
model.embedding_padding_modules,
max_position_embeddings=text_config.max_position_embeddings,
)
return self.lora_manager.create_lora_manager(model)

Expand Down
4 changes: 1 addition & 3 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,9 +1178,7 @@ def load_model(self) -> None:
"or sharding the weights on more chips. "
f"See the detailed error: {e}") from e
if self.lora_config is not None:
model = self.load_lora_model(model, self.model_config,
self.scheduler_config,
self.lora_config, self.device)
model = self.load_lora_model(model, self.vllm_config, self.device)
replace_set_lora(model)

# Sync all pending XLA execution during model initialization and weight
Expand Down
11 changes: 2 additions & 9 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,20 +1078,13 @@ def load_model(self) -> None:
"Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.")

# Use get_text_config() in case of multimodal models
text_config = self.model_config.hf_config.get_text_config()

self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
self.vocab_size,
self.lora_config,
self.vllm_config,
self.device,
self.model.embedding_modules,
self.model.embedding_padding_modules,
max_position_embeddings=text_config.
max_position_embeddings,
)

self.model = self.lora_manager.create_lora_manager(self.model)
time_after_load = time.perf_counter()

Expand Down