diff --git a/vllm/config.py b/vllm/config.py index 7a3248f4087a..6aa2485e98d4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -173,14 +173,20 @@ def __init__(self, if self.enforce_eager is None: self.enforce_eager = False - if (not self.disable_sliding_window - and self.hf_text_config.model_type == "gemma2" - and self.hf_text_config.sliding_window is not None): + sliding_window = getattr(self.hf_text_config, "sliding_window", None) + has_interleaved_attention = (sliding_window is not None) and ( + isinstance(sliding_window, list) or + (self.hf_text_config.model_type in ["gemma2"])) + + if (not self.disable_sliding_window and has_interleaved_attention): + sliding_window_len_min = get_min_sliding_window( + self.hf_text_config.sliding_window) + print_warning_once( - "Gemma 2 uses sliding window attention for every odd layer, " + f"{self.hf_text_config.model_type} has interleaved attention, " "which is currently not supported by vLLM. Disabling sliding " "window and capping the max length to the sliding window size " - f"({self.hf_text_config.sliding_window}).") + f"({sliding_window_len_min}).") self.disable_sliding_window = True self.max_model_len = _get_and_verify_max_len( @@ -422,7 +428,8 @@ def verify_with_parallel_config( "pipeline parallelism currently. Disabling it.") self.use_async_output_proc = False - def get_hf_config_sliding_window(self) -> Optional[int]: + def get_hf_config_sliding_window( + self) -> Union[Optional[int], List[Optional[int]]]: """Get the sliding window size, or None if disabled.""" # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in @@ -433,7 +440,7 @@ def get_hf_config_sliding_window(self) -> Optional[int]: return None return getattr(self.hf_text_config, "sliding_window", None) - def get_sliding_window(self) -> Optional[int]: + def get_sliding_window(self) -> Optional[Union[int, List[Optional[int]]]]: """Get the sliding window size, or None if disabled. """ # If user disables sliding window, return None. @@ -1680,7 +1687,7 @@ def _get_and_verify_max_len( hf_config: PretrainedConfig, max_model_len: Optional[int], disable_sliding_window: bool, - sliding_window_len: Optional[int], + sliding_window_len: Optional[Union[int, List[Optional[int]]]], spec_target_max_model_len: Optional[int] = None, ) -> int: """Get and verify the model's maximum length.""" @@ -1713,9 +1720,12 @@ def _get_and_verify_max_len( # If sliding window is manually disabled, max_length should be less # than the sliding window length in the model config. if disable_sliding_window and sliding_window_len is not None: + + sliding_window_len_min = get_min_sliding_window(sliding_window_len) max_len_key = "sliding_window" \ - if sliding_window_len < derived_max_model_len else max_len_key - derived_max_model_len = min(derived_max_model_len, sliding_window_len) + if sliding_window_len_min < derived_max_model_len else max_len_key + derived_max_model_len = min(derived_max_model_len, + sliding_window_len_min) # If none of the keys were found in the config, use a default and # log a warning. @@ -1803,6 +1813,14 @@ def _get_and_verify_max_len( return int(max_model_len) +def get_min_sliding_window( + sliding_window: Union[int, List[Optional[int]]]) -> int: + if isinstance(sliding_window, list): + return min(s for s in sliding_window if s is not None) + + return sliding_window + + def get_served_model_name(model: str, served_model_name: Optional[Union[str, List[str]]]): """