Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,10 @@ def test_rope_customization():
trust_remote_code=False,
dtype="float16",
seed=0,
rope_scaling=TEST_ROPE_SCALING,
rope_theta=TEST_ROPE_THETA,
hf_overrides={
"rope_scaling": TEST_ROPE_SCALING,
"rope_theta": TEST_ROPE_THETA,
},
)
assert getattr(llama_model_config.hf_config, "rope_scaling",
None) == TEST_ROPE_SCALING
Expand Down Expand Up @@ -232,7 +234,9 @@ def test_rope_customization():
trust_remote_code=False,
dtype="float16",
seed=0,
rope_scaling=TEST_ROPE_SCALING,
hf_overrides={
"rope_scaling": TEST_ROPE_SCALING,
},
)
assert getattr(longchat_model_config.hf_config, "rope_scaling",
None) == TEST_ROPE_SCALING
Expand Down
30 changes: 22 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import enum
import json
import warnings
from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
Mapping, Optional, Set, Tuple, Type, Union)
Expand Down Expand Up @@ -74,9 +75,6 @@ class ModelConfig:
code_revision: The specific revision to use for the model code on
Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version.
rope_scaling: Dictionary containing the scaling configuration for the
RoPE embeddings. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use
the default version.
Expand Down Expand Up @@ -116,6 +114,7 @@ class ModelConfig:
can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
hf_overrides: Arguments to be forwarded to the HuggingFace config.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor.
pooling_type: Used to configure the pooling method in the embedding
Expand Down Expand Up @@ -146,7 +145,7 @@ def __init__(
allowed_local_media_path: str = "",
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
Expand All @@ -164,6 +163,7 @@ def __init__(
override_neuron_config: Optional[Dict[str, Any]] = None,
config_format: ConfigFormat = ConfigFormat.AUTO,
chat_template_text_format: str = "string",
hf_overrides: Optional[Dict[str, Any]] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
pooling_type: Optional[str] = None,
pooling_norm: Optional[bool] = None,
Expand All @@ -178,8 +178,22 @@ def __init__(
self.seed = seed
self.revision = revision
self.code_revision = code_revision
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta

if hf_overrides is None:
hf_overrides = {}
if rope_scaling is not None:
hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling}
hf_overrides.update(hf_override)
msg = ("`--rope-scaling` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
if rope_theta is not None:
hf_override = {"rope_theta": rope_theta}
hf_overrides.update(hf_override)
msg = ("`--rope-theta` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)

# The tokenizer version is consistent with the model version by default.
if tokenizer_revision is None:
self.tokenizer_revision = revision
Expand All @@ -193,8 +207,8 @@ def __init__(
self.disable_sliding_window = disable_sliding_window
self.skip_tokenizer_init = skip_tokenizer_init
self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, rope_scaling, rope_theta,
config_format)
code_revision, config_format,
**hf_overrides)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config(
Expand Down
14 changes: 11 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,9 @@ class EngineArgs:
disable_log_stats: bool = False
revision: Optional[str] = None
code_revision: Optional[str] = None
rope_scaling: Optional[dict] = None
rope_scaling: Optional[Dict[str, Any]] = None
rope_theta: Optional[float] = None
hf_overrides: Optional[Dict[str, Any]] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
enforce_eager: Optional[bool] = None
Expand All @@ -140,8 +141,9 @@ class EngineArgs:
# is intended for expert use only. The API may change without
# notice.
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
tokenizer_pool_extra_config: Optional[dict] = None
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
enable_lora: bool = False
max_loras: int = 1
max_lora_rank: int = 16
Expand Down Expand Up @@ -187,7 +189,6 @@ class EngineArgs:
collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False
override_neuron_config: Optional[Dict[str, Any]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
scheduling_policy: Literal["fcfs", "priority"] = "fcfs"

# Pooling configuration.
Expand Down Expand Up @@ -512,6 +513,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help='RoPE theta. Use with `rope_scaling`. In '
'some cases, changing the RoPE theta improves the '
'performance of the scaled model.')
parser.add_argument('--hf-overrides',
type=json.loads,
default=EngineArgs.hf_overrides,
help='Extra arguments for the HuggingFace config.'
'This should be a JSON string that will be '
'parsed into a dictionary.')
parser.add_argument('--enforce-eager',
action='store_true',
help='Always use eager-mode PyTorch. If False, '
Expand Down Expand Up @@ -940,6 +947,7 @@ def create_model_config(self) -> ModelConfig:
code_revision=self.code_revision,
rope_scaling=self.rope_scaling,
rope_theta=self.rope_theta,
hf_overrides=self.hf_overrides,
tokenizer_revision=self.tokenizer_revision,
max_model_len=self.max_model_len,
quantization=self.quantization,
Expand Down
5 changes: 1 addition & 4 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,7 @@ def __init__(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"override_neuron_config=%s, "
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"override_neuron_config=%s, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"pipeline_parallel_size=%d, "
Expand All @@ -271,8 +270,6 @@ def __init__(
model_config.tokenizer_mode,
model_config.revision,
model_config.override_neuron_config,
model_config.rope_scaling,
model_config.rope_theta,
model_config.tokenizer_revision,
model_config.trust_remote_code,
model_config.dtype,
Expand Down
7 changes: 6 additions & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ class LLM:
to eager mode. Additionally for encoder-decoder models, if the
sequence length of the encoder input is larger than this, we fall
back to the eager mode.
disable_custom_all_reduce: See ParallelConfig
disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
disable_async_output_proc: Disable async output processing.
This may result in lower performance.
hf_overrides: Arguments to be forwarded to the HuggingFace config.
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`)

Expand Down Expand Up @@ -153,6 +156,7 @@ def __init__(
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
hf_overrides: Optional[dict] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
# After positional args are removed, move this right below `model`
task: TaskOption = "auto",
Expand Down Expand Up @@ -194,6 +198,7 @@ def __init__(
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc,
hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs,
pooling_type=pooling_type,
pooling_norm=pooling_norm,
Expand Down
55 changes: 25 additions & 30 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,8 @@ def get_config(
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None,
config_format: ConfigFormat = ConfigFormat.AUTO,
token: Optional[str] = None,
**kwargs,
) -> PretrainedConfig:
# Separate model folder from file path for GGUF models
Expand All @@ -159,46 +158,51 @@ def get_config(
model = Path(model).parent

if config_format == ConfigFormat.AUTO:
if is_gguf or file_or_path_exists(model,
HF_CONFIG_NAME,
revision=revision,
token=kwargs.get("token")):
if is_gguf or file_or_path_exists(
model, HF_CONFIG_NAME, revision=revision, token=token):
config_format = ConfigFormat.HF
elif file_or_path_exists(model,
MISTRAL_CONFIG_NAME,
revision=revision,
token=kwargs.get("token")):
token=token):
config_format = ConfigFormat.MISTRAL
else:
# If we're in offline mode and found no valid config format, then
# raise an offline mode error to indicate to the user that they
# don't have files cached and may need to go online.
# This is conveniently triggered by calling file_exists().
file_exists(model,
HF_CONFIG_NAME,
revision=revision,
token=kwargs.get("token"))
file_exists(model, HF_CONFIG_NAME, revision=revision, token=token)

raise ValueError(f"No supported config format found in {model}")

if config_format == ConfigFormat.HF:
config_dict, _ = PretrainedConfig.get_config_dict(
model, revision=revision, code_revision=code_revision, **kwargs)
model,
revision=revision,
code_revision=code_revision,
token=token,
**kwargs,
)

# Use custom model class if it's in our registry
model_type = config_dict.get("model_type")
if model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[model_type]
config = config_class.from_pretrained(model,
revision=revision,
code_revision=code_revision)
config = config_class.from_pretrained(
model,
revision=revision,
code_revision=code_revision,
token=token,
**kwargs,
)
else:
try:
config = AutoConfig.from_pretrained(
model,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision,
token=token,
**kwargs,
)
except ValueError as e:
Expand All @@ -216,7 +220,7 @@ def get_config(
raise e

elif config_format == ConfigFormat.MISTRAL:
config = load_params_config(model, revision, token=kwargs.get("token"))
config = load_params_config(model, revision, token=token, **kwargs)
else:
raise ValueError(f"Unsupported config format: {config_format}")

Expand All @@ -228,19 +232,6 @@ def get_config(
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})

for key, value in [
("rope_scaling", rope_scaling),
("rope_theta", rope_theta),
]:
if value is not None:
logger.info(
"Updating %s from %r to %r",
key,
getattr(config, key, None),
value,
)
config.update({key: value})

patch_rope_scaling(config)

return config
Expand Down Expand Up @@ -462,13 +453,15 @@ def _reduce_modelconfig(mc: ModelConfig):

def load_params_config(model: Union[str, Path],
revision: Optional[str],
token: Optional[str] = None) -> PretrainedConfig:
token: Optional[str] = None,
**kwargs) -> PretrainedConfig:
# This function loads a params.json config which
# should be used when loading models in mistral format

config_file_name = "params.json"

config_dict = get_hf_file_to_dict(config_file_name, model, revision, token)
assert isinstance(config_dict, dict)

config_mapping = {
"dim": "hidden_size",
Expand Down Expand Up @@ -512,6 +505,8 @@ def recurse_elems(elem: Any):
config_dict["architectures"] = ["PixtralForConditionalGeneration"]
config_dict["model_type"] = "pixtral"

config_dict.update(kwargs)

config = recurse_elems(config_dict)
return config

Expand Down
5 changes: 1 addition & 4 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ def __init__(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"override_neuron_config=%s, "
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"override_neuron_config=%s, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"pipeline_parallel_size=%d, "
Expand All @@ -94,8 +93,6 @@ def __init__(
model_config.tokenizer_mode,
model_config.revision,
model_config.override_neuron_config,
model_config.rope_scaling,
model_config.rope_theta,
model_config.tokenizer_revision,
model_config.trust_remote_code,
model_config.dtype,
Expand Down