Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 85 additions & 30 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,38 +112,58 @@ class ModelConfig:
Defaults to 'auto' which defaults to 'hf'.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor.
pooling_type: Used to configure the pooling method in the embedding
model.
pooling_norm: Used to determine whether to normalize the pooled
data in the embedding model.
pooling_softmax: Used to determine whether to softmax the pooled
data in the embedding model.
pooling_step_tag_id: When pooling_step_tag_id is not -1, it indicates
that the score corresponding to the pooling_step_tag_id in the
generated sentence should be returned. Otherwise, it returns
the scores for all tokens.
pooling_returned_token_ids: pooling_returned_token_ids represents a
list of indices for the vocabulary dimensions to be extracted,
such as the token IDs of good_token and bad_token in the
math-shepherd-mistral-7b-prm model.
"""

def __init__(self,
model: str,
task: Union[TaskOption, _Task],
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
override_neuron_config: Optional[Dict[str, Any]] = None,
config_format: ConfigFormat = ConfigFormat.AUTO,
chat_template_text_format: str = "string",
mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None:
def __init__(
self,
model: str,
task: Union[TaskOption, _Task],
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
override_neuron_config: Optional[Dict[str, Any]] = None,
config_format: ConfigFormat = ConfigFormat.AUTO,
chat_template_text_format: str = "string",
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
pooling_type: Optional[str] = None,
pooling_norm: Optional[bool] = None,
pooling_softmax: Optional[bool] = None,
pooling_step_tag_id: Optional[int] = None,
pooling_returned_token_ids: Optional[List[int]] = None) -> None:
self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
Expand Down Expand Up @@ -224,6 +244,13 @@ def __init__(self,
supported_tasks, task = self._resolve_task(task, self.hf_config)
self.supported_tasks = supported_tasks
self.task: Final = task
self.pooler_config = self._init_pooler_config(
pooling_type,
pooling_norm,
pooling_softmax,
pooling_step_tag_id,
pooling_returned_token_ids,
)

self._verify_quantization()
self._verify_cuda_graph()
Expand All @@ -242,6 +269,23 @@ def _init_multimodal_config(

return None

def _init_pooler_config(
self,
pooling_type: Optional[str] = None,
pooling_norm: Optional[bool] = None,
pooling_softmax: Optional[bool] = None,
pooling_step_tag_id: Optional[int] = None,
pooling_returned_token_ids: Optional[List[int]] = None
) -> Optional["PoolerConfig"]:
if self.task == "embedding":
return PoolerConfig(
pooling_type=pooling_type,
pooling_norm=pooling_norm,
pooling_softmax=pooling_softmax,
pooling_step_tag_id=pooling_step_tag_id,
pooling_returned_token_ids=pooling_returned_token_ids)
return None

def _init_attention_free(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.is_attention_free_model(architectures)
Expand Down Expand Up @@ -1647,6 +1691,17 @@ class MultiModalConfig:
# TODO: Add configs to init vision tower or not.


@dataclass
class PoolerConfig:
"""Controls the behavior of pooler in embedding model"""

pooling_type: Optional[str] = None
pooling_norm: Optional[bool] = None
pooling_softmax: Optional[bool] = None
pooling_step_tag_id: Optional[int] = None
pooling_returned_token_ids: Optional[List[int]] = None


_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
"float16": torch.float16,
Expand Down
64 changes: 64 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,13 @@ class EngineArgs:
mm_processor_kwargs: Optional[Dict[str, Any]] = None
scheduling_policy: Literal["fcfs", "priority"] = "fcfs"

# Pooling configuration.
pooling_type: Optional[str] = None
pooling_norm: Optional[bool] = None
pooling_softmax: Optional[bool] = None
pooling_step_tag_id: Optional[int] = None
pooling_returned_token_ids: Optional[List[int]] = None

def __post_init__(self):
if not self.tokenizer:
self.tokenizer = self.model
Expand Down Expand Up @@ -850,6 +857,58 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'priority (lower value means earlier handling) and time of '
'arrival deciding any ties).')

parser.add_argument(
'--pooling-type',
choices=['LAST', 'ALL', 'CLS', 'STEP'],
default=None,
help='Used to configure the pooling method in the embedding model.'
)

parser.add_argument('--pooling-norm',
default=None,
action='store_true',
help="Used to determine whether to normalize "
"the pooled data in the embedding model.")

parser.add_argument('--no-pooling-norm',
default=None,
action='store_false',
dest='pooling_norm',
help="Used to determine whether to normalize "
"the pooled data in the embedding model.")

parser.add_argument('--pooling-softmax',
default=None,
action='store_true',
help="Used to determine whether to softmax "
"the pooled data in the embedding model.")

parser.add_argument('--no-pooling-softmax',
default=None,
action='store_false',
dest='pooling_softmax',
help="Used to determine whether to softmax "
"the pooled data in the embedding model.")

parser.add_argument(
'--pooling-step-tag-id',
type=int,
default=None,
help="When pooling-step-tag-id is not -1, it indicates "
"that the score corresponding to the step-tag-ids in the "
"generated sentence should be returned. Otherwise, it "
"returns the scores for all tokens.")

parser.add_argument(
'--pooling-returned-token-ids',
nargs='+',
type=int,
default=None,
help="pooling-returned-token-ids represents a list of "
"indices for the vocabulary dimensions to be extracted, "
"such as the token IDs of good_token and bad_token in "
"the math-shepherd-mistral-7b-prm model.")

return parser

@classmethod
Expand Down Expand Up @@ -891,6 +950,11 @@ def create_model_config(self) -> ModelConfig:
override_neuron_config=self.override_neuron_config,
config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs,
pooling_type=self.pooling_type,
pooling_norm=self.pooling_norm,
pooling_softmax=self.pooling_softmax,
pooling_step_tag_id=self.pooling_step_tag_id,
pooling_returned_token_ids=self.pooling_returned_token_ids,
)

def create_load_config(self) -> LoadConfig:
Expand Down
4 changes: 3 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def __init__(
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, "
"chat_template_text_format=%s, mm_processor_kwargs=%s)",
"chat_template_text_format=%s, mm_processor_kwargs=%s, "
"pooler_config=%r)",
VLLM_VERSION,
model_config.model,
speculative_config,
Expand Down Expand Up @@ -294,6 +295,7 @@ def __init__(
use_cached_outputs,
model_config.chat_template_text_format,
model_config.mm_processor_kwargs,
model_config.pooler_config,
)
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config
Expand Down
10 changes: 10 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ def __init__(
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
# After positional args are removed, move this right below `model`
task: TaskOption = "auto",
pooling_type: Optional[str] = None,
pooling_norm: Optional[bool] = None,
pooling_softmax: Optional[bool] = None,
pooling_step_tag_id: Optional[int] = None,
pooling_returned_token_ids: Optional[List[int]] = None,
**kwargs,
) -> None:
'''
Expand Down Expand Up @@ -193,6 +198,11 @@ def __init__(
disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc,
mm_processor_kwargs=mm_processor_kwargs,
pooling_type=pooling_type,
pooling_norm=pooling_norm,
pooling_softmax=pooling_softmax,
pooling_step_tag_id=pooling_step_tag_id,
pooling_returned_token_ids=pooling_returned_token_ids,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(
Expand Down
62 changes: 58 additions & 4 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from enum import IntEnum
from typing import List, Optional

import torch
import torch.nn as nn

from vllm.config import PoolerConfig
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
PoolingTensors)
from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput
Expand All @@ -13,6 +15,7 @@ class PoolingType(IntEnum):
LAST = 0
ALL = 1
CLS = 2
STEP = 3


class Pooler(nn.Module):
Expand All @@ -28,15 +31,47 @@ class Pooler(nn.Module):
normalize: Whether to normalize the pooled data.
"""

def __init__(self,
pooling_type: PoolingType,
normalize: bool,
softmax: bool = False):
def __init__(
self,
pooling_type: PoolingType,
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[List[int]] = None,
):
super().__init__()

self.pooling_type = pooling_type
self.normalize = normalize
self.softmax = softmax
self.step_tag_id = step_tag_id
self.returned_token_ids = returned_token_ids

@classmethod
def from_config_with_defaults(
cls,
pooler_config: PoolerConfig,
pooling_type: PoolingType,
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[List[int]] = None,
) -> Optional["Pooler"]:
if pooler_config is None:
return None
return cls(
pooling_type=PoolingType[pooler_config.pooling_type]
if pooler_config.pooling_type is not None else pooling_type,
normalize=pooler_config.pooling_norm
if pooler_config.pooling_norm is not None else normalize,
softmax=pooler_config.pooling_softmax
if pooler_config.pooling_softmax is not None else softmax,
step_tag_id=pooler_config.pooling_step_tag_id
if pooler_config.pooling_step_tag_id is not None else step_tag_id,
returned_token_ids=pooler_config.pooling_returned_token_ids
if pooler_config.pooling_returned_token_ids is not None else
returned_token_ids,
)

def forward(
self,
Expand All @@ -62,6 +97,25 @@ def forward(
for prompt_len in prompt_lens:
pooled_data.append(hidden_states[offset:offset + prompt_len])
offset += prompt_len
elif self.pooling_type == PoolingType.STEP:
if self.returned_token_ids is not None and len(
self.returned_token_ids) > 0:
logits = hidden_states[:,
self.returned_token_ids].softmax(dim=-1)
else:
logits = hidden_states.softmax(dim=-1)
offset = 0
pooled_data = []
for prompt_len, seq_data_i in zip(
prompt_lens, pooling_metadata.seq_data.values()):
if self.step_tag_id is None:
pooled_data.append(logits[offset:offset + prompt_len])
else:
step_idxs = torch.tensor(
seq_data_i.prompt_token_ids) == self.step_tag_id
pooled_data.append(logits[offset:offset +
prompt_len][step_idxs])
offset += prompt_len
else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}")

Expand Down
Loading