From 1f61d605d7929156ce05f77517b552669e6cdbfd Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 21 Nov 2024 12:59:10 -0800 Subject: [PATCH 1/3] align with cli usage Signed-off-by: youkaichao --- tests/tpu/test_compilation.py | 5 ++--- vllm/entrypoints/llm.py | 12 +++++++++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py index 65bee85e7a1e..b7124ebc1b0f 100644 --- a/tests/tpu/test_compilation.py +++ b/tests/tpu/test_compilation.py @@ -4,7 +4,7 @@ import depyf -from vllm.config import CompilationConfig, CompilationLevel +from vllm.config import CompilationLevel temp_dir = tempfile.mkdtemp() with depyf.prepare_debug(temp_dir): @@ -34,8 +34,7 @@ # all the control llm = LLM(model="google/gemma-2b", enforce_eager=True, - compilation_config=CompilationConfig( - level=CompilationLevel.DYNAMO_AS_IS)) + compilation_config={"level": CompilationLevel.DYNAMO_AS_IS}) outputs = llm.generate(prompts, sampling_params) for output, answer in zip(outputs, answers): prompt = output.prompt diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 86b0b6893f1d..0b8bc45f3280 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -5,10 +5,12 @@ Union, cast, overload) from tqdm import tqdm +import json from vllm import envs from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, get_beam_search_score) +from vllm.config import CompilationConfig from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig, TaskOption) from vllm.engine.llm_engine import LLMEngine @@ -107,13 +109,16 @@ class LLM: hf_overrides: If a dictionary, contains arguments to be forwarded to the HuggingFace config. If a callable, it is called to update the HuggingFace config. + compilation_config: Either an integer or a dictionary. If it is an integer, + it is used as the level of compilation optimization. If it is a dictionary, + it can specify the full compilation configuration. **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See :ref:`engine_args`) Note: This class is intended to be used for offline inference. For online serving, use the :class:`~vllm.AsyncLLMEngine` class instead. - """ + """ # noqa DEPRECATE_LEGACY: ClassVar[bool] = False """A flag to toggle whether to deprecate the legacy generate/encode API.""" @@ -166,6 +171,7 @@ def __init__( # After positional args are removed, move this right below `model` task: TaskOption = "auto", override_pooler_config: Optional[PoolerConfig] = None, + compilation_config: Optional[Union[int, Dict[str, Any]]] = None, **kwargs, ) -> None: ''' @@ -178,6 +184,9 @@ def __init__( if "disable_log_stats" not in kwargs: kwargs["disable_log_stats"] = True + if compilation_config is not None: + compilation_config_instance = CompilationConfig.from_cli( + json.dumps(compilation_config)) engine_args = EngineArgs( model=model, task=task, @@ -202,6 +211,7 @@ def __init__( hf_overrides=hf_overrides, mm_processor_kwargs=mm_processor_kwargs, override_pooler_config=override_pooler_config, + compilation_config=compilation_config_instance, **kwargs, ) # Logic to switch between engines is done at runtime instead of import From 69e3c82bfae4298fdbd1105901de13080283217d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 21 Nov 2024 13:02:01 -0800 Subject: [PATCH 2/3] fix order Signed-off-by: youkaichao --- vllm/entrypoints/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 0b8bc45f3280..f5c2c82f7160 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,11 +1,11 @@ import itertools +import json import warnings from contextlib import contextmanager from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, Union, cast, overload) from tqdm import tqdm -import json from vllm import envs from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, From 322ff68ca04f8501c223cb8bb833554d3ece9b0d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 21 Nov 2024 13:27:48 -0800 Subject: [PATCH 3/3] fix Signed-off-by: youkaichao --- vllm/entrypoints/llm.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f5c2c82f7160..2446a64a02eb 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -187,6 +187,9 @@ def __init__( if compilation_config is not None: compilation_config_instance = CompilationConfig.from_cli( json.dumps(compilation_config)) + else: + compilation_config_instance = None + engine_args = EngineArgs( model=model, task=task,