From 3e11fc637a54520a50245692def9a1eb1a7beb01 Mon Sep 17 00:00:00 2001 From: Went-Liang Date: Fri, 25 Oct 2024 20:23:33 +0800 Subject: [PATCH 01/12] Support math-shepherd-mistral-7b-prm model Signed-off-by: Went-Liang --- vllm/model_executor/layers/pooler.py | 19 +++++++++++++++++++ vllm/model_executor/models/llama.py | 9 +++++++++ 2 files changed, 28 insertions(+) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 0a1df9cb699a..d6411773643e 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -1,3 +1,4 @@ +import os from enum import IntEnum import torch @@ -13,6 +14,7 @@ class PoolingType(IntEnum): LAST = 0 ALL = 1 CLS = 2 + STEP = 3 class Pooler(nn.Module): @@ -37,6 +39,9 @@ def __init__(self, self.pooling_type = pooling_type self.normalize = normalize self.softmax = softmax + returned_token_ids = os.environ.get('RETURNED_TOKEN_IDS', '648,387') + self.returned_token_ids = list(map(int, returned_token_ids.split(","))) + self.step_tag_id = int(os.environ.get('STEP_TOKEN_ID', -1)) def forward( self, @@ -62,6 +67,20 @@ def forward( for prompt_len in prompt_lens: pooled_data.append(hidden_states[offset:offset + prompt_len]) offset += prompt_len + elif self.pooling_type == PoolingType.STEP: + logits = hidden_states[:, self.returned_token_ids].softmax(dim=-1) + offset = 0 + pooled_data = [] + for prompt_len, seq_data_i in zip( + prompt_lens, pooling_metadata.seq_data.values()): + if self.step_tag_id == -1: + pooled_data.append(logits[offset:offset + prompt_len]) + else: + step_idxs = torch.tensor( + seq_data_i.prompt_token_ids) == self.step_tag_id + pooled_data.append(logits[offset:offset + + prompt_len][step_idxs]) + offset += prompt_len else: raise ValueError(f"Invalid pooling type: {self.pooling_type}") diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 98c53bdaae81..8e19d78be19c 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -543,6 +543,7 @@ def __init__( self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + self._pooler = Pooler(pooling_type=PoolingType.STEP, normalize=False) def forward( self, @@ -565,6 +566,14 @@ def compute_logits( sampling_metadata) return logits + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + logits = self.compute_logits(hidden_states, None) + return self._pooler(logits, pooling_metadata) + def sample(self, logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) From 3a3acad526e04de6643c443fa09a6aa4e795ab97 Mon Sep 17 00:00:00 2001 From: Went-Liang Date: Mon, 28 Oct 2024 21:28:25 +0800 Subject: [PATCH 02/12] Configure pooler through PoolerConfig Signed-off-by: Went-Liang --- vllm/config.py | 111 +++++++++++++++------ vllm/engine/arg_utils.py | 48 +++++++++ vllm/engine/llm_engine.py | 9 +- vllm/entrypoints/llm.py | 10 ++ vllm/model_executor/layers/pooler.py | 26 +++-- vllm/model_executor/model_loader/loader.py | 15 ++- vllm/model_executor/models/llama.py | 11 +- 7 files changed, 183 insertions(+), 47 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 3814e41aeb92..808bbc1c5cb8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -112,38 +112,58 @@ class ModelConfig: Defaults to 'auto' which defaults to 'hf'. mm_processor_kwargs: Arguments to be forwarded to the model's processor for multi-modal data, e.g., image processor. + pooling_type: Used to configure the pooling method in the embedding + model. + pooling_norm: Used to determine whether to normalize the pooled + data in the embedding model. + pooling_softmax: Used to determine whether to softmax the pooled + data in the embedding model. + pooling_step_tag_id: When pooling_step_tag_id is not -1, it indicates + that the score corresponding to the pooling_step_tag_id in the + generated sentence should be returned. Otherwise, it returns + the scores for all tokens. + pooling_returned_token_ids: pooling_returned_token_ids represents a + list of indices for the vocabulary dimensions to be extracted, + such as the token IDs of good_token and bad_token in the + math-shepherd-mistral-7b-prm model. """ - def __init__(self, - model: str, - task: Union[TaskOption, _Task], - tokenizer: str, - tokenizer_mode: str, - trust_remote_code: bool, - dtype: Union[str, torch.dtype], - seed: int, - revision: Optional[str] = None, - code_revision: Optional[str] = None, - rope_scaling: Optional[dict] = None, - rope_theta: Optional[float] = None, - tokenizer_revision: Optional[str] = None, - max_model_len: Optional[int] = None, - spec_target_max_model_len: Optional[int] = None, - quantization: Optional[str] = None, - quantization_param_path: Optional[str] = None, - enforce_eager: Optional[bool] = None, - max_context_len_to_capture: Optional[int] = None, - max_seq_len_to_capture: Optional[int] = None, - max_logprobs: int = 20, - disable_sliding_window: bool = False, - skip_tokenizer_init: bool = False, - served_model_name: Optional[Union[str, List[str]]] = None, - limit_mm_per_prompt: Optional[Mapping[str, int]] = None, - use_async_output_proc: bool = True, - override_neuron_config: Optional[Dict[str, Any]] = None, - config_format: ConfigFormat = ConfigFormat.AUTO, - chat_template_text_format: str = "string", - mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None: + def __init__( + self, + model: str, + task: Union[TaskOption, _Task], + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + dtype: Union[str, torch.dtype], + seed: int, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + spec_target_max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + enforce_eager: Optional[bool] = None, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: Optional[int] = None, + max_logprobs: int = 20, + disable_sliding_window: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, + limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + use_async_output_proc: bool = True, + override_neuron_config: Optional[Dict[str, Any]] = None, + config_format: ConfigFormat = ConfigFormat.AUTO, + chat_template_text_format: str = "string", + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + pooling_type: Optional[str] = None, + pooling_norm: bool = False, + pooling_softmax: bool = False, + pooling_step_tag_id: int = -1, + pooling_returned_token_ids: Optional[List[int]] = None) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode @@ -224,6 +244,13 @@ def __init__(self, supported_tasks, task = self._resolve_task(task, self.hf_config) self.supported_tasks = supported_tasks self.task: Final = task + self.pooler_config = self._init_pooler_config( + pooling_type, + pooling_norm, + pooling_softmax, + pooling_step_tag_id, + pooling_returned_token_ids, + ) self._verify_quantization() self._verify_cuda_graph() @@ -242,6 +269,19 @@ def _init_multimodal_config( return None + def _init_pooler_config( + self, pooling_type, pooling_norm, pooling_softmax, + pooling_step_tag_id, + pooling_returned_token_ids) -> Optional["PoolerConfig"]: + if self.task == "embedding": + return PoolerConfig( + pooling_type=pooling_type, + pooling_norm=pooling_norm, + pooling_softmax=pooling_softmax, + pooling_step_tag_id=pooling_step_tag_id, + pooling_returned_token_ids=pooling_returned_token_ids) + return None + def _init_attention_free(self) -> bool: architectures = getattr(self.hf_config, "architectures", []) return ModelRegistry.is_attention_free_model(architectures) @@ -1647,6 +1687,17 @@ class MultiModalConfig: # TODO: Add configs to init vision tower or not. +@dataclass +class PoolerConfig: + """Controls the behavior of pooler in embedding model""" + + pooling_type: Optional[str] = None + pooling_norm: bool = False + pooling_softmax: bool = False + pooling_step_tag_id: int = -1 + pooling_returned_token_ids: Optional[List[int]] = None + + _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, "float16": torch.float16, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 38687809a31f..fc9cb6f65718 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -184,6 +184,13 @@ class EngineArgs: mm_processor_kwargs: Optional[Dict[str, Any]] = None scheduling_policy: Literal["fcfs", "priority"] = "fcfs" + # Pooling configuration. + pooling_type: Optional[str] = None + pooling_norm: bool = False + pooling_softmax: bool = False + pooling_step_tag_id: int = -1 + pooling_returned_token_ids: Optional[List[int]] = None + def __post_init__(self): if not self.tokenizer: self.tokenizer = self.model @@ -850,6 +857,42 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'priority (lower value means earlier handling) and time of ' 'arrival deciding any ties).') + parser.add_argument( + '--pooling-type', + choices=['LAST', 'ALL', 'CLS', 'STEP'], + default="LAST", + help='Used to configure the pooling method in the embedding model.' + ) + + parser.add_argument('--pooling-norm', + action='store_true', + help="Used to determine whether to normalize " + "the pooled data in the embedding model.") + + parser.add_argument('--pooling-softmax', + action='store_true', + help="Used to determine whether to normalize " + "the pooled data in the embedding model.") + + parser.add_argument( + '--pooling-step-tag-id', + type=int, + default=-1, + help="When pooling-step-tag-id is not -1, it indicates " + "that the score corresponding to the step-tag-ids in the " + "generated sentence should be returned. Otherwise, it " + "returns the scores for all tokens.") + + parser.add_argument( + '--pooling-returned-token-ids', + nargs='+', + type=int, + default=None, + help="pooling-returned-token-ids represents a list of " + "indices for the vocabulary dimensions to be extracted, " + "such as the token IDs of good_token and bad_token in " + "the math-shepherd-mistral-7b-prm model.") + return parser @classmethod @@ -891,6 +934,11 @@ def create_model_config(self) -> ModelConfig: override_neuron_config=self.override_neuron_config, config_format=self.config_format, mm_processor_kwargs=self.mm_processor_kwargs, + pooling_type=self.pooling_type, + pooling_norm=self.pooling_norm, + pooling_softmax=self.pooling_softmax, + pooling_step_tag_id=self.pooling_step_tag_id, + pooling_returned_token_ids=self.pooling_returned_token_ids, ) def create_load_config(self) -> LoadConfig: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index fde768ed5165..16522f0b4ae7 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -257,7 +257,9 @@ def __init__( "num_scheduler_steps=%d, chunked_prefill_enabled=%s " "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " "use_async_output_proc=%s, use_cached_outputs=%s, " - "chat_template_text_format=%s, mm_processor_kwargs=%s)", + "chat_template_text_format=%s, mm_processor_kwargs=%s, " + "pooling_type=%s, pooling_norm=%s, pooling_softmax=%s, " + "pooling_step_tag_id=%s, pooling_returned_token_ids=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -294,6 +296,11 @@ def __init__( use_cached_outputs, model_config.chat_template_text_format, model_config.mm_processor_kwargs, + model_config.pooler_config.pooling_type, + model_config.pooler_config.pooling_norm, + model_config.pooler_config.pooling_softmax, + model_config.pooler_config.pooling_step_tag_id, + model_config.pooler_config.pooling_returned_token_ids, ) # TODO(woosuk): Print more configs in debug mode. self.model_config = model_config diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index db97fe0a0285..ca690bd65692 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -159,6 +159,11 @@ def __init__( mm_processor_kwargs: Optional[Dict[str, Any]] = None, # After positional args are removed, move this right below `model` task: TaskOption = "auto", + pooling_type: Optional[str] = None, + pooling_norm: bool = False, + pooling_softmax: bool = False, + pooling_step_tag_id: int = -1, + pooling_returned_token_ids: Optional[List[int]] = None, **kwargs, ) -> None: ''' @@ -193,6 +198,11 @@ def __init__( disable_custom_all_reduce=disable_custom_all_reduce, disable_async_output_proc=disable_async_output_proc, mm_processor_kwargs=mm_processor_kwargs, + pooling_type=pooling_type, + pooling_norm=pooling_norm, + pooling_softmax=pooling_softmax, + pooling_step_tag_id=pooling_step_tag_id, + pooling_returned_token_ids=pooling_returned_token_ids, **kwargs, ) self.llm_engine = LLMEngine.from_engine_args( diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index d6411773643e..4f2c04ec7f95 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -1,5 +1,5 @@ -import os from enum import IntEnum +from typing import List, Optional import torch import torch.nn as nn @@ -30,18 +30,21 @@ class Pooler(nn.Module): normalize: Whether to normalize the pooled data. """ - def __init__(self, - pooling_type: PoolingType, - normalize: bool, - softmax: bool = False): + def __init__( + self, + pooling_type: PoolingType, + normalize: bool, + softmax: bool = False, + step_tag_id: int = -1, + returned_token_ids: Optional[List[int]] = None, + ): super().__init__() self.pooling_type = pooling_type self.normalize = normalize self.softmax = softmax - returned_token_ids = os.environ.get('RETURNED_TOKEN_IDS', '648,387') - self.returned_token_ids = list(map(int, returned_token_ids.split(","))) - self.step_tag_id = int(os.environ.get('STEP_TOKEN_ID', -1)) + self.step_tag_id = step_tag_id + self.returned_token_ids = returned_token_ids def forward( self, @@ -68,7 +71,12 @@ def forward( pooled_data.append(hidden_states[offset:offset + prompt_len]) offset += prompt_len elif self.pooling_type == PoolingType.STEP: - logits = hidden_states[:, self.returned_token_ids].softmax(dim=-1) + if self.returned_token_ids is not None and len( + self.returned_token_ids) > 0: + logits = hidden_states[:, + self.returned_token_ids].softmax(dim=-1) + else: + logits = hidden_states.softmax(dim=-1) offset = 0 pooled_data = [] for prompt_len, seq_data_i in zip( diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 3ae8a51859f7..939927efd9db 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -23,7 +23,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, ModelConfig, MultiModalConfig, - ParallelConfig, SchedulerConfig) + ParallelConfig, PoolerConfig, SchedulerConfig) from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.envs import VLLM_USE_MODELSCOPE @@ -122,7 +122,8 @@ def _get_model_initialization_kwargs( model_class: Type[nn.Module], lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], - scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]: + scheduler_config: Optional[SchedulerConfig] = None, + pooler_config: Optional[PoolerConfig] = None) -> Dict[str, Any]: """Get extra kwargs for model initialization.""" extra_kwargs: Dict[str, Any] = {} @@ -143,7 +144,8 @@ def _get_model_initialization_kwargs( if has_inner_state(model_class) and scheduler_config: extra_kwargs["scheduler_config"] = scheduler_config - + if pooler_config: + extra_kwargs["pooler_config"] = pooler_config return extra_kwargs @@ -155,10 +157,12 @@ def build_model(model_class: Type[nn.Module], lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], scheduler_config: Optional[SchedulerConfig], - prefix: Optional[str] = None) -> nn.Module: + prefix: Optional[str] = None, + pooler_config: Optional[PoolerConfig]) -> nn.Module: extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config, multimodal_config, - scheduler_config) + scheduler_config, + pooler_config) if prefix: extra_kwargs["prefix"] = prefix @@ -185,6 +189,7 @@ def _initialize_model( lora_config=lora_config, multimodal_config=model_config.multimodal_config, scheduler_config=scheduler_config, + pooler_config=model_config.pooler_config, ) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 8e19d78be19c..ac98a2121e9e 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -29,7 +29,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, LoRAConfig, PoolerConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -502,6 +502,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, prefix: str = "", + pooler_config: Optional[PoolerConfig] = None, ) -> None: super().__init__() @@ -543,7 +544,13 @@ def __init__( self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - self._pooler = Pooler(pooling_type=PoolingType.STEP, normalize=False) + self._pooler = Pooler( + pooling_type=PoolingType[pooler_config.pooling_type], + normalize=pooler_config.pooling_norm, + softmax=pooler_config.pooling_softmax, + step_tag_id=pooler_config.pooling_step_tag_id, + returned_token_ids=pooler_config.pooling_returned_token_ids, + ) def forward( self, From 059d2821f04ccb2a1d6eccee9f8ca8df19626806 Mon Sep 17 00:00:00 2001 From: Went-Liang Date: Tue, 29 Oct 2024 11:28:40 +0800 Subject: [PATCH 03/12] Add type annotations and make None the default for pooler arguments Signed-off-by: Went-Liang --- vllm/config.py | 14 +++++++++----- vllm/engine/arg_utils.py | 6 +++--- vllm/entrypoints/llm.py | 2 +- vllm/model_executor/layers/pooler.py | 6 +++--- 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 808bbc1c5cb8..786100e23361 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -162,7 +162,7 @@ def __init__( pooling_type: Optional[str] = None, pooling_norm: bool = False, pooling_softmax: bool = False, - pooling_step_tag_id: int = -1, + pooling_step_tag_id: Optional[int] = None, pooling_returned_token_ids: Optional[List[int]] = None) -> None: self.model = model self.tokenizer = tokenizer @@ -270,9 +270,13 @@ def _init_multimodal_config( return None def _init_pooler_config( - self, pooling_type, pooling_norm, pooling_softmax, - pooling_step_tag_id, - pooling_returned_token_ids) -> Optional["PoolerConfig"]: + self, + pooling_type: Optional[str] = None, + pooling_norm: bool = False, + pooling_softmax: bool = False, + pooling_step_tag_id: Optional[int] = None, + pooling_returned_token_ids: Optional[List[int]] = None + ) -> Optional["PoolerConfig"]: if self.task == "embedding": return PoolerConfig( pooling_type=pooling_type, @@ -1694,7 +1698,7 @@ class PoolerConfig: pooling_type: Optional[str] = None pooling_norm: bool = False pooling_softmax: bool = False - pooling_step_tag_id: int = -1 + pooling_step_tag_id: Optional[int] = None pooling_returned_token_ids: Optional[List[int]] = None diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index fc9cb6f65718..9df6c5048bf7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -188,7 +188,7 @@ class EngineArgs: pooling_type: Optional[str] = None pooling_norm: bool = False pooling_softmax: bool = False - pooling_step_tag_id: int = -1 + pooling_step_tag_id: Optional[int] = None pooling_returned_token_ids: Optional[List[int]] = None def __post_init__(self): @@ -860,7 +860,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( '--pooling-type', choices=['LAST', 'ALL', 'CLS', 'STEP'], - default="LAST", + default=None, help='Used to configure the pooling method in the embedding model.' ) @@ -877,7 +877,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( '--pooling-step-tag-id', type=int, - default=-1, + default=None, help="When pooling-step-tag-id is not -1, it indicates " "that the score corresponding to the step-tag-ids in the " "generated sentence should be returned. Otherwise, it " diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index ca690bd65692..4354a84f18e1 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -162,7 +162,7 @@ def __init__( pooling_type: Optional[str] = None, pooling_norm: bool = False, pooling_softmax: bool = False, - pooling_step_tag_id: int = -1, + pooling_step_tag_id: Optional[int] = None, pooling_returned_token_ids: Optional[List[int]] = None, **kwargs, ) -> None: diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 4f2c04ec7f95..13f3bfe6d1a9 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -33,9 +33,9 @@ class Pooler(nn.Module): def __init__( self, pooling_type: PoolingType, - normalize: bool, + normalize: bool = False, softmax: bool = False, - step_tag_id: int = -1, + step_tag_id: Optional[int] = None, returned_token_ids: Optional[List[int]] = None, ): super().__init__() @@ -81,7 +81,7 @@ def forward( pooled_data = [] for prompt_len, seq_data_i in zip( prompt_lens, pooling_metadata.seq_data.values()): - if self.step_tag_id == -1: + if self.step_tag_id is None: pooled_data.append(logits[offset:offset + prompt_len]) else: step_idxs = torch.tensor( From 9e44cfc0a73d1fa2365857188b2468d457024119 Mon Sep 17 00:00:00 2001 From: Went-Liang Date: Tue, 29 Oct 2024 16:06:20 +0800 Subject: [PATCH 04/12] Make None the default for boolean arguments Signed-off-by: Went-Liang --- vllm/engine/arg_utils.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9df6c5048bf7..8de20eadadbb 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -865,13 +865,29 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ) parser.add_argument('--pooling-norm', + default=None, action='store_true', help="Used to determine whether to normalize " "the pooled data in the embedding model.") + parser.add_argument('--no-pooling-norm', + default=None, + action='store_false', + dest='pooling_norm', + help="Used to determine whether to normalize " + "the pooled data in the embedding model.") + parser.add_argument('--pooling-softmax', + default=None, action='store_true', - help="Used to determine whether to normalize " + help="Used to determine whether to softmax " + "the pooled data in the embedding model.") + + parser.add_argument('--no-pooling-softmax', + default=None, + action='store_false', + dest='pooling_softmax', + help="Used to determine whether to softmax " "the pooled data in the embedding model.") parser.add_argument( From 2fac40f7d051d5ba164f4995c730252c6e7c8173 Mon Sep 17 00:00:00 2001 From: Went-Liang Date: Tue, 29 Oct 2024 16:53:04 +0800 Subject: [PATCH 05/12] Update the type annotations and update each embedding model to use the PoolerConfig Signed-off-by: Went-Liang --- vllm/config.py | 12 ++++++------ vllm/engine/arg_utils.py | 4 ++-- vllm/entrypoints/llm.py | 4 ++-- vllm/model_executor/layers/pooler.py | 6 +++--- vllm/model_executor/models/bert.py | 12 ++++++++++-- vllm/model_executor/models/gemma2.py | 13 ++++++++++--- vllm/model_executor/models/llama.py | 17 +++++++++++++---- vllm/model_executor/models/llava_next.py | 15 +++++++++++---- vllm/model_executor/models/phi3v.py | 16 ++++++++++++---- vllm/model_executor/models/qwen2_cls.py | 14 ++++++++++---- vllm/model_executor/models/qwen2_rm.py | 13 ++++++++++--- 11 files changed, 89 insertions(+), 37 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 786100e23361..e9559c40dbdf 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -160,8 +160,8 @@ def __init__( chat_template_text_format: str = "string", mm_processor_kwargs: Optional[Dict[str, Any]] = None, pooling_type: Optional[str] = None, - pooling_norm: bool = False, - pooling_softmax: bool = False, + pooling_norm: Optional[bool] = None, + pooling_softmax: Optional[bool] = None, pooling_step_tag_id: Optional[int] = None, pooling_returned_token_ids: Optional[List[int]] = None) -> None: self.model = model @@ -272,8 +272,8 @@ def _init_multimodal_config( def _init_pooler_config( self, pooling_type: Optional[str] = None, - pooling_norm: bool = False, - pooling_softmax: bool = False, + pooling_norm: Optional[bool] = None, + pooling_softmax: Optional[bool] = None, pooling_step_tag_id: Optional[int] = None, pooling_returned_token_ids: Optional[List[int]] = None ) -> Optional["PoolerConfig"]: @@ -1696,8 +1696,8 @@ class PoolerConfig: """Controls the behavior of pooler in embedding model""" pooling_type: Optional[str] = None - pooling_norm: bool = False - pooling_softmax: bool = False + pooling_norm: Optional[bool] = None + pooling_softmax: Optional[bool] = None pooling_step_tag_id: Optional[int] = None pooling_returned_token_ids: Optional[List[int]] = None diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8de20eadadbb..de886c98e51b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -186,8 +186,8 @@ class EngineArgs: # Pooling configuration. pooling_type: Optional[str] = None - pooling_norm: bool = False - pooling_softmax: bool = False + pooling_norm: Optional[bool] = None + pooling_softmax: Optional[bool] = None pooling_step_tag_id: Optional[int] = None pooling_returned_token_ids: Optional[List[int]] = None diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 4354a84f18e1..083b67c2f8e7 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -160,8 +160,8 @@ def __init__( # After positional args are removed, move this right below `model` task: TaskOption = "auto", pooling_type: Optional[str] = None, - pooling_norm: bool = False, - pooling_softmax: bool = False, + pooling_norm: Optional[bool] = None, + pooling_softmax: Optional[bool] = None, pooling_step_tag_id: Optional[int] = None, pooling_returned_token_ids: Optional[List[int]] = None, **kwargs, diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 13f3bfe6d1a9..fe6519a5c680 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -32,9 +32,9 @@ class Pooler(nn.Module): def __init__( self, - pooling_type: PoolingType, - normalize: bool = False, - softmax: bool = False, + pooling_type: Optional[PoolingType] = None, + normalize: Optional[bool] = None, + softmax: Optional[bool] = None, step_tag_id: Optional[int] = None, returned_token_ids: Optional[List[int]] = None, ): diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 4c0a0e303e65..18f454b7e453 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -6,7 +6,7 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention.backends.xformers import XFormersImpl -from vllm.config import CacheConfig +from vllm.config import CacheConfig, PoolerConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -387,10 +387,18 @@ def __init__( config: BertConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + pooler_config: Optional[PoolerConfig] = None, ) -> None: super().__init__() self.model = BertModel(config, cache_config, quant_config) - self._pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True) + self._pooler = Pooler( + pooling_type=PoolingType[pooler_config.pooling_type] + if pooler_config.pooling_type is not None else PoolingType.CLS, + normalize=pooler_config.pooling_norm or True, + softmax=pooler_config.pooling_softmax or False, + step_tag_id=pooler_config.pooling_step_tag_id, + returned_token_ids=pooler_config.pooling_returned_token_ids, + ) def forward( self, diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index d79248f93f5a..868f30468af2 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -22,7 +22,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, LoRAConfig, PoolerConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul @@ -473,13 +473,20 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP): def __init__( self, + pooler_config: Optional[PoolerConfig] = None, **kwargs, ) -> None: super().__init__() self.model = Gemma2Model(**kwargs) - self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) - + self._pooler = Pooler( + pooling_type=PoolingType[pooler_config.pooling_type] + if pooler_config.pooling_type is not None else PoolingType.LAST, + normalize=pooler_config.pooling_norm or True, + softmax=pooler_config.pooling_softmax or False, + step_tag_id=pooler_config.pooling_step_tag_id, + returned_token_ids=pooler_config.pooling_returned_token_ids, + ) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index ac98a2121e9e..5f9ad0aa4856 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -545,9 +545,10 @@ def __init__( self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) self._pooler = Pooler( - pooling_type=PoolingType[pooler_config.pooling_type], - normalize=pooler_config.pooling_norm, - softmax=pooler_config.pooling_softmax, + pooling_type=PoolingType[pooler_config.pooling_type] + if pooler_config.pooling_type is not None else PoolingType.STEP, + normalize=pooler_config.pooling_norm or False, + softmax=pooler_config.pooling_softmax or False, step_tag_id=pooler_config.pooling_step_tag_id, returned_token_ids=pooler_config.pooling_returned_token_ids, ) @@ -646,12 +647,20 @@ class LlamaEmbeddingModel(nn.Module, SupportsPP): def __init__( self, + pooler_config: Optional[PoolerConfig] = None, **kwargs, ) -> None: super().__init__() self.model = LlamaModel(**kwargs) - self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + self._pooler = Pooler( + pooling_type=PoolingType[pooler_config.pooling_type] + if pooler_config.pooling_type is not None else PoolingType.LAST, + normalize=pooler_config.pooling_norm or True, + softmax=pooler_config.pooling_softmax or False, + step_tag_id=pooler_config.pooling_step_tag_id, + returned_token_ids=pooler_config.pooling_returned_token_ids, + ) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index f85129b20691..cc866d2a4350 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -11,7 +11,7 @@ from typing_extensions import NotRequired from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, MultiModalConfig +from vllm.config import CacheConfig, MultiModalConfig, PoolerConfig from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig @@ -285,7 +285,8 @@ def __init__(self, config: LlavaNextConfig, multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + pooler_config: Optional[PoolerConfig] = None) -> None: super().__init__() self.config = config @@ -312,8 +313,14 @@ def __init__(self, # The same model class supports both language generation and embedding # because the architecture name is the same - self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) - + self._pooler = Pooler( + pooling_type=PoolingType[pooler_config.pooling_type] + if pooler_config.pooling_type is not None else PoolingType.LAST, + normalize=pooler_config.pooling_norm or True, + softmax=pooler_config.pooling_softmax or False, + step_tag_id=pooler_config.pooling_step_tag_id, + returned_token_ids=pooler_config.pooling_returned_token_ids, + ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 0962d3d3847c..da0c49c8bfcc 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -26,7 +26,8 @@ from transformers import CLIPVisionConfig, PretrainedConfig from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, ModelConfig, MultiModalConfig +from vllm.config import (CacheConfig, ModelConfig, MultiModalConfig, + PoolerConfig) from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, token_inputs) from vllm.logger import init_logger @@ -530,7 +531,8 @@ def __init__(self, config: PretrainedConfig, multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + pooler_config: Optional[PoolerConfig] = None) -> None: super().__init__() self.config = config @@ -556,8 +558,14 @@ def __init__(self, # The same model class supports both language generation and embedding # because the architecture name is the same - self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) - + self._pooler = Pooler( + pooling_type=PoolingType[pooler_config.pooling_type] + if pooler_config.pooling_type is not None else PoolingType.LAST, + normalize=pooler_config.pooling_norm or True, + softmax=pooler_config.pooling_softmax or False, + step_tag_id=pooler_config.pooling_step_tag_id, + returned_token_ids=pooler_config.pooling_returned_token_ids, + ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py index e10c6dbbb647..0a67227ba19e 100644 --- a/vllm/model_executor/models/qwen2_cls.py +++ b/vllm/model_executor/models/qwen2_cls.py @@ -12,7 +12,7 @@ from transformers import Qwen2Config from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, LoRAConfig, PoolerConfig from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization.base_config import ( @@ -53,6 +53,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + pooler_config: Optional[PoolerConfig] = None, ) -> None: # TODO (@robertgshaw2): see if this can be moved out if (cache_config.sliding_window is not None @@ -77,9 +78,14 @@ def __init__( self.score = RowParallelLinear(config.hidden_size, config.num_labels, quant_config=quant_config) - self._pooler = Pooler(pooling_type=PoolingType.LAST, - normalize=False, - softmax=True) + self._pooler = Pooler( + pooling_type=PoolingType[pooler_config.pooling_type] + if pooler_config.pooling_type is not None else PoolingType.LAST, + normalize=pooler_config.pooling_norm or False, + softmax=pooler_config.pooling_softmax or True, + step_tag_id=pooler_config.pooling_step_tag_id, + returned_token_ids=pooler_config.pooling_returned_token_ids, + ) def forward( self, diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index ee0eeb9db380..bb68ea2abe55 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -11,7 +11,7 @@ from transformers import Qwen2Config from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, LoRAConfig, PoolerConfig from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.pooler import Pooler, PoolingType @@ -64,6 +64,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + pooler_config: Optional[PoolerConfig] = None, ) -> None: # TODO (@robertgshaw2): see if this can be moved out if (cache_config.sliding_window is not None @@ -93,8 +94,14 @@ def __init__( RowParallelLinear(config.hidden_size, 1, quant_config=quant_config), ) - self._pooler = Pooler(pooling_type=PoolingType.ALL, normalize=False) - + self._pooler = Pooler( + pooling_type=PoolingType[pooler_config.pooling_type] + if pooler_config.pooling_type is not None else PoolingType.ALL, + normalize=pooler_config.pooling_norm or False, + softmax=pooler_config.pooling_softmax or False, + step_tag_id=pooler_config.pooling_step_tag_id, + returned_token_ids=pooler_config.pooling_returned_token_ids, + ) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) From 7b9346492d64d5e16c2277736075b8b7e86e0b87 Mon Sep 17 00:00:00 2001 From: Went-Liang Date: Tue, 29 Oct 2024 17:41:19 +0800 Subject: [PATCH 06/12] Add a factory method to Pooler Signed-off-by: Went-Liang --- vllm/model_executor/layers/pooler.py | 21 +++++++++++++++++++ vllm/model_executor/models/bert.py | 13 +++++------- vllm/model_executor/models/gemma2.py | 13 +++++------- vllm/model_executor/models/llama.py | 26 +++++++++--------------- vllm/model_executor/models/llava_next.py | 13 +++++------- vllm/model_executor/models/phi3v.py | 13 +++++------- vllm/model_executor/models/qwen2_cls.py | 13 +++++------- vllm/model_executor/models/qwen2_rm.py | 13 +++++------- 8 files changed, 61 insertions(+), 64 deletions(-) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index fe6519a5c680..3c58ae67794f 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn +from vllm.config import PoolerConfig from vllm.model_executor.pooling_metadata import (PoolingMetadata, PoolingTensors) from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput @@ -46,6 +47,26 @@ def __init__( self.step_tag_id = step_tag_id self.returned_token_ids = returned_token_ids + @classmethod + def from_config_with_defaults( + cls, + pooler_config: PoolerConfig, + pooling_type: Optional[PoolingType] = None, + normalize: Optional[bool] = None, + softmax: Optional[bool] = None, + step_tag_id: Optional[int] = None, + returned_token_ids: Optional[List[int]] = None, + ) -> "Pooler": + return cls( + pooling_type=PoolingType[pooler_config.pooling_type] + if pooler_config.pooling_type is not None else pooling_type, + normalize=pooler_config.pooling_norm or normalize, + softmax=pooler_config.pooling_softmax or softmax, + step_tag_id=pooler_config.pooling_step_tag_id or step_tag_id, + returned_token_ids=pooler_config.pooling_returned_token_ids + or returned_token_ids, + ) + def forward( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 18f454b7e453..bfed2929d57d 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -391,14 +391,11 @@ def __init__( ) -> None: super().__init__() self.model = BertModel(config, cache_config, quant_config) - self._pooler = Pooler( - pooling_type=PoolingType[pooler_config.pooling_type] - if pooler_config.pooling_type is not None else PoolingType.CLS, - normalize=pooler_config.pooling_norm or True, - softmax=pooler_config.pooling_softmax or False, - step_tag_id=pooler_config.pooling_step_tag_id, - returned_token_ids=pooler_config.pooling_returned_token_ids, - ) + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.CLS, + normalize=True, + softmax=False) def forward( self, diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 868f30468af2..693f32160a28 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -479,14 +479,11 @@ def __init__( super().__init__() self.model = Gemma2Model(**kwargs) - self._pooler = Pooler( - pooling_type=PoolingType[pooler_config.pooling_type] - if pooler_config.pooling_type is not None else PoolingType.LAST, - normalize=pooler_config.pooling_norm or True, - softmax=pooler_config.pooling_softmax or False, - step_tag_id=pooler_config.pooling_step_tag_id, - returned_token_ids=pooler_config.pooling_returned_token_ids, - ) + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.LAST, + normalize=True, + softmax=False) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 5f9ad0aa4856..8a9e5203972b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -544,14 +544,11 @@ def __init__( self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - self._pooler = Pooler( - pooling_type=PoolingType[pooler_config.pooling_type] - if pooler_config.pooling_type is not None else PoolingType.STEP, - normalize=pooler_config.pooling_norm or False, - softmax=pooler_config.pooling_softmax or False, - step_tag_id=pooler_config.pooling_step_tag_id, - returned_token_ids=pooler_config.pooling_returned_token_ids, - ) + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.STEP, + normalize=False, + softmax=False) def forward( self, @@ -653,14 +650,11 @@ def __init__( super().__init__() self.model = LlamaModel(**kwargs) - self._pooler = Pooler( - pooling_type=PoolingType[pooler_config.pooling_type] - if pooler_config.pooling_type is not None else PoolingType.LAST, - normalize=pooler_config.pooling_norm or True, - softmax=pooler_config.pooling_softmax or False, - step_tag_id=pooler_config.pooling_step_tag_id, - returned_token_ids=pooler_config.pooling_returned_token_ids, - ) + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.LAST, + normalize=True, + softmax=False) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index cc866d2a4350..e8540d85ff56 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -313,14 +313,11 @@ def __init__(self, # The same model class supports both language generation and embedding # because the architecture name is the same - self._pooler = Pooler( - pooling_type=PoolingType[pooler_config.pooling_type] - if pooler_config.pooling_type is not None else PoolingType.LAST, - normalize=pooler_config.pooling_norm or True, - softmax=pooler_config.pooling_softmax or False, - step_tag_id=pooler_config.pooling_step_tag_id, - returned_token_ids=pooler_config.pooling_returned_token_ids, - ) + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.LAST, + normalize=True, + softmax=False) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index da0c49c8bfcc..0fc4556831fd 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -558,14 +558,11 @@ def __init__(self, # The same model class supports both language generation and embedding # because the architecture name is the same - self._pooler = Pooler( - pooling_type=PoolingType[pooler_config.pooling_type] - if pooler_config.pooling_type is not None else PoolingType.LAST, - normalize=pooler_config.pooling_norm or True, - softmax=pooler_config.pooling_softmax or False, - step_tag_id=pooler_config.pooling_step_tag_id, - returned_token_ids=pooler_config.pooling_returned_token_ids, - ) + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.LAST, + normalize=True, + softmax=False) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py index 0a67227ba19e..2d6f3e90f761 100644 --- a/vllm/model_executor/models/qwen2_cls.py +++ b/vllm/model_executor/models/qwen2_cls.py @@ -78,14 +78,11 @@ def __init__( self.score = RowParallelLinear(config.hidden_size, config.num_labels, quant_config=quant_config) - self._pooler = Pooler( - pooling_type=PoolingType[pooler_config.pooling_type] - if pooler_config.pooling_type is not None else PoolingType.LAST, - normalize=pooler_config.pooling_norm or False, - softmax=pooler_config.pooling_softmax or True, - step_tag_id=pooler_config.pooling_step_tag_id, - returned_token_ids=pooler_config.pooling_returned_token_ids, - ) + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.LAST, + normalize=False, + softmax=True) def forward( self, diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index bb68ea2abe55..901b1daaa14a 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -94,14 +94,11 @@ def __init__( RowParallelLinear(config.hidden_size, 1, quant_config=quant_config), ) - self._pooler = Pooler( - pooling_type=PoolingType[pooler_config.pooling_type] - if pooler_config.pooling_type is not None else PoolingType.ALL, - normalize=pooler_config.pooling_norm or False, - softmax=pooler_config.pooling_softmax or False, - step_tag_id=pooler_config.pooling_step_tag_id, - returned_token_ids=pooler_config.pooling_returned_token_ids, - ) + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.ALL, + normalize=False, + softmax=False) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) From 084958c21288c800ffc840ed5610174bcee8ce00 Mon Sep 17 00:00:00 2001 From: Went-Liang Date: Tue, 29 Oct 2024 18:35:22 +0800 Subject: [PATCH 07/12] Check None explicitly Signed-off-by: Went-Liang --- vllm/model_executor/layers/pooler.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 3c58ae67794f..f9f627b79632 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -33,9 +33,9 @@ class Pooler(nn.Module): def __init__( self, - pooling_type: Optional[PoolingType] = None, - normalize: Optional[bool] = None, - softmax: Optional[bool] = None, + pooling_type: PoolingType, + normalize: bool, + softmax: bool, step_tag_id: Optional[int] = None, returned_token_ids: Optional[List[int]] = None, ): @@ -51,20 +51,24 @@ def __init__( def from_config_with_defaults( cls, pooler_config: PoolerConfig, - pooling_type: Optional[PoolingType] = None, - normalize: Optional[bool] = None, - softmax: Optional[bool] = None, + pooling_type: PoolingType, + normalize: bool, + softmax: bool, step_tag_id: Optional[int] = None, returned_token_ids: Optional[List[int]] = None, ) -> "Pooler": return cls( pooling_type=PoolingType[pooler_config.pooling_type] if pooler_config.pooling_type is not None else pooling_type, - normalize=pooler_config.pooling_norm or normalize, - softmax=pooler_config.pooling_softmax or softmax, - step_tag_id=pooler_config.pooling_step_tag_id or step_tag_id, + normalize=pooler_config.pooling_norm + if pooler_config.pooling_norm is not None else normalize, + softmax=pooler_config.pooling_softmax + if pooler_config.pooling_softmax is not None else softmax, + step_tag_id=pooler_config.pooling_step_tag_id + if pooler_config.pooling_step_tag_id is not None else step_tag_id, returned_token_ids=pooler_config.pooling_returned_token_ids - or returned_token_ids, + if pooler_config.pooling_returned_token_ids is not None else + returned_token_ids, ) def forward( From 88ab376aad5ba6377b3d1b3034d0ee41f036beb1 Mon Sep 17 00:00:00 2001 From: Went-Liang Date: Tue, 29 Oct 2024 19:17:42 +0800 Subject: [PATCH 08/12] Fix bug for PoolerConfig Signed-off-by: Went-Liang --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index e9559c40dbdf..f794dc33f6f8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -284,7 +284,7 @@ def _init_pooler_config( pooling_softmax=pooling_softmax, pooling_step_tag_id=pooling_step_tag_id, pooling_returned_token_ids=pooling_returned_token_ids) - return None + return PoolerConfig() def _init_attention_free(self) -> bool: architectures = getattr(self.hf_config, "architectures", []) From a1b2ee81fd10db17660a2316ae57bba577143aa3 Mon Sep 17 00:00:00 2001 From: Went-Liang Date: Tue, 29 Oct 2024 19:42:46 +0800 Subject: [PATCH 09/12] Fixed bug for print pooler_config Signed-off-by: Went-Liang --- vllm/config.py | 2 +- vllm/engine/llm_engine.py | 9 ++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index f794dc33f6f8..e9559c40dbdf 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -284,7 +284,7 @@ def _init_pooler_config( pooling_softmax=pooling_softmax, pooling_step_tag_id=pooling_step_tag_id, pooling_returned_token_ids=pooling_returned_token_ids) - return PoolerConfig() + return None def _init_attention_free(self) -> bool: architectures = getattr(self.hf_config, "architectures", []) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 16522f0b4ae7..3fd34fadee1c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -258,8 +258,7 @@ def __init__( "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " "use_async_output_proc=%s, use_cached_outputs=%s, " "chat_template_text_format=%s, mm_processor_kwargs=%s, " - "pooling_type=%s, pooling_norm=%s, pooling_softmax=%s, " - "pooling_step_tag_id=%s, pooling_returned_token_ids=%s)", + "pooler_config=%r)", VLLM_VERSION, model_config.model, speculative_config, @@ -296,11 +295,7 @@ def __init__( use_cached_outputs, model_config.chat_template_text_format, model_config.mm_processor_kwargs, - model_config.pooler_config.pooling_type, - model_config.pooler_config.pooling_norm, - model_config.pooler_config.pooling_softmax, - model_config.pooler_config.pooling_step_tag_id, - model_config.pooler_config.pooling_returned_token_ids, + model_config.pooler_config, ) # TODO(woosuk): Print more configs in debug mode. self.model_config = model_config From b4b34dc6533156462f6118e16282213ffdc10352 Mon Sep 17 00:00:00 2001 From: Went-Liang Date: Tue, 29 Oct 2024 22:49:34 +0800 Subject: [PATCH 10/12] Fixed bug for from_config_with_defaults Signed-off-by: Went-Liang --- vllm/model_executor/layers/pooler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index f9f627b79632..1c9772b41cbe 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -56,7 +56,9 @@ def from_config_with_defaults( softmax: bool, step_tag_id: Optional[int] = None, returned_token_ids: Optional[List[int]] = None, - ) -> "Pooler": + ) -> Optional["Pooler"]: + if pooler_config is None: + return None return cls( pooling_type=PoolingType[pooler_config.pooling_type] if pooler_config.pooling_type is not None else pooling_type, From cf050e4904a307172874b23fe1d99fae9944e197 Mon Sep 17 00:00:00 2001 From: Went-Liang Date: Wed, 30 Oct 2024 00:16:20 +0800 Subject: [PATCH 11/12] Fixed bug for build_model Signed-off-by: Went-Liang --- vllm/model_executor/model_loader/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 939927efd9db..79703bb7ded7 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -158,7 +158,7 @@ def build_model(model_class: Type[nn.Module], multimodal_config: Optional[MultiModalConfig], scheduler_config: Optional[SchedulerConfig], prefix: Optional[str] = None, - pooler_config: Optional[PoolerConfig]) -> nn.Module: + pooler_config: Optional[PoolerConfig] = None) -> nn.Module: extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config, multimodal_config, scheduler_config, From d1b0f5b27911b922fea4100c36f269d547c6b4eb Mon Sep 17 00:00:00 2001 From: Went-Liang Date: Wed, 30 Oct 2024 11:50:31 +0800 Subject: [PATCH 12/12] Add models using LlamaForCausalLM to _EMBEDDING_MODELS Signed-off-by: Went-Liang --- vllm/model_executor/models/registry.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 30dfff31f7e4..f50ceaccb1bb 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -100,11 +100,27 @@ "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), "Qwen2ForSequenceClassification": ( "qwen2_cls", "Qwen2ForSequenceClassification"), + "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), + "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), + "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), # [Multimodal] "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), } +def add_embedding_models(base_models, embedding_models): + with_pooler_method_models = {} + embedding_models_name = embedding_models.keys() + for name, (path, arch) in base_models.items(): + if arch in embedding_models_name: + with_pooler_method_models[name] = (path, arch) + return with_pooler_method_models + +_EMBEDDING_MODELS = { + **add_embedding_models(_TEXT_GENERATION_MODELS, _EMBEDDING_MODELS), + **_EMBEDDING_MODELS, +} + _MULTIMODAL_MODELS = { # [Decoder-only] "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),