Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.utils import direct_register_custom_op

os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)

global_counter = 0

# create a library to hold the custom op
Expand Down Expand Up @@ -48,7 +47,11 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
@support_torch_compile
class SillyModel(nn.Module):

def __init__(self) -> None:
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -74,11 +77,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

def test_simple_piecewise_compile():

model = SillyModel()

directory = os.path.dirname(__file__)
config = os.path.join(directory, "piecewise_compilation_config.json")
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)

model = SillyModel(vllm_config=VllmConfig(), prefix='')

inputs = torch.randn(100).cuda()

Expand Down
21 changes: 14 additions & 7 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.plugins import set_compilation_config
from vllm.utils import direct_register_custom_op

Expand Down Expand Up @@ -195,9 +196,15 @@ def forward(
return hidden_states, residual


@support_torch_compile
class LlamaModel(nn.Module):

def __init__(self, config: LlamaConfig) -> None:
def __init__(self,
*,
vllm_config: VllmConfig,
config: LlamaConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__()
self.embedding_tokens = nn.Embedding(
num_embeddings=config.vocab_size,
Expand Down Expand Up @@ -265,10 +272,9 @@ def run_model(llama_config,
CompilationLevel.NO_COMPILATION)
set_compilation_config(None)

cls = LlamaModel
if use_compile:
cls = support_torch_compile(LlamaModel)
model = cls(llama_config).eval().cuda()
model = LlamaModel(config=llama_config,
vllm_config=VllmConfig(),
prefix="").eval().cuda()

B = 16 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
Expand Down Expand Up @@ -357,7 +363,6 @@ def test_toy_llama():
def benchmark():
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
from triton.testing import do_bench
cls = support_torch_compile(LlamaModel)

# similar to llama 3.1-8B
llama_config = LlamaConfig(hidden_size=4096,
Expand Down Expand Up @@ -390,7 +395,9 @@ def benchmark():
else:
set_compilation_config(None)

model = cls(llama_config).eval().cuda().to(torch.bfloat16)
model = LlamaModel(config=llama_config,
vllm_config=VllmConfig(),
prefix="").eval().cuda().to(torch.bfloat16)

B = 256 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
Expand Down
27 changes: 14 additions & 13 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo
Expand Down Expand Up @@ -110,26 +111,26 @@ def _support_torch_compile(cls: type,
"""
A decorator to add support for compiling the forward method of a class.
"""

# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
if envs.VLLM_TORCH_COMPILE_LEVEL in [
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo():
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
# support decorating multiple times
return cls

# take care of method resolution order
# make sure super().__init__ is called on the base class
# other than TorchCompileWrapperWithCustomDispatcher
if TorchCompileWrapperWithCustomDispatcher not in cls.__bases__:
# support decorating multiple times
cls.__bases__ = cls.__bases__ + (
TorchCompileWrapperWithCustomDispatcher, )
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )

old_init = cls.__init__ # type: ignore

def __init__(self, *args, **kwargs):
old_init(self, *args, **kwargs)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self.do_not_compile = envs.VLLM_TORCH_COMPILE_LEVEL in [
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo()
if self.do_not_compile:
return
TorchCompileWrapperWithCustomDispatcher.__init__(self)

cls.__init__ = __init__ # type: ignore
Expand All @@ -138,7 +139,7 @@ def __call__(self, *args, **kwargs):
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
if torch.compiler.is_compiling():
if self.do_not_compile or torch.compiler.is_compiling():
return self.forward(*args, **kwargs)

# the first compilation needs to have dynamic shapes marked
Expand Down
30 changes: 18 additions & 12 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2038,12 +2038,15 @@ class VllmConfig:
simplifies passing around the distinct configurations in the codebase.
"""

model_config: ModelConfig
cache_config: CacheConfig
parallel_config: ParallelConfig
scheduler_config: SchedulerConfig
device_config: DeviceConfig
load_config: LoadConfig
model_config: ModelConfig = field(default=None, init=True) # type: ignore
cache_config: CacheConfig = field(default=None, init=True) # type: ignore
parallel_config: ParallelConfig = field(default=None,
init=True) # type: ignore
scheduler_config: SchedulerConfig = field(default=None,
init=True) # type: ignore
device_config: DeviceConfig = field(default=None,
init=True) # type: ignore
load_config: LoadConfig = field(default=None, init=True) # type: ignore
lora_config: Optional[LoRAConfig] = None
speculative_config: Optional[SpeculativeConfig] = None
decoding_config: Optional[DecodingConfig] = None
Expand Down Expand Up @@ -2088,11 +2091,14 @@ def with_hf_config(self, hf_config: PretrainedConfig) -> "VllmConfig":
def __post_init__(self):
"""Verify configs are valid & consistent with each other.
"""
self.model_config.verify_async_output_proc(self.parallel_config,
self.speculative_config,
self.device_config)
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)
if self.model_config is not None:
self.model_config.verify_async_output_proc(self.parallel_config,
self.speculative_config,
self.device_config)
self.model_config.verify_with_parallel_config(self.parallel_config)

if self.cache_config is not None:
self.cache_config.verify_with_parallel_config(self.parallel_config)

if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config)
Expand Down Expand Up @@ -2146,4 +2152,4 @@ def __str__(self):
self.scheduler_config.num_scheduler_steps,
self.cache_config.enable_prefix_caching,
self.model_config.use_async_output_proc,
self.model_config.mm_processor_kwargs)
self.model_config.mm_processor_kwargs)