From b5b64ac151e44a183b3bba0994a138bff21fbd20 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 8 May 2024 15:45:49 +0000 Subject: [PATCH 1/3] [Misc] Add get_name method to attention backends --- vllm/attention/backends/abstract.py | 5 +++++ vllm/attention/backends/flash_attn.py | 4 ++++ vllm/attention/backends/flashinfer.py | 16 +++++++--------- vllm/attention/backends/rocm_flash_attn.py | 4 ++++ vllm/attention/backends/torch_sdpa.py | 4 ++++ vllm/attention/backends/xformers.py | 4 ++++ vllm/worker/model_runner.py | 10 ++++++---- 7 files changed, 34 insertions(+), 13 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index b2b6e7ac810..02a2fd603fa 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -9,6 +9,11 @@ class AttentionBackend(ABC): """Abstract class for attention backends.""" + @staticmethod + @abstractmethod + def get_name() -> str: + raise NotImplementedError + @staticmethod @abstractmethod def get_impl_cls() -> Type["AttentionImpl"]: diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index da672d5df61..bee482c3431 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -19,6 +19,10 @@ class FlashAttentionBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "flash-attn" + @staticmethod def get_impl_cls() -> Type["FlashAttentionImpl"]: return FlashAttentionImpl diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 2851cbe2396..015f718aede 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,16 +1,10 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Set, Tuple, Type -try: - import flashinfer - from flash_attn import flash_attn_varlen_func - from flashinfer import BatchDecodeWithPagedKVCacheWrapper -except ImportError: - flashinfer = None - flash_attn_varlen_func = None - BatchDecodeWithPagedKVCacheWrapper = None - +import flashinfer import torch +from flashinfer import BatchDecodeWithPagedKVCacheWrapper +from flash_attn import flash_attn_varlen_func from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -20,6 +14,10 @@ class FlashInferBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "flashinfer" + @staticmethod def get_impl_cls() -> Type["FlashInferImpl"]: return FlashInferImpl diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index c3b522e63b4..10c94f02ff0 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -17,6 +17,10 @@ class ROCmFlashAttentionBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "rocm-flash-attn" + @staticmethod def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: return ROCmFlashAttentionImpl diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 03825f6023f..c1c07abef0c 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -15,6 +15,10 @@ class TorchSDPABackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "torch-sdpa" + @staticmethod def get_impl_cls() -> Type["TorchSDPABackendImpl"]: return TorchSDPABackendImpl diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 4c7fa71a2c7..2a9150dea58 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -20,6 +20,10 @@ class XFormersBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "xformers" + @staticmethod def get_impl_cls() -> Type["XFormersImpl"]: return XFormersImpl diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ab248596490..68653113e79 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -9,7 +9,6 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, get_attn_backend) -from vllm.attention.backends.flashinfer import FlashInferBackend from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce @@ -273,7 +272,10 @@ def _prepare_prompt( # Prefix is not supported with sliding_window context_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[context_len:] - prefix_block_tables.append(computed_block_nums) + if self.attn_backend.get_name() == "flash-attn": + block_table = seq_group_metadata.block_tables[seq_id] + else: + block_table = computed_block_nums elif self.scheduler_config.chunked_prefill_enabled: if seq_group_metadata.block_tables is not None: # Prefill has chunked before. @@ -395,7 +397,7 @@ def _prepare_prompt( dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) - if self.attn_backend is FlashInferBackend: + if self.attn_backend.get_name() == "flashinfer": attn_metadata = self.attn_backend.make_metadata( is_prompt=True, use_cuda_graph=False, @@ -556,7 +558,7 @@ def _prepare_decode( device=self.device, ) - if self.attn_backend is FlashInferBackend: + if self.attn_backend.get_name() == "flashinfer": if not hasattr(self, "flashinfer_workspace_buffer"): # Allocate 16MB workspace buffer # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html From 1db3ac996314cb4661b7840a2b581b1bf34ce172 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 8 May 2024 15:46:46 +0000 Subject: [PATCH 2/3] Fix --- vllm/worker/model_runner.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 68653113e79..c96f13c590f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -272,10 +272,7 @@ def _prepare_prompt( # Prefix is not supported with sliding_window context_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[context_len:] - if self.attn_backend.get_name() == "flash-attn": - block_table = seq_group_metadata.block_tables[seq_id] - else: - block_table = computed_block_nums + prefix_block_tables.append(computed_block_nums) elif self.scheduler_config.chunked_prefill_enabled: if seq_group_metadata.block_tables is not None: # Prefill has chunked before. From ac44d0184a9dd87ff5a1553e934b6de8c4d5f423 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 8 May 2024 15:47:04 +0000 Subject: [PATCH 3/3] isort --- vllm/attention/backends/flashinfer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 015f718aede..67b99ba2ead 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -3,8 +3,8 @@ import flashinfer import torch -from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flash_attn import flash_attn_varlen_func +from flashinfer import BatchDecodeWithPagedKVCacheWrapper from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,