Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ def main(
query.uniform_(-scale, scale)

assert num_query_heads % num_kv_heads == 0
num_queries_per_kv = num_query_heads // num_kv_heads
head_mapping = torch.repeat_interleave(
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
num_queries_per_kv)
alibi_slopes = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads,
Expand Down Expand Up @@ -103,7 +99,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
query,
key_cache,
value_cache,
head_mapping,
num_kv_heads,
scale,
block_tables,
context_lens,
Expand All @@ -120,7 +116,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
query,
key_cache,
value_cache,
head_mapping,
num_kv_heads,
scale,
block_tables,
context_lens,
Expand Down
31 changes: 15 additions & 16 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ __device__ void paged_attention_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int* __restrict__ head_mapping, // [num_heads]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
Expand Down Expand Up @@ -132,7 +132,8 @@ __device__ void paged_attention_kernel(

const int head_idx = blockIdx.x;
const int num_heads = gridDim.x;
const int kv_head_idx = head_mapping[head_idx];
const int num_queries_per_kv = num_heads / num_kv_heads;
const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];

// A vector type to store a part of a key or a query.
Expand Down Expand Up @@ -401,7 +402,7 @@ __global__ void paged_attention_v1_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int* __restrict__ head_mapping, // [num_heads]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
Expand All @@ -412,7 +413,7 @@ __global__ void paged_attention_v1_kernel(
const int kv_head_stride) {
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
/* exp_sums */ nullptr, /* max_logits */ nullptr,
out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens,
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
}

Expand All @@ -430,7 +431,7 @@ __global__ void paged_attention_v2_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int* __restrict__ head_mapping, // [num_heads]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
Expand All @@ -440,7 +441,7 @@ __global__ void paged_attention_v2_kernel(
const int kv_block_stride,
const int kv_head_stride) {
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale,
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
q_stride, kv_block_stride, kv_head_stride);
}
Expand Down Expand Up @@ -556,7 +557,7 @@ __global__ void paged_attention_v2_reduce_kernel(
query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
head_mapping_ptr, \
num_kv_heads, \
scale, \
block_tables_ptr, \
context_lens_ptr, \
Expand All @@ -576,7 +577,7 @@ void paged_attention_v1_launcher(
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& head_mapping,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
Expand All @@ -602,7 +603,6 @@ void paged_attention_v1_launcher(
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();

Expand Down Expand Up @@ -651,7 +651,7 @@ void paged_attention_v1_launcher(
query, \
key_cache, \
value_cache, \
head_mapping, \
num_kv_heads, \
scale, \
block_tables, \
context_lens, \
Expand Down Expand Up @@ -681,7 +681,7 @@ void paged_attention_v1(
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& head_mapping, // [num_heads]
int num_kv_heads, // [num_heads]
float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]
Expand All @@ -708,7 +708,7 @@ void paged_attention_v1(
query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
head_mapping_ptr, \
num_kv_heads, \
scale, \
block_tables_ptr, \
context_lens_ptr, \
Expand Down Expand Up @@ -739,7 +739,7 @@ void paged_attention_v2_launcher(
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& head_mapping,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
Expand Down Expand Up @@ -768,7 +768,6 @@ void paged_attention_v2_launcher(
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();

Expand Down Expand Up @@ -823,7 +822,7 @@ void paged_attention_v2_launcher(
query, \
key_cache, \
value_cache, \
head_mapping, \
num_kv_heads, \
scale, \
block_tables, \
context_lens, \
Expand Down Expand Up @@ -856,7 +855,7 @@ void paged_attention_v2(
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& head_mapping, // [num_heads]
int num_kv_heads, // [num_heads]
float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]
Expand Down
4 changes: 2 additions & 2 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ void paged_attention_v1(
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& head_mapping,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
Expand All @@ -21,7 +21,7 @@ void paged_attention_v2(
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& head_mapping,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
Expand Down
7 changes: 2 additions & 5 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,6 @@ def test_paged_attention(

assert num_query_heads % num_kv_heads == 0
num_queries_per_kv = num_query_heads // num_kv_heads
head_mapping = torch.repeat_interleave(
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
num_queries_per_kv)
alibi_slopes = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads,
Expand Down Expand Up @@ -170,7 +167,7 @@ def test_paged_attention(
query,
key_cache,
value_cache,
head_mapping,
num_kv_heads,
scale,
block_tables,
context_lens,
Expand Down Expand Up @@ -202,7 +199,7 @@ def test_paged_attention(
query,
key_cache,
value_cache,
head_mapping,
num_kv_heads,
scale,
block_tables,
context_lens,
Expand Down
13 changes: 5 additions & 8 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ def __init__(

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.head_mapping = torch.repeat_interleave(
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
self.num_queries_per_kv)

if self.head_size not in _SUPPORTED_HEAD_SIZES:
raise ValueError(f"head_size ({self.head_size}) is not supported. "
Expand All @@ -77,7 +74,7 @@ def forward(
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
Expand Down Expand Up @@ -172,7 +169,7 @@ def forward(
key_cache,
value_cache,
input_metadata,
self.head_mapping,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
)
Expand Down Expand Up @@ -217,7 +214,7 @@ def _paged_attention(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
head_mapping: torch.Tensor,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
) -> torch.Tensor:
Expand All @@ -244,7 +241,7 @@ def _paged_attention(
query,
key_cache,
value_cache,
head_mapping,
num_kv_heads,
scale,
input_metadata.block_tables,
input_metadata.context_lens,
Expand Down Expand Up @@ -274,7 +271,7 @@ def _paged_attention(
query,
key_cache,
value_cache,
head_mapping,
num_kv_heads,
scale,
input_metadata.block_tables,
input_metadata.context_lens,
Expand Down