From f597cc9bfc773ce93aafa25f3fef1dbf73ba8097 Mon Sep 17 00:00:00 2001 From: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Date: Tue, 7 Oct 2025 06:03:41 +0000 Subject: [PATCH 1/7] Add gather_indexer_k_quant_cache kernel Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> --- csrc/cache.h | 8 ++++ csrc/cache_kernels.cu | 102 ++++++++++++++++++++++++++++++++++++++++ csrc/torch_bindings.cpp | 6 +++ vllm/_custom_ops.py | 12 +++++ 4 files changed, 128 insertions(+) diff --git a/csrc/cache.h b/csrc/cache.h index 427bd0d54fac..b162a4a2bc31 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -64,3 +64,11 @@ void indexer_k_quant_and_cache( torch::Tensor& slot_mapping, // [num_tokens] int64_t quant_block_size, // quantization block size const std::string& scale_fmt); + +// Extract function to gather quantized K cache +void cp_gather_indexer_k_quant_cache( + const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& dst_k, // [num_tokens, head_dim] + torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4] + const torch::Tensor& block_table, // [batch_size, num_blocks] + const torch::Tensor& cu_seq_lens); // [batch_size + 1] \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 84c2345b44d8..dd33bc989b74 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -572,6 +572,69 @@ __global__ void indexer_k_quant_and_cache_kernel( } } +__global__ void cp_gather_indexer_k_quant_cache_kernel( + const char* __restrict__ kv_cache, // [num_blocks, block_size, + // cache_stride] + char* __restrict__ dst_k, // [num_tokens, head_dim] + char* __restrict__ dst_scale, // [num_tokens, head_dim / quant_block_size * + // 4] + const int* __restrict__ block_table, // [batch_size, num_blocks] + const int* __restrict__ cu_seq_lens, // [batch_size + 1] + const int batch_size, // batch size + const int token_stride, // stride for each token in dst_k + const int head_dim, // dimension of each head + const int block_stride, // stride for each block in kv_cache + const int cache_token_stride, // stride for each token in kv_cache + const int cache_block_size, // num_tokens for each block in kv_cache + const int num_blocks, // number of blocks + const int quant_block_size // quantization block size +) { + constexpr int VEC_SIZE = sizeof(float4) / sizeof(char); + const int64_t token_idx = blockIdx.x; + const int64_t head_idx = (blockIdx.y * blockDim.y * blockDim.x + + threadIdx.y * blockDim.x + threadIdx.x) * + VEC_SIZE; + if (head_idx >= head_dim) { + return; + } + + // Find batch index within a block + __shared__ int batch_idx; + for (int iter = 0; + iter < cuda_utils::ceil_div(batch_size, int(blockDim.x * blockDim.y)); + iter++) { + int tid = + iter * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x; + if (tid < batch_size) { + const int seq_start = cu_seq_lens[tid]; + const int seq_end = cu_seq_lens[tid + 1]; + if (token_idx >= seq_start && token_idx < seq_end) { + batch_idx = tid; + } + } + } + __syncthreads(); + const int64_t inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx]; + const int64_t block_idx = + block_table[batch_idx * num_blocks + inbatch_seq_idx / cache_block_size]; + const int64_t src_block_offset = block_idx * block_stride; + const int64_t cache_inblock_offset = + (inbatch_seq_idx % cache_block_size) * head_dim + head_idx; + const int64_t src_inblock_offset = src_block_offset + cache_inblock_offset; + const int64_t dst_inblock_offset = token_idx * token_stride + head_idx; + + reinterpret_cast(dst_k)[dst_inblock_offset / VEC_SIZE] = + reinterpret_cast(kv_cache)[src_inblock_offset / VEC_SIZE]; + ; + if (threadIdx.x == 0) { + const int64_t src_scale_offset = + src_block_offset + cache_block_size * head_dim + + cache_inblock_offset * 4 / quant_block_size; + reinterpret_cast(dst_scale)[dst_inblock_offset / quant_block_size] = + reinterpret_cast(kv_cache)[src_scale_offset / 4]; + } +} + } // namespace vllm // KV_T is the data type of key and value tensors. @@ -1173,3 +1236,42 @@ void indexer_k_quant_and_cache( DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3", CALL_INDEXER_K_QUANT_AND_CACHE); } +void cp_gather_indexer_k_quant_cache( + const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& dst_k, // [num_tokens, head_dim] + torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4] + const torch::Tensor& block_table, // [batch_size, num_blocks] + const torch::Tensor& cu_seq_lens // [batch_size + 1] +) { + int batch_size = block_table.size(0); + int num_blocks = block_table.size(1); + int num_tokens = dst_k.size(0); + int head_dim = dst_k.size(1); + int quant_block_size = head_dim * 4 / dst_scale.size(1); + + TORCH_CHECK(kv_cache.device() == dst_k.device(), + "kv_cache and dst_k must be on the same device"); + TORCH_CHECK(kv_cache.device() == dst_scale.device(), + "kv_cache and dst_scale must be on the same device"); + TORCH_CHECK(kv_cache.device() == block_table.device(), + "kv_cache and block_table must be on the same device"); + TORCH_CHECK(kv_cache.device() == cu_seq_lens.device(), + "kv_cache and cu_seq_lens must be on the same device"); + TORCH_CHECK(head_dim % quant_block_size == 0, + "head_dim must be divisible by quant_block_size"); + + constexpr int vec_size = 16; + dim3 grid(num_tokens, (head_dim + 128 * vec_size - 1) / (128 * vec_size)); + dim3 block(8, 16); + const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_cache)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + vllm::cp_gather_indexer_k_quant_cache_kernel<<>>( + reinterpret_cast(kv_cache.data_ptr()), + reinterpret_cast(dst_k.data_ptr()), + reinterpret_cast(dst_scale.data_ptr()), + block_table.data_ptr(), cu_seq_lens.data_ptr(), + batch_size, dst_k.stride(0), dst_k.size(1), kv_cache.stride(0), + kv_cache.stride(1), kv_cache.size(1), block_table.size(1), + quant_block_size); +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 64a345eb66cc..82658ff95e0d 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -720,6 +720,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "int quant_block_size, str kv_cache_dtype) -> ()"); cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA, &indexer_k_quant_and_cache); + + cache_ops.def( + "cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! " + "dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()"); + cache_ops.impl("cp_gather_indexer_k_quant_cache", torch::kCUDA, + &cp_gather_indexer_k_quant_cache); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index b8cbb1ad90a6..9fa346cca56d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2108,6 +2108,18 @@ def indexer_k_quant_and_cache( ) +def cp_gather_indexer_k_quant_cache( + kv_cache: torch.Tensor, + dst_k: torch.Tensor, + dst_scale: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, +) -> None: + torch.ops._C_cache_ops.cp_gather_indexer_k_quant_cache( + kv_cache, dst_k, dst_scale, block_table, cu_seq_lens + ) + + def get_device_attribute(attribute: int, device: int) -> int: return torch.ops._C_cuda_utils.get_device_attribute(attribute, device) From a9b2f183c922b2fa34993b9a02c21414ad8d0a1b Mon Sep 17 00:00:00 2001 From: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Date: Tue, 7 Oct 2025 07:34:10 +0000 Subject: [PATCH 2/7] Fix on large num_tokens Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> --- csrc/cache_kernels.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index dd33bc989b74..f9fde2693162 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -594,10 +594,6 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel( const int64_t head_idx = (blockIdx.y * blockDim.y * blockDim.x + threadIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE; - if (head_idx >= head_dim) { - return; - } - // Find batch index within a block __shared__ int batch_idx; for (int iter = 0; @@ -614,6 +610,10 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel( } } __syncthreads(); + + if (head_idx >= head_dim) { + return; + } const int64_t inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx]; const int64_t block_idx = block_table[batch_idx * num_blocks + inbatch_seq_idx / cache_block_size]; From bd88b9b1ddee09f279e767f5b254147ca7ee1f20 Mon Sep 17 00:00:00 2001 From: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Date: Tue, 7 Oct 2025 08:44:58 +0000 Subject: [PATCH 3/7] Perf optimization Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> --- csrc/cache_kernels.cu | 80 ++++++++++++++++++++++++++----------------- 1 file changed, 48 insertions(+), 32 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index f9fde2693162..bec827ae76db 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -572,6 +572,7 @@ __global__ void indexer_k_quant_and_cache_kernel( } } +template __global__ void cp_gather_indexer_k_quant_cache_kernel( const char* __restrict__ kv_cache, // [num_blocks, block_size, // cache_stride] @@ -581,42 +582,39 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel( const int* __restrict__ block_table, // [batch_size, num_blocks] const int* __restrict__ cu_seq_lens, // [batch_size + 1] const int batch_size, // batch size - const int token_stride, // stride for each token in dst_k - const int head_dim, // dimension of each head - const int block_stride, // stride for each block in kv_cache - const int cache_token_stride, // stride for each token in kv_cache - const int cache_block_size, // num_tokens for each block in kv_cache - const int num_blocks, // number of blocks - const int quant_block_size // quantization block size + const int64_t token_stride, // stride for each token in dst_k + const int64_t head_dim, // dimension of each head + const int64_t block_stride, // stride for each block in kv_cache + const int64_t cache_token_stride, // stride for each token in kv_cache + const int64_t cache_block_size, // num_tokens for each block in kv_cache + const int num_blocks, // number of blocks + const int num_tokens, // number of tokens + const int quant_block_size // quantization block size ) { constexpr int VEC_SIZE = sizeof(float4) / sizeof(char); - const int64_t token_idx = blockIdx.x; - const int64_t head_idx = (blockIdx.y * blockDim.y * blockDim.x + - threadIdx.y * blockDim.x + threadIdx.x) * - VEC_SIZE; + const int token_idx = blockIdx.x * blockDim.y + threadIdx.y; + const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE; // Find batch index within a block - __shared__ int batch_idx; - for (int iter = 0; - iter < cuda_utils::ceil_div(batch_size, int(blockDim.x * blockDim.y)); + __shared__ int batch_idx[BLOCK_Y_SIZE]; + for (int iter = 0; iter < cuda_utils::ceil_div(batch_size, int(blockDim.x)); iter++) { - int tid = - iter * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x; + int tid = iter * blockDim.x + threadIdx.x; if (tid < batch_size) { const int seq_start = cu_seq_lens[tid]; const int seq_end = cu_seq_lens[tid + 1]; if (token_idx >= seq_start && token_idx < seq_end) { - batch_idx = tid; + batch_idx[threadIdx.y] = tid; } } } - __syncthreads(); + __syncwarp(); - if (head_idx >= head_dim) { + if (head_idx >= head_dim || token_idx >= num_tokens) { return; } - const int64_t inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx]; - const int64_t block_idx = - block_table[batch_idx * num_blocks + inbatch_seq_idx / cache_block_size]; + const int inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx[threadIdx.y]]; + const int block_idx = block_table[batch_idx[threadIdx.y] * num_blocks + + inbatch_seq_idx / cache_block_size]; const int64_t src_block_offset = block_idx * block_stride; const int64_t cache_inblock_offset = (inbatch_seq_idx % cache_block_size) * head_dim + head_idx; @@ -1236,6 +1234,21 @@ void indexer_k_quant_and_cache( DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3", CALL_INDEXER_K_QUANT_AND_CACHE); } + +// Macro to dispatch the kernel based on the data amount. +#define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \ + vllm::cp_gather_indexer_k_quant_cache_kernel \ + <<>>( \ + reinterpret_cast(kv_cache.data_ptr()), \ + reinterpret_cast(dst_k.data_ptr()), \ + reinterpret_cast(dst_scale.data_ptr()), \ + block_table.data_ptr(), cu_seq_lens.data_ptr(), \ + batch_size, dst_k.stride(0), dst_k.size(1), kv_cache.stride(0), \ + kv_cache.stride(1), kv_cache.size(1), block_table.size(1), \ + num_tokens, quant_block_size); + void cp_gather_indexer_k_quant_cache( const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] torch::Tensor& dst_k, // [num_tokens, head_dim] @@ -1261,17 +1274,20 @@ void cp_gather_indexer_k_quant_cache( "head_dim must be divisible by quant_block_size"); constexpr int vec_size = 16; - dim3 grid(num_tokens, (head_dim + 128 * vec_size - 1) / (128 * vec_size)); - dim3 block(8, 16); const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_cache)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - vllm::cp_gather_indexer_k_quant_cache_kernel<<>>( - reinterpret_cast(kv_cache.data_ptr()), - reinterpret_cast(dst_k.data_ptr()), - reinterpret_cast(dst_scale.data_ptr()), - block_table.data_ptr(), cu_seq_lens.data_ptr(), - batch_size, dst_k.stride(0), dst_k.size(1), kv_cache.stride(0), - kv_cache.stride(1), kv_cache.size(1), block_table.size(1), - quant_block_size); + if (num_tokens < 32) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(1); + } else if (num_tokens < 64) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(2); + } else if (num_tokens < 128) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(4); + } else if (num_tokens < 256) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(8); + } else if (num_tokens < 512) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(16); + } else { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32); + } } From 9bbca67ea04069d938447645a7b72ad6900d9fb3 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Tue, 7 Oct 2025 12:45:37 -0700 Subject: [PATCH 4/7] Update csrc/cache_kernels.cu Co-authored-by: Yongye Zhu Signed-off-by: Simon Mo --- csrc/cache_kernels.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index bec827ae76db..ccaa32af6aad 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -607,7 +607,9 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel( } } } +#ifndef USE_ROCM __syncwarp(); +#endif if (head_idx >= head_dim || token_idx >= num_tokens) { return; From 7bf9fe156e7d2aa058faf1a15ece08066d889aef Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 7 Oct 2025 19:14:17 -0700 Subject: [PATCH 5/7] skip the whole kernel for rocm Signed-off-by: Chen Zhang --- csrc/cache_kernels.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index ccaa32af6aad..530360c6cd11 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -591,6 +591,7 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel( const int num_tokens, // number of tokens const int quant_block_size // quantization block size ) { +#ifndef USE_ROCM constexpr int VEC_SIZE = sizeof(float4) / sizeof(char); const int token_idx = blockIdx.x * blockDim.y + threadIdx.y; const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE; @@ -607,9 +608,7 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel( } } } -#ifndef USE_ROCM __syncwarp(); -#endif if (head_idx >= head_dim || token_idx >= num_tokens) { return; @@ -633,6 +632,9 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel( reinterpret_cast(dst_scale)[dst_inblock_offset / quant_block_size] = reinterpret_cast(kv_cache)[src_scale_offset / 4]; } +#else + assert false; // TODO: this kernel has compilation errors with ROCm. +#endif } } // namespace vllm From ce8576d6e7ab7b3357831260a3865b229baf2072 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 7 Oct 2025 19:20:28 -0700 Subject: [PATCH 6/7] fix Signed-off-by: Chen Zhang --- csrc/cache_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 530360c6cd11..1866b4e12ff1 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -633,7 +633,7 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel( reinterpret_cast(kv_cache)[src_scale_offset / 4]; } #else - assert false; // TODO: this kernel has compilation errors with ROCm. + assert(false); // TODO: This kernel has compilation errors with ROCm. #endif } From e46273e5d577b707fea02545f20debb6de34d336 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 7 Oct 2025 19:34:43 -0700 Subject: [PATCH 7/7] try fix Signed-off-by: Chen Zhang --- csrc/cache_kernels.cu | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 1866b4e12ff1..f4b116c94f19 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -591,7 +591,6 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel( const int num_tokens, // number of tokens const int quant_block_size // quantization block size ) { -#ifndef USE_ROCM constexpr int VEC_SIZE = sizeof(float4) / sizeof(char); const int token_idx = blockIdx.x * blockDim.y + threadIdx.y; const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE; @@ -608,7 +607,10 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel( } } } + +#ifndef USE_ROCM __syncwarp(); +#endif if (head_idx >= head_dim || token_idx >= num_tokens) { return; @@ -632,9 +634,6 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel( reinterpret_cast(dst_scale)[dst_inblock_offset / quant_block_size] = reinterpret_cast(kv_cache)[src_scale_offset / 4]; } -#else - assert(false); // TODO: This kernel has compilation errors with ROCm. -#endif } } // namespace vllm @@ -1261,7 +1260,6 @@ void cp_gather_indexer_k_quant_cache( const torch::Tensor& cu_seq_lens // [batch_size + 1] ) { int batch_size = block_table.size(0); - int num_blocks = block_table.size(1); int num_tokens = dst_k.size(0); int head_dim = dst_k.size(1); int quant_block_size = head_dim * 4 / dst_scale.size(1);