Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions tests/tpu/test_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import depyf

from vllm.config import CompilationConfig, CompilationLevel
from vllm.config import CompilationLevel

temp_dir = tempfile.mkdtemp()
with depyf.prepare_debug(temp_dir):
Expand Down Expand Up @@ -34,8 +34,7 @@
# all the control
llm = LLM(model="google/gemma-2b",
enforce_eager=True,
compilation_config=CompilationConfig(
level=CompilationLevel.DYNAMO_AS_IS))
compilation_config={"level": CompilationLevel.DYNAMO_AS_IS})
outputs = llm.generate(prompts, sampling_params)
for output, answer in zip(outputs, answers):
prompt = output.prompt
Expand Down
15 changes: 14 additions & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import json
import warnings
from contextlib import contextmanager
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
Expand All @@ -9,6 +10,7 @@
from vllm import envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
from vllm.config import CompilationConfig
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
TaskOption)
from vllm.engine.llm_engine import LLMEngine
Expand Down Expand Up @@ -107,13 +109,16 @@ class LLM:
hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the
HuggingFace config.
compilation_config: Either an integer or a dictionary. If it is an integer,
it is used as the level of compilation optimization. If it is a dictionary,
it can specify the full compilation configuration.
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`)

Note:
This class is intended to be used for offline inference. For online
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
"""
""" # noqa
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is it because of the line-too-long error? If so, can you please fix L113?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a triple quote, adding # noqa can ignore the line length limit for all lines inside the triple quote.

Copy link
Collaborator

@WoosukKwon WoosukKwon Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes please don't use noqa in this case. Please adhere to the line limit as much as possible.


DEPRECATE_LEGACY: ClassVar[bool] = False
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
Expand Down Expand Up @@ -166,6 +171,7 @@ def __init__(
# After positional args are removed, move this right below `model`
task: TaskOption = "auto",
override_pooler_config: Optional[PoolerConfig] = None,
compilation_config: Optional[Union[int, Dict[str, Any]]] = None,
**kwargs,
) -> None:
'''
Expand All @@ -178,6 +184,12 @@ def __init__(
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True

if compilation_config is not None:
compilation_config_instance = CompilationConfig.from_cli(
json.dumps(compilation_config))
else:
compilation_config_instance = None

engine_args = EngineArgs(
model=model,
task=task,
Expand All @@ -202,6 +214,7 @@ def __init__(
hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs,
override_pooler_config=override_pooler_config,
compilation_config=compilation_config_instance,
**kwargs,
)
# Logic to switch between engines is done at runtime instead of import
Expand Down