Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Add get_name method to attention backends #4685

Merged
merged 3 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
class AttentionBackend(ABC):
"""Abstract class for attention backends."""

@staticmethod
@abstractmethod
def get_name() -> str:
raise NotImplementedError

@staticmethod
@abstractmethod
def get_impl_cls() -> Type["AttentionImpl"]:
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@

class FlashAttentionBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "flash-attn"

@staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]:
return FlashAttentionImpl
Expand Down
16 changes: 7 additions & 9 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Type

try:
import flashinfer
from flash_attn import flash_attn_varlen_func
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
except ImportError:
flashinfer = None
flash_attn_varlen_func = None
BatchDecodeWithPagedKVCacheWrapper = None

import flashinfer
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
import torch
from flash_attn import flash_attn_varlen_func
from flashinfer import BatchDecodeWithPagedKVCacheWrapper

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
Expand All @@ -20,6 +14,10 @@

class FlashInferBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "flashinfer"

@staticmethod
def get_impl_cls() -> Type["FlashInferImpl"]:
return FlashInferImpl
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

class ROCmFlashAttentionBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "rocm-flash-attn"

@staticmethod
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
return ROCmFlashAttentionImpl
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@

class TorchSDPABackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "torch-sdpa"

@staticmethod
def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
return TorchSDPABackendImpl
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@

class XFormersBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "xformers"

@staticmethod
def get_impl_cls() -> Type["XFormersImpl"]:
return XFormersImpl
Expand Down
5 changes: 2 additions & 3 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
get_attn_backend)
from vllm.attention.backends.flashinfer import FlashInferBackend
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
Expand Down Expand Up @@ -395,7 +394,7 @@ def _prepare_prompt(
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])

if self.attn_backend is FlashInferBackend:
if self.attn_backend.get_name() == "flashinfer":
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
use_cuda_graph=False,
Expand Down Expand Up @@ -556,7 +555,7 @@ def _prepare_decode(
device=self.device,
)

if self.attn_backend is FlashInferBackend:
if self.attn_backend.get_name() == "flashinfer":
if not hasattr(self, "flashinfer_workspace_buffer"):
# Allocate 16MB workspace buffer
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
Expand Down
Loading