From 8e62e90bb6bddaa6ac80c7e5d68b1f1487b9701f Mon Sep 17 00:00:00 2001 From: wangguoya Date: Fri, 8 Dec 2023 20:54:32 -0800 Subject: [PATCH 1/2] Replace head_mapping params with num_kv_heads to attention kernel. --- .../kernels/benchmark_paged_attention.py | 8 ++--- csrc/attention/attention_kernels.cu | 31 +++++++++---------- csrc/ops.h | 4 +-- tests/kernels/test_attention.py | 7 ++--- vllm/model_executor/layers/attention.py | 11 +++---- 5 files changed, 25 insertions(+), 36 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 91fcf5340298..935393e9942c 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -37,10 +37,6 @@ def main( query.uniform_(-scale, scale) assert num_query_heads % num_kv_heads == 0 - num_queries_per_kv = num_query_heads // num_kv_heads - head_mapping = torch.repeat_interleave( - torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), - num_queries_per_kv) alibi_slopes = None if use_alibi: alibi_slopes = torch.randn(num_query_heads, @@ -103,7 +99,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float: query, key_cache, value_cache, - head_mapping, + num_kv_heads, scale, block_tables, context_lens, @@ -120,7 +116,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float: query, key_cache, value_cache, - head_mapping, + num_kv_heads, scale, block_tables, context_lens, diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 3676af1a378d..eff28d3dacd0 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -89,7 +89,7 @@ __device__ void paged_attention_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int* __restrict__ head_mapping, // [num_heads] + const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] @@ -132,7 +132,8 @@ __device__ void paged_attention_kernel( const int head_idx = blockIdx.x; const int num_heads = gridDim.x; - const int kv_head_idx = head_mapping[head_idx]; + const int num_queries_per_kv = num_heads / num_kv_heads; + const int kv_head_idx = head_idx / num_queries_per_kv; const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; // A vector type to store a part of a key or a query. @@ -401,7 +402,7 @@ __global__ void paged_attention_v1_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int* __restrict__ head_mapping, // [num_heads] + const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] @@ -412,7 +413,7 @@ __global__ void paged_attention_v1_kernel( const int kv_head_stride) { paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, - out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens, + out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); } @@ -430,7 +431,7 @@ __global__ void paged_attention_v2_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int* __restrict__ head_mapping, // [num_heads] + const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] @@ -440,7 +441,7 @@ __global__ void paged_attention_v2_kernel( const int kv_block_stride, const int kv_head_stride) { paged_attention_kernel( - exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale, + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); } @@ -556,7 +557,7 @@ __global__ void paged_attention_v2_reduce_kernel( query_ptr, \ key_cache_ptr, \ value_cache_ptr, \ - head_mapping_ptr, \ + num_kv_heads, \ scale, \ block_tables_ptr, \ context_lens_ptr, \ @@ -576,7 +577,7 @@ void paged_attention_v1_launcher( torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, - torch::Tensor& head_mapping, + int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& context_lens, @@ -602,7 +603,6 @@ void paged_attention_v1_launcher( T* query_ptr = reinterpret_cast(query.data_ptr()); T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* head_mapping_ptr = reinterpret_cast(head_mapping.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* context_lens_ptr = context_lens.data_ptr(); @@ -651,7 +651,7 @@ void paged_attention_v1_launcher( query, \ key_cache, \ value_cache, \ - head_mapping, \ + num_kv_heads, \ scale, \ block_tables, \ context_lens, \ @@ -681,7 +681,7 @@ void paged_attention_v1( torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - torch::Tensor& head_mapping, // [num_heads] + int num_kv_heads, // [num_heads] float scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& context_lens, // [num_seqs] @@ -708,7 +708,7 @@ void paged_attention_v1( query_ptr, \ key_cache_ptr, \ value_cache_ptr, \ - head_mapping_ptr, \ + num_kv_heads, \ scale, \ block_tables_ptr, \ context_lens_ptr, \ @@ -739,7 +739,7 @@ void paged_attention_v2_launcher( torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, - torch::Tensor& head_mapping, + int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& context_lens, @@ -768,7 +768,6 @@ void paged_attention_v2_launcher( T* query_ptr = reinterpret_cast(query.data_ptr()); T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* head_mapping_ptr = reinterpret_cast(head_mapping.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* context_lens_ptr = context_lens.data_ptr(); @@ -823,7 +822,7 @@ void paged_attention_v2_launcher( query, \ key_cache, \ value_cache, \ - head_mapping, \ + num_kv_heads, \ scale, \ block_tables, \ context_lens, \ @@ -856,7 +855,7 @@ void paged_attention_v2( torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - torch::Tensor& head_mapping, // [num_heads] + int num_kv_heads, // [num_heads] float scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& context_lens, // [num_seqs] diff --git a/csrc/ops.h b/csrc/ops.h index e12c34f0aafa..7a174179f724 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -5,7 +5,7 @@ void paged_attention_v1( torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, - torch::Tensor& head_mapping, + int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& context_lens, @@ -21,7 +21,7 @@ void paged_attention_v2( torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, - torch::Tensor& head_mapping, + int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& context_lens, diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index a65d4d54d7c8..614b65f82ccb 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -131,9 +131,6 @@ def test_paged_attention( assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads - head_mapping = torch.repeat_interleave( - torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), - num_queries_per_kv) alibi_slopes = None if use_alibi: alibi_slopes = torch.randn(num_query_heads, @@ -170,7 +167,7 @@ def test_paged_attention( query, key_cache, value_cache, - head_mapping, + num_kv_heads, scale, block_tables, context_lens, @@ -202,7 +199,7 @@ def test_paged_attention( query, key_cache, value_cache, - head_mapping, + num_kv_heads, scale, block_tables, context_lens, diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 3f4ecb5d2ae7..cdc3d8b5d2e9 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -54,9 +54,6 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.head_mapping = torch.repeat_interleave( - torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"), - self.num_queries_per_kv) if self.head_size not in _SUPPORTED_HEAD_SIZES: raise ValueError(f"head_size ({self.head_size}) is not supported. " @@ -172,7 +169,7 @@ def forward( key_cache, value_cache, input_metadata, - self.head_mapping, + self.num_kv_heads, self.scale, self.alibi_slopes, ) @@ -217,7 +214,7 @@ def _paged_attention( key_cache: torch.Tensor, value_cache: torch.Tensor, input_metadata: InputMetadata, - head_mapping: torch.Tensor, + num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], ) -> torch.Tensor: @@ -244,7 +241,7 @@ def _paged_attention( query, key_cache, value_cache, - head_mapping, + num_kv_heads, scale, input_metadata.block_tables, input_metadata.context_lens, @@ -274,7 +271,7 @@ def _paged_attention( query, key_cache, value_cache, - head_mapping, + num_kv_heads, scale, input_metadata.block_tables, input_metadata.context_lens, From 4c8d6478cb1c501e9db74a7143672fa2d93b1f2a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 10 Dec 2023 18:11:02 +0000 Subject: [PATCH 2/2] Minor --- vllm/model_executor/layers/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index cdc3d8b5d2e9..d0f0f28cb3fe 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -74,7 +74,7 @@ def forward( Args: query: shape = [batch_size, seq_len, num_heads * head_size] key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] key_cache: shape = [num_blocks, num_kv_heads, head_size/x, block_size, x] value_cache: shape = [num_blocks, num_kv_heads, head_size,