From d23b1e34c4e6ae7749d5630b232628b45e38b1e7 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 19 Jan 2025 22:01:07 +0800 Subject: [PATCH 01/13] finish for v1 Signed-off-by: youkaichao --- vllm/v1/worker/gpu_model_runner.py | 16 ++++++++-------- vllm/v1/worker/gpu_worker.py | 11 +++++++++++ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 87a1cd7f9e62..3fa7d77256c7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,11 +1,10 @@ import gc import time -from typing import TYPE_CHECKING, Dict, List, Tuple, cast +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast import numpy as np import torch import torch.distributed -import torch.nn as nn from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention @@ -823,10 +822,12 @@ def load_model(self) -> None: @torch.inference_mode() def _dummy_run( self, - model: nn.Module, num_tokens: int, - kv_caches: List[torch.Tensor], + kv_caches: Optional[List[torch.Tensor]] = None, ) -> torch.Tensor: + model = self.model + if kv_caches is None: + kv_caches = self.kv_caches if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -952,8 +953,7 @@ def profile_run(self) -> None: self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) # Trigger compilation for general shape. - hidden_states = self._dummy_run(self.model, self.max_num_tokens, - dummy_kv_caches) + hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches) logits = self.model.compute_logits(hidden_states, None) logits = logits[:self.max_num_tokens] # TODO(woosuk): Consider the memory usage of the sampler. @@ -979,8 +979,8 @@ def capture_model(self) -> None: for num_tokens in reversed(self.cudagraph_batch_sizes): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): - self._dummy_run(self.model, num_tokens, self.kv_caches) - self._dummy_run(self.model, num_tokens, self.kv_caches) + self._dummy_run(num_tokens) + self._dummy_run(num_tokens) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 4fb4197f1822..ab7d0a257a96 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -170,6 +170,17 @@ def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None: self.model_runner.initialize_kv_cache(kv_cache_config) def compile_or_warm_up_model(self) -> None: + # warm up sizes that are not in cudagraph capture sizes, + # but users still want to compile for better performance, + # e.g. for the max-num-batched token size in chunked prefill. + warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() + if not self.model_config.enforce_eager: + warmup_sizes = [ + x for x in warmup_sizes + if x not in self.vllm_config.compilation_config.capture_sizes + ] + for size in sorted(warmup_sizes, reverse=True): + self.model_runner._dummy_run(size) if not self.model_config.enforce_eager: self.model_runner.capture_model() # Reset the seed to ensure that the random state is not affected by From e06077023c2f0ffc4a6ddd7ae0762039d6c7713c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 20 Jan 2025 10:23:41 +0800 Subject: [PATCH 02/13] update Signed-off-by: youkaichao --- vllm/config.py | 45 +++++++++++++++++-------------------- vllm/worker/model_runner.py | 12 +++++++--- vllm/worker/worker.py | 11 +++++++++ 3 files changed, 41 insertions(+), 27 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 4698a0502033..ff5a2707db98 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2703,10 +2703,10 @@ class CompilationConfig(BaseModel): - use_inductor: whether to use inductor compilation. - False: inductor compilation is not used. graph runs in eager. - True: inductor compilation is used. one graph for symbolic shape - is compiled. In addition, compile for cudagraph sizes that are - in candidate_compile_sizes, using configurations - in inductor_compile_config. - - candidate_compile_sizes: sizes to compile for inductor. + is compiled. In addition, compile for cudagraph sizes, + max-tokens for chunked prefill, and additional_compile_sizes, + using configurations in inductor_compile_config. + - additional_compile_sizes: additional sizes to compile for inductor. - inductor_compile_config: additional configurations for inductor. - None: use default configurations. - inductor_passes: additional passes for inductor. It is a dictionary @@ -2734,7 +2734,7 @@ class CompilationConfig(BaseModel): splitting_ops: List[str] = Field(default=None) # type: ignore use_inductor: bool = True - candidate_compile_sizes: Optional[List[int]] = Field(default=None) + additional_compile_sizes: Optional[List[int]] = Field(default=None) inductor_compile_config: Dict = Field(default_factory=dict) inductor_passes: Dict[str, str] = Field(default_factory=dict) @@ -2909,31 +2909,23 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: from vllm.compilation.backends import VllmBackend return VllmBackend(vllm_config) - def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): + def init_with_specific_sizes(self, cudagraph_capture_sizes: List[int], + only_compile_sizes: List[int]) -> None: """To complete the initialization of config, - we need to know the cudagraph sizes.""" + we need to know the cudagraph sizes, and the + sizes we only compile but not capture cudagraph.""" if self.cudagraph_capture_sizes is None: - self.capture_sizes = sizes_to_specialize + self.capture_sizes = cudagraph_capture_sizes else: self.capture_sizes = self.cudagraph_capture_sizes logger.info(("cudagraph sizes specified by model runner" " %s is overridden by config %s"), - sizes_to_specialize, self.cudagraph_capture_sizes) + cudagraph_capture_sizes, self.cudagraph_capture_sizes) - if self.candidate_compile_sizes is None: - self.candidate_compile_sizes = [] - self.compile_sizes = [ - x for x in self.candidate_compile_sizes if x in self.capture_sizes - ] - ignored_sizes = [ - x for x in self.candidate_compile_sizes - if x not in self.capture_sizes - ] - if ignored_sizes: - logger.warning(("candidate_compile_sizes %s are ignored " - "because they are not cudagraph capture sizes."), - ignored_sizes) + if self.additional_compile_sizes is None: + self.additional_compile_sizes = [] + self.compile_sizes = self.capture_sizes + self.additional_compile_sizes + only_compile_sizes # sort to make sure cudagraph capture sizes are in descending order self.capture_sizes.sort(reverse=True) @@ -3261,8 +3253,13 @@ def _set_cudagraph_sizes(self): batch_size_capture_list = [1, 2, 4 ] + [i for i in range(8, 513, 8)] - self.compilation_config.init_with_cudagraph_sizes( - batch_size_capture_list) + only_compile_sizes = [] + if self.scheduler_config is not None and \ + self.scheduler_config.chunked_prefill_enabled: + only_compile_sizes = [self.scheduler_config.max_num_batched_tokens] + + self.compilation_config.init_with_specific_sizes( + batch_size_capture_list, only_compile_sizes) def __str__(self): return ( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ae8b7f97c827..60401dcc33da 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1241,13 +1241,19 @@ def set_in_profile_run(self): @torch.inference_mode() def profile_run(self) -> None: + max_num_batched_tokens = \ + self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + self._dummy_run(max_num_batched_tokens, max_num_seqs) + + def _dummy_run(self, + max_num_batched_tokens: int, + max_num_seqs: int = 1) -> None: with self.set_in_profile_run(): # Enable top-k sampling to reflect the accurate memory usage. sampling_params = \ SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) - max_num_batched_tokens = \ - self.scheduler_config.max_num_batched_tokens - max_num_seqs = self.scheduler_config.max_num_seqs + # This represents the maximum number of different requests # that will have unique loras, an therefore the max amount of memory # consumption create dummy lora request copies from the lora request diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 29d62ddda3dc..af8b9483916d 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -288,6 +288,17 @@ def _init_cache_engine(self): self.gpu_cache) def _warm_up_model(self) -> None: + # warm up sizes that are not in cudagraph capture sizes, + # but users still want to compile for better performance, + # e.g. for the max-num-batched token size in chunked prefill. + warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() + if not self.model_config.enforce_eager: + warmup_sizes = [ + x for x in warmup_sizes + if x not in self.vllm_config.compilation_config.capture_sizes + ] + for size in sorted(warmup_sizes, reverse=True): + self.model_runner._dummy_run(size) if not self.model_config.enforce_eager: self.model_runner.capture_model(self.gpu_cache) # Reset the seed to ensure that the random state is not affected by From 83e708d0d9c611e67a57861b7370b043dd2ec402 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 20 Jan 2025 14:49:21 +0800 Subject: [PATCH 03/13] fix Signed-off-by: youkaichao --- vllm/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index ff5a2707db98..f3dc008bcfa0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2925,7 +2925,8 @@ def init_with_specific_sizes(self, cudagraph_capture_sizes: List[int], if self.additional_compile_sizes is None: self.additional_compile_sizes = [] - self.compile_sizes = self.capture_sizes + self.additional_compile_sizes + only_compile_sizes + self.compile_sizes = self.capture_sizes + \ + self.additional_compile_sizes + only_compile_sizes # sort to make sure cudagraph capture sizes are in descending order self.capture_sizes.sort(reverse=True) From 1dd79dd1fb9d4a9f2240a1505462c9ab880d9fd5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 21 Jan 2025 13:13:09 +0800 Subject: [PATCH 04/13] fix Signed-off-by: youkaichao --- vllm/config.py | 38 +++++++++++++++++------------- vllm/v1/worker/gpu_model_runner.py | 1 + 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index f06440f20d79..e80b71800e36 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2703,10 +2703,11 @@ class CompilationConfig(BaseModel): - use_inductor: whether to use inductor compilation. - False: inductor compilation is not used. graph runs in eager. - True: inductor compilation is used. one graph for symbolic shape - is compiled. In addition, compile for cudagraph sizes, - max-tokens for chunked prefill, and additional_compile_sizes, + is compiled. In addition, compile for backend_compile_sizes, using configurations in inductor_compile_config. - - additional_compile_sizes: additional sizes to compile for inductor. + - backend_compile_sizes: sizes to compile for inductor. In addition + to integers, it also supports "cudagraph" to + specify the sizes for cudagraph capture. - inductor_compile_config: additional configurations for inductor. - None: use default configurations. - inductor_passes: additional passes for inductor. It is a dictionary @@ -2734,7 +2735,8 @@ class CompilationConfig(BaseModel): splitting_ops: List[str] = Field(default=None) # type: ignore use_inductor: bool = True - additional_compile_sizes: Optional[List[int]] = Field(default=None) + backend_compile_sizes: Optional[List[Union[int, + str]]] = Field(default=None) inductor_compile_config: Dict = Field(default_factory=dict) inductor_passes: Dict[str, str] = Field(default_factory=dict) @@ -2909,8 +2911,8 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: from vllm.compilation.backends import VllmBackend return VllmBackend(vllm_config) - def init_with_specific_sizes(self, cudagraph_capture_sizes: List[int], - only_compile_sizes: List[int]) -> None: + def init_with_specific_sizes(self, + cudagraph_capture_sizes: List[int]) -> None: """To complete the initialization of config, we need to know the cudagraph sizes, and the sizes we only compile but not capture cudagraph.""" @@ -2923,10 +2925,19 @@ def init_with_specific_sizes(self, cudagraph_capture_sizes: List[int], " %s is overridden by config %s"), cudagraph_capture_sizes, self.cudagraph_capture_sizes) - if self.additional_compile_sizes is None: - self.additional_compile_sizes = [] - self.compile_sizes = self.capture_sizes + \ - self.additional_compile_sizes + only_compile_sizes + if self.backend_compile_sizes is None: + self.backend_compile_sizes = [] + + self.compile_sizes = [] + for x in self.backend_compile_sizes: + if isinstance(x, str): + assert x == "cudagraph", \ + "Unrecognized size type in backend_compile_sizes, " \ + f"expect cudagraph, got {x}" + self.compile_sizes.extend(self.capture_sizes) + else: + assert isinstance(x, int) + self.compile_sizes.append(x) # sort to make sure cudagraph capture sizes are in descending order self.capture_sizes.sort(reverse=True) @@ -3254,13 +3265,8 @@ def _set_cudagraph_sizes(self): batch_size_capture_list = [1, 2, 4 ] + [i for i in range(8, 513, 8)] - only_compile_sizes = [] - if self.scheduler_config is not None and \ - self.scheduler_config.chunked_prefill_enabled: - only_compile_sizes = [self.scheduler_config.max_num_batched_tokens] - self.compilation_config.init_with_specific_sizes( - batch_size_capture_list, only_compile_sizes) + batch_size_capture_list) def __str__(self): return ( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3fb8bf94f9ed..de5d084f5d01 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5,6 +5,7 @@ import numpy as np import torch import torch.distributed +import torch.nn as nn from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention From e3416e30a6dd0dd9656ef724285c754c0de0038b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 21 Jan 2025 13:14:12 +0800 Subject: [PATCH 05/13] fix Signed-off-by: youkaichao --- vllm/config.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index e80b71800e36..a4c566b718c7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2914,8 +2914,7 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: def init_with_specific_sizes(self, cudagraph_capture_sizes: List[int]) -> None: """To complete the initialization of config, - we need to know the cudagraph sizes, and the - sizes we only compile but not capture cudagraph.""" + we need to know the cudagraph sizes.""" if self.cudagraph_capture_sizes is None: self.capture_sizes = cudagraph_capture_sizes From a1b948f42ab6b86c0295c602a061f4e7db0f2920 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 21 Jan 2025 13:15:05 +0800 Subject: [PATCH 06/13] rename Signed-off-by: youkaichao --- vllm/config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index a4c566b718c7..577d9cb6d6c1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2911,8 +2911,8 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: from vllm.compilation.backends import VllmBackend return VllmBackend(vllm_config) - def init_with_specific_sizes(self, - cudagraph_capture_sizes: List[int]) -> None: + def init_with_cudagraph_sizes(self, + cudagraph_capture_sizes: List[int]) -> None: """To complete the initialization of config, we need to know the cudagraph sizes.""" @@ -3264,7 +3264,7 @@ def _set_cudagraph_sizes(self): batch_size_capture_list = [1, 2, 4 ] + [i for i in range(8, 513, 8)] - self.compilation_config.init_with_specific_sizes( + self.compilation_config.init_with_cudagraph_sizes( batch_size_capture_list) def __str__(self): From a29cffadb121ecb88aa1f3bd80804f3af18738fd Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 21 Jan 2025 13:16:15 +0800 Subject: [PATCH 07/13] add logging Signed-off-by: youkaichao --- vllm/v1/worker/gpu_worker.py | 1 + vllm/worker/worker.py | 1 + 2 files changed, 2 insertions(+) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index f0ff5635f8b4..af964c7a058b 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -180,6 +180,7 @@ def compile_or_warm_up_model(self) -> None: if x not in self.vllm_config.compilation_config.capture_sizes ] for size in sorted(warmup_sizes, reverse=True): + logger.info("Compile and warming up model for size %d", size) self.model_runner._dummy_run(size) if not self.model_config.enforce_eager: self.model_runner.capture_model() diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index af8b9483916d..04d0f3a85233 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -298,6 +298,7 @@ def _warm_up_model(self) -> None: if x not in self.vllm_config.compilation_config.capture_sizes ] for size in sorted(warmup_sizes, reverse=True): + logger.info("Compile and warming up model for size %d", size) self.model_runner._dummy_run(size) if not self.model_config.enforce_eager: self.model_runner.capture_model(self.gpu_cache) From 44d5e2d6a600c25300f4d6048d41fcd5894d99b3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 21 Jan 2025 19:48:44 +0800 Subject: [PATCH 08/13] user-facing: use capture_sizes Signed-off-by: youkaichao --- tests/compile/piecewise/test_simple.py | 2 +- tests/compile/piecewise/test_toy_llama.py | 6 +++--- vllm/config.py | 10 ++++------ 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index aa11524812cd..f5ac94374022 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -80,7 +80,7 @@ def test_simple_piecewise_compile(): use_cudagraph=True, splitting_ops=["silly.attention"], cudagraph_copy_inputs=True, - cudagraph_capture_sizes=[1, 2], + capture_sizes=[1, 2], )) with set_current_vllm_config(vllm_config): model = SillyModel(vllm_config=vllm_config, prefix='') diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index d4ede4d2320a..656f3b5ddcfa 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -265,7 +265,7 @@ def run_model(llama_config, compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, - cudagraph_capture_sizes=[1, 2], + capture_sizes=[1, 2], ) if split_attn: compilation_config.splitting_ops = ["silly.attention"] @@ -389,12 +389,12 @@ def benchmark(): level=CompilationLevel.PIECEWISE, use_cudagraph=True, splitting_ops=["silly.attention"], - cudagraph_capture_sizes=cudagraph_sizes, + capture_sizes=cudagraph_sizes, ) else: compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE, - cudagraph_capture_sizes=cudagraph_sizes, + capture_sizes=cudagraph_sizes, ) vllm_config = VllmConfig(compilation_config=compilation_config) diff --git a/vllm/config.py b/vllm/config.py index 577d9cb6d6c1..bb9806fd4cdf 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2687,7 +2687,7 @@ class CompilationConfig(BaseModel): outside of compilation. TODO: move outside cudagraph logic into compilation. torch.compile will handle cudagraph capture logic in the future. - - cudagraph_capture_sizes: sizes to capture cudagraph. + - capture_sizes: sizes to capture cudagraph. - None (default): capture sizes are inferred from vllm config. - List[int]: capture sizes are specified as given. - cudagraph_num_of_warmups: number of warmup runs for cudagraph. @@ -2742,7 +2742,7 @@ class CompilationConfig(BaseModel): use_cudagraph: bool = False cudagraph_num_of_warmups: int = 0 - cudagraph_capture_sizes: Optional[List[int]] = None + capture_sizes: Optional[List[int]] = None cudagraph_copy_inputs: bool = False class PassConfig(BaseModel): @@ -2785,7 +2785,6 @@ def model_post_init(self, __context: Any) -> None: # not configurable, computed after init compile_sizes: List[int] = PrivateAttr - capture_sizes: List[int] = PrivateAttr max_capture_size: int = PrivateAttr # optimization: # Intuitively, bs_to_padded_graph_size should be Dict[int, int]. @@ -2916,13 +2915,12 @@ def init_with_cudagraph_sizes(self, """To complete the initialization of config, we need to know the cudagraph sizes.""" - if self.cudagraph_capture_sizes is None: + if self.capture_sizes is None: self.capture_sizes = cudagraph_capture_sizes else: - self.capture_sizes = self.cudagraph_capture_sizes logger.info(("cudagraph sizes specified by model runner" " %s is overridden by config %s"), - cudagraph_capture_sizes, self.cudagraph_capture_sizes) + cudagraph_capture_sizes, self.capture_sizes) if self.backend_compile_sizes is None: self.backend_compile_sizes = [] From 58e8422b0f940336eab813167169399b26d3eedf Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 21 Jan 2025 19:51:56 +0800 Subject: [PATCH 09/13] user-facing: use compile_sizes Signed-off-by: youkaichao --- vllm/config.py | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index bb9806fd4cdf..497fb081b193 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2703,9 +2703,9 @@ class CompilationConfig(BaseModel): - use_inductor: whether to use inductor compilation. - False: inductor compilation is not used. graph runs in eager. - True: inductor compilation is used. one graph for symbolic shape - is compiled. In addition, compile for backend_compile_sizes, + is compiled. In addition, compile for compile_sizes, using configurations in inductor_compile_config. - - backend_compile_sizes: sizes to compile for inductor. In addition + - compile_sizes: sizes to compile for inductor. In addition to integers, it also supports "cudagraph" to specify the sizes for cudagraph capture. - inductor_compile_config: additional configurations for inductor. @@ -2735,8 +2735,7 @@ class CompilationConfig(BaseModel): splitting_ops: List[str] = Field(default=None) # type: ignore use_inductor: bool = True - backend_compile_sizes: Optional[List[Union[int, - str]]] = Field(default=None) + compile_sizes: Optional[List[Union[int, str]]] = Field(default=None) inductor_compile_config: Dict = Field(default_factory=dict) inductor_passes: Dict[str, str] = Field(default_factory=dict) @@ -2784,7 +2783,6 @@ def model_post_init(self, __context: Any) -> None: pass_config: PassConfig = Field(default_factory=PassConfig) # not configurable, computed after init - compile_sizes: List[int] = PrivateAttr max_capture_size: int = PrivateAttr # optimization: # Intuitively, bs_to_padded_graph_size should be Dict[int, int]. @@ -2922,19 +2920,18 @@ def init_with_cudagraph_sizes(self, " %s is overridden by config %s"), cudagraph_capture_sizes, self.capture_sizes) - if self.backend_compile_sizes is None: - self.backend_compile_sizes = [] - - self.compile_sizes = [] - for x in self.backend_compile_sizes: - if isinstance(x, str): - assert x == "cudagraph", \ - "Unrecognized size type in backend_compile_sizes, " \ - f"expect cudagraph, got {x}" - self.compile_sizes.extend(self.capture_sizes) - else: - assert isinstance(x, int) - self.compile_sizes.append(x) + computed_compile_sizes = [] + if self.compile_sizes is not None: + for x in self.compile_sizes: + if isinstance(x, str): + assert x == "cudagraph", \ + "Unrecognized size type in compile_sizes, " \ + f"expect 'cudagraph', got {x}" + computed_compile_sizes.extend(self.capture_sizes) + else: + assert isinstance(x, int) + computed_compile_sizes.append(x) + self.compile_sizes = computed_compile_sizes # sort to make sure cudagraph capture sizes are in descending order self.capture_sizes.sort(reverse=True) From 29b08b7838588755b566486ab3a67a2bae23f65c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 00:01:13 +0800 Subject: [PATCH 10/13] rename to cudagraph_capture_sizes Signed-off-by: youkaichao --- tests/compile/piecewise/test_simple.py | 2 +- tests/compile/piecewise/test_toy_llama.py | 6 ++--- vllm/compilation/backends.py | 10 ++++---- vllm/config.py | 30 +++++++++++------------ vllm/engine/metrics.py | 3 ++- vllm/v1/worker/gpu_model_runner.py | 3 ++- vllm/v1/worker/gpu_worker.py | 4 +-- vllm/worker/model_runner.py | 15 ++++++------ vllm/worker/worker.py | 4 +-- 9 files changed, 40 insertions(+), 37 deletions(-) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index f5ac94374022..aa11524812cd 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -80,7 +80,7 @@ def test_simple_piecewise_compile(): use_cudagraph=True, splitting_ops=["silly.attention"], cudagraph_copy_inputs=True, - capture_sizes=[1, 2], + cudagraph_capture_sizes=[1, 2], )) with set_current_vllm_config(vllm_config): model = SillyModel(vllm_config=vllm_config, prefix='') diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 656f3b5ddcfa..d4ede4d2320a 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -265,7 +265,7 @@ def run_model(llama_config, compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, - capture_sizes=[1, 2], + cudagraph_capture_sizes=[1, 2], ) if split_attn: compilation_config.splitting_ops = ["silly.attention"] @@ -389,12 +389,12 @@ def benchmark(): level=CompilationLevel.PIECEWISE, use_cudagraph=True, splitting_ops=["silly.attention"], - capture_sizes=cudagraph_sizes, + cudagraph_capture_sizes=cudagraph_sizes, ) else: compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE, - capture_sizes=cudagraph_sizes, + cudagraph_capture_sizes=cudagraph_sizes, ) vllm_config = VllmConfig(compilation_config=compilation_config) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 955c25f30051..193d3c466f1e 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -662,7 +662,7 @@ def copy_and_call(*args): class ConcreteSizeEntry: runtime_shape: int need_to_compile: bool # the size is in compile_sizes - use_cudagraph: bool # the size is in capture_sizes + use_cudagraph: bool # the size is in cudagraph_capture_sizes compiled: bool = False runnable: Callable = None # type: ignore @@ -709,8 +709,8 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, self.compile_sizes: Set[int] = set( self.compilation_config.compile_sizes) - self.capture_sizes: Set[int] = set( - self.compilation_config.capture_sizes + self.cudagraph_capture_sizes: Set[int] = set( + self.compilation_config.cudagraph_capture_sizes ) if self.compilation_config.use_cudagraph else set() self.first_run_finished = False @@ -728,11 +728,11 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, # to_be_compiled_sizes tracks the remaining sizes to compile, # and updates during the compilation process, so we need to copy it self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy() - for shape in self.compile_sizes.union(self.capture_sizes): + for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): self.concrete_size_entries[shape] = ConcreteSizeEntry( runtime_shape=shape, need_to_compile=shape in self.compile_sizes, - use_cudagraph=shape in self.capture_sizes, + use_cudagraph=shape in self.cudagraph_capture_sizes, ) def check_for_ending_compilation(self): diff --git a/vllm/config.py b/vllm/config.py index 497fb081b193..244b0e53140a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2687,7 +2687,7 @@ class CompilationConfig(BaseModel): outside of compilation. TODO: move outside cudagraph logic into compilation. torch.compile will handle cudagraph capture logic in the future. - - capture_sizes: sizes to capture cudagraph. + - cudagraph_capture_sizes: sizes to capture cudagraph. - None (default): capture sizes are inferred from vllm config. - List[int]: capture sizes are specified as given. - cudagraph_num_of_warmups: number of warmup runs for cudagraph. @@ -2741,7 +2741,7 @@ class CompilationConfig(BaseModel): use_cudagraph: bool = False cudagraph_num_of_warmups: int = 0 - capture_sizes: Optional[List[int]] = None + cudagraph_capture_sizes: Optional[List[int]] = None cudagraph_copy_inputs: bool = False class PassConfig(BaseModel): @@ -2913,12 +2913,12 @@ def init_with_cudagraph_sizes(self, """To complete the initialization of config, we need to know the cudagraph sizes.""" - if self.capture_sizes is None: - self.capture_sizes = cudagraph_capture_sizes + if self.cudagraph_capture_sizes is None: + self.cudagraph_capture_sizes = cudagraph_capture_sizes else: logger.info(("cudagraph sizes specified by model runner" " %s is overridden by config %s"), - cudagraph_capture_sizes, self.capture_sizes) + cudagraph_capture_sizes, self.cudagraph_capture_sizes) computed_compile_sizes = [] if self.compile_sizes is not None: @@ -2927,23 +2927,23 @@ def init_with_cudagraph_sizes(self, assert x == "cudagraph", \ "Unrecognized size type in compile_sizes, " \ f"expect 'cudagraph', got {x}" - computed_compile_sizes.extend(self.capture_sizes) + computed_compile_sizes.extend(self.cudagraph_capture_sizes) else: assert isinstance(x, int) computed_compile_sizes.append(x) self.compile_sizes = computed_compile_sizes # sort to make sure cudagraph capture sizes are in descending order - self.capture_sizes.sort(reverse=True) - self.max_capture_size = self.capture_sizes[ - 0] if self.capture_sizes else 0 + self.cudagraph_capture_sizes.sort(reverse=True) + self.max_capture_size = self.cudagraph_capture_sizes[ + 0] if self.cudagraph_capture_sizes else 0 # pre-compute the mapping from batch size to padded graph size self.bs_to_padded_graph_size = [ 0 for i in range(self.max_capture_size + 1) ] - for end, start in zip(self.capture_sizes, - self.capture_sizes[1:] + [0]): + for end, start in zip(self.cudagraph_capture_sizes, + self.cudagraph_capture_sizes[1:] + [0]): for bs in range(start, end): if bs == start: self.bs_to_padded_graph_size[bs] = start @@ -3214,14 +3214,14 @@ def _set_cudagraph_sizes(self): However, if users specify the cudagraph capture sizes through compilation config, we will use the specified sizes instead. - In the end, `vllm_config.compilation_config.capture_sizes` will be the - final sizes to capture cudagraph (in descending order). + In the end, `vllm_config.compilation_config.cudagraph_capture_sizes` + will be the final sizes to capture cudagraph (in descending order). During runtime, if batchsize is larger than - `vllm_config.compilation_config.capture_sizes`, + `vllm_config.compilation_config.cudagraph_capture_sizes`, no cudagraph will be used. If the batch size is no larger than - `vllm_config.compilation_config.capture_sizes`, + `vllm_config.compilation_config.cudagraph_capture_sizes`, we can quickly find the padded graph size for a given batch size by looking up `vllm_config.compilation_config.bs_to_padded_graph_size`. """ diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index c8aec8dd3afa..f7ce21d0ae98 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -120,7 +120,8 @@ def __init__(self, labelnames: List[str], vllm_config: VllmConfig): labelnames=labelnames) buckets = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096] if not vllm_config.model_config.enforce_eager: - buckets = vllm_config.compilation_config.capture_sizes.copy() + buckets = vllm_config.compilation_config.\ + cudagraph_capture_sizes.copy() buckets.sort() self.histogram_iteration_tokens = self._histogram_cls( name="vllm:iteration_tokens_total", diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index de5d084f5d01..8fd54d0772da 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -127,7 +127,8 @@ def __init__( # self.cudagraph_batch_sizes sorts in ascending order. # The batch sizes in the config are in descending order. self.cudagraph_batch_sizes = list( - reversed(self.vllm_config.compilation_config.capture_sizes)) + reversed( + self.vllm_config.compilation_config.cudagraph_capture_sizes)) # Cache the device properties. self.device_properties = torch.cuda.get_device_properties(self.device) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index af964c7a058b..d2a3dc4b3a3c 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -176,8 +176,8 @@ def compile_or_warm_up_model(self) -> None: warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() if not self.model_config.enforce_eager: warmup_sizes = [ - x for x in warmup_sizes - if x not in self.vllm_config.compilation_config.capture_sizes + x for x in warmup_sizes if x not in + self.vllm_config.compilation_config.cudagraph_capture_sizes ] for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ead5cb033bac..901e424d3285 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1485,13 +1485,14 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: for virtual_engine in range( self.parallel_config.pipeline_parallel_size): # Only rank 0 should print progress bar during capture - capture_sizes = ( - tqdm( - self.vllm_config.compilation_config.capture_sizes, - desc="Capturing CUDA graph shapes", - ) if get_tensor_model_parallel_rank() == 0 else - self.vllm_config.compilation_config.capture_sizes) - for batch_size in capture_sizes: + cudagraph_capture_sizes = (tqdm( + self.vllm_config.compilation_config. + cudagraph_capture_sizes, + desc="Capturing CUDA graph shapes", + ) if get_tensor_model_parallel_rank() == 0 else + self.vllm_config.compilation_config. + cudagraph_capture_sizes) + for batch_size in cudagraph_capture_sizes: attn_metadata = ( self.attn_state.graph_capture_get_metadata_for_batch( batch_size, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 04d0f3a85233..ff8eb31c6ca5 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -294,8 +294,8 @@ def _warm_up_model(self) -> None: warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() if not self.model_config.enforce_eager: warmup_sizes = [ - x for x in warmup_sizes - if x not in self.vllm_config.compilation_config.capture_sizes + x for x in warmup_sizes if x not in + self.vllm_config.compilation_config.cudagraph_capture_sizes ] for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) From ce1ead5b4a40822f154db4fff0845043ca18f320 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 00:03:09 +0800 Subject: [PATCH 11/13] fix string Signed-off-by: youkaichao --- vllm/config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 244b0e53140a..71143b642a1a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2706,7 +2706,7 @@ class CompilationConfig(BaseModel): is compiled. In addition, compile for compile_sizes, using configurations in inductor_compile_config. - compile_sizes: sizes to compile for inductor. In addition - to integers, it also supports "cudagraph" to + to integers, it also supports "cudagraph_capture_sizes" to specify the sizes for cudagraph capture. - inductor_compile_config: additional configurations for inductor. - None: use default configurations. @@ -2924,9 +2924,9 @@ def init_with_cudagraph_sizes(self, if self.compile_sizes is not None: for x in self.compile_sizes: if isinstance(x, str): - assert x == "cudagraph", \ + assert x == "cudagraph_capture_sizes", \ "Unrecognized size type in compile_sizes, " \ - f"expect 'cudagraph', got {x}" + f"expect 'cudagraph_capture_sizes', got {x}" computed_compile_sizes.extend(self.cudagraph_capture_sizes) else: assert isinstance(x, int) From 39b9245f77d4e834c6b950affed581cd79ae32f4 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 00:05:05 +0800 Subject: [PATCH 12/13] de-duplicate sizes Signed-off-by: youkaichao --- vllm/config.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index 71143b642a1a..6337dff989f2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2916,12 +2916,17 @@ def init_with_cudagraph_sizes(self, if self.cudagraph_capture_sizes is None: self.cudagraph_capture_sizes = cudagraph_capture_sizes else: + # de-duplicate the sizes provided by the config + self.cudagraph_capture_sizes = list( + set(self.cudagraph_capture_sizes)) logger.info(("cudagraph sizes specified by model runner" " %s is overridden by config %s"), cudagraph_capture_sizes, self.cudagraph_capture_sizes) computed_compile_sizes = [] if self.compile_sizes is not None: + # de-duplicate the sizes provided by the config + self.compile_sizes = list(set(self.compile_sizes)) for x in self.compile_sizes: if isinstance(x, str): assert x == "cudagraph_capture_sizes", \ From a91a568d7520c481ce1117b060fbc3aae39aefa0 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 00:12:05 +0800 Subject: [PATCH 13/13] fix format Signed-off-by: youkaichao --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 6337dff989f2..69990fa910b3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2936,7 +2936,7 @@ def init_with_cudagraph_sizes(self, else: assert isinstance(x, int) computed_compile_sizes.append(x) - self.compile_sizes = computed_compile_sizes + self.compile_sizes = computed_compile_sizes # type: ignore # sort to make sure cudagraph capture sizes are in descending order self.cudagraph_capture_sizes.sort(reverse=True)