From 43af31054aee037558e1977b84cad6e23d447734 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 10 Oct 2023 08:19:43 +0000 Subject: [PATCH 01/21] port dtype_float16.cuh and cache_kernels.cu --- csrc/attention/dtype_float16.cuh | 69 ++++++++++++++++++++++++++++++-- csrc/cache_kernels.cu | 17 ++++---- setup.py | 64 ----------------------------- 3 files changed, 74 insertions(+), 76 deletions(-) diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index e67921128d52..079fa607f96b 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -21,6 +21,10 @@ #include "attention_generic.cuh" #include "dtype_float32.cuh" +#ifdef USE_ROCM + #include +#endif + #include namespace vllm { @@ -63,30 +67,49 @@ struct FloatVec { // Utility functions for type conversions. inline __device__ uint32_t h0_h0(uint16_t a) { +#ifndef USE_ROCM uint32_t b; asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); +#else + uint32_t b = a; + b <<= 16; + b |= a; +#endif return b; } inline __device__ float half_to_float(uint16_t h) { +#ifndef USE_ROCM float f; asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); return f; +#else + return __half2float(__ushort_as_half(h)); +#endif } inline __device__ float2 half2_to_float2(uint32_t v) { +#ifndef USE_ROCM uint16_t lo, hi; asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); +#else + uint16_t hi = (v >> 16) & 0xFFFF; + uint16_t lo = v & 0xFFFF; +#endif return make_float2(half_to_float(lo), half_to_float(hi)); } inline __device__ uint16_t float_to_half(float f) { +#ifndef USE_ROCM union { uint32_t u32; uint16_t u16[2]; } tmp; asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); return tmp.u16[0]; +#else + return __half_as_ushort(__float2half(f)); +#endif } inline __device__ uint32_t float2_to_half2(float2 f) { @@ -95,26 +118,48 @@ inline __device__ uint32_t float2_to_half2(float2 f) { uint16_t u16[2]; } tmp; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); +#ifndef USE_ROCM + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); + #else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + #endif #else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + __half2 h = __float22half2_rn(f); + tmp.u16[0] = h.x; + tmp.u16[1] = h.y; #endif return tmp.u32; } // Vector addition. inline __device__ uint16_t add(uint16_t a, uint16_t b) { +#ifndef USE_ROCM uint16_t c; asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); return c; +#else + return __half_as_ushort(__hadd(__ushort_as_half(a), __ushort_as_half(b))); +#endif } inline __device__ uint32_t add(uint32_t a, uint32_t b) { +#ifndef USE_ROCM uint32_t c; asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; +#else + __half2 h = __hadd2(a, b); + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u16[0] = h.x; + tmp.u16[1] = h.y; + + return tmp.u32; +#endif } inline __device__ uint2 add(uint2 a, uint2 b) { @@ -157,16 +202,32 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { // Vector multiplication. template<> inline __device__ uint16_t mul(uint16_t a, uint16_t b) { +#ifndef USE_ROCM uint16_t c; asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); return c; +#else + return __half_as_ushort(__hmul(__ushort_as_half(a), __ushort_as_half(b))); +#endif } template<> inline __device__ uint32_t mul(uint32_t a, uint32_t b) { +#ifndef USE_ROCM uint32_t c; asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; +#else + __half2 h = __hmul2(a, b); + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u16[0] = h.x; + tmp.u16[1] = h.y; + + return tmp.u32; +#endif } template<> diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index ddad2b5a29b9..2d54ac5eab01 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "cuda_compat.h" #include "dispatch_utils.h" #include @@ -28,8 +29,8 @@ void swap_blocks( TORCH_CHECK(false, "Invalid device combination"); } - void *src_ptr = src.data_ptr(); - void *dst_ptr = dst.data_ptr(); + char *src_ptr = static_cast(src.data_ptr()); + char *dst_ptr = static_cast(dst.data_ptr()); const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -176,8 +177,8 @@ __global__ void reshape_and_cache_kernel( + head_idx * head_size * block_size + head_offset * block_size + block_offset; - key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]); - value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]); + key_cache[tgt_key_idx] = VLLM_LDB(&key[src_key_idx]); + value_cache[tgt_value_idx] = VLLM_LDB(&value[src_value_idx]); } } @@ -262,8 +263,8 @@ __global__ void gather_cached_kv_kernel( + head_offset * block_size + block_offset; - key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]); - value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]); + key[tgt_key_idx] = VLLM_LDB(&key_cache[src_key_idx]); + value[tgt_value_idx] = VLLM_LDB(&value_cache[src_value_idx]); } } @@ -328,8 +329,8 @@ __global__ void gather_cached_kv_kernel_optimized( src_key_indices[j] = src_key_idx; src_value_indices[j] = src_value_idx; - keys_to_store[j] = __ldg(&key_cache[src_key_idx]); - values_to_store[j] = __ldg(&value_cache[src_value_idx]); + keys_to_store[j] = VLLM_LDB(&key_cache[src_key_idx]); + values_to_store[j] = VLLM_LDB(&value_cache[src_value_idx]); } #pragma unroll diff --git a/setup.py b/setup.py index 8b2ad97dd540..75433b336c35 100644 --- a/setup.py +++ b/setup.py @@ -24,10 +24,6 @@ CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] -if CUDA_HOME is None: - raise RuntimeError( - "Cannot find CUDA_HOME. CUDA must be available to build the package.") - def get_nvcc_cuda_version(cuda_dir: str) -> Version: """Get the CUDA version from nvcc. @@ -64,66 +60,6 @@ def get_torch_arch_list() -> Set[str]: return set(arch_list) -# First, check the TORCH_CUDA_ARCH_LIST environment variable. -compute_capabilities = get_torch_arch_list() -if not compute_capabilities: - # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available - # GPUs on the current machine. - device_count = torch.cuda.device_count() - for i in range(device_count): - major, minor = torch.cuda.get_device_capability(i) - if major < 7: - raise RuntimeError( - "GPUs with compute capability below 7.0 are not supported.") - compute_capabilities.add(f"{major}.{minor}") - -nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) -if not compute_capabilities: - # If no GPU is specified nor available, add all supported architectures - # based on the NVCC CUDA version. - compute_capabilities = set(SUPPORTED_ARCHS) - if nvcc_cuda_version < Version("11.1"): - compute_capabilities.remove("8.6") - if nvcc_cuda_version < Version("11.8"): - compute_capabilities.remove("8.9") - compute_capabilities.remove("9.0") - -# Validate the NVCC CUDA version. -if nvcc_cuda_version < Version("11.0"): - raise RuntimeError("CUDA 11.0 or higher is required to build the package.") -if nvcc_cuda_version < Version("11.1"): - if any(cc.startswith("8.6") for cc in compute_capabilities): - raise RuntimeError( - "CUDA 11.1 or higher is required for compute capability 8.6.") -if nvcc_cuda_version < Version("11.8"): - if any(cc.startswith("8.9") for cc in compute_capabilities): - # CUDA 11.8 is required to generate the code targeting compute capability 8.9. - # However, GPUs with compute capability 8.9 can also run the code generated by - # the previous versions of CUDA 11 and targeting compute capability 8.0. - # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 - # instead of 8.9. - warnings.warn( - "CUDA 11.8 or higher is required for compute capability 8.9. " - "Targeting compute capability 8.0 instead.") - compute_capabilities = set(cc for cc in compute_capabilities - if not cc.startswith("8.9")) - compute_capabilities.add("8.0+PTX") - if any(cc.startswith("9.0") for cc in compute_capabilities): - raise RuntimeError( - "CUDA 11.8 or higher is required for compute capability 9.0.") - -# Add target compute capabilities to NVCC flags. -for capability in compute_capabilities: - num = capability[0] + capability[2] - NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] - if capability.endswith("+PTX"): - NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] - -# Use NVCC threads to parallelize the build. -if nvcc_cuda_version >= Version("11.2"): - num_threads = min(os.cpu_count(), 8) - NVCC_FLAGS += ["--threads", str(num_threads)] - ext_modules = [] # Cache operations. From cc818664c403b0a0c0683b565a428f0863f65a37 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 10 Oct 2023 08:53:46 +0000 Subject: [PATCH 02/21] port dtype_bfloat16.cuh --- csrc/attention/dtype_bfloat16.cuh | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 2154bfcf8631..9ad2e299c7aa 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -21,8 +21,17 @@ #include "attention_generic.cuh" #include "dtype_float32.cuh" -#include -#include +#ifndef USE_ROCM + #include + #include +#else + #include + #include + + typedef __hip_bfloat162 __nv_bfloat162; + typedef __hip_bfloat16 __nv_bfloat16; +#endif + #include namespace vllm { @@ -98,7 +107,17 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else - return a + b; + #ifndef USE_ROCM + return a + b; + #else + // See https://github.com/RadeonOpenCompute/ROCm/issues/2534 + hip_bfloat16 A, B; + __hip_bfloat16 c; + A.data = a.data; + B.data = b.data; + c.data = (A + B).data; + return c; + #endif #endif } From 475b5e2875f9f870b88206bf087ff6adc99517a9 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 10 Oct 2023 09:07:33 +0000 Subject: [PATCH 03/21] port attention_utils.cuh --- csrc/attention/attention_kernels.cu | 30 ++++++++++++++++++++++++++++- csrc/attention/attention_utils.cuh | 4 ++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 505c63d2efd7..423e784dca94 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -39,7 +39,11 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Compute the sum per warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM sum += __shfl_xor_sync(uint32_t(-1), sum, mask); +#else + sum += __shfl_xor(uint32_t(-1), sum, mask); +#endif } // Warp leaders store the data to shared memory. @@ -58,11 +62,19 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Parallel reduction inside the warp. #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM sum += __shfl_xor_sync(uint32_t(-1), sum, mask); +#else + sum += __shfl_xor(uint32_t(-1), sum, mask); +#endif } // Broadcast to other threads. +#ifndef USE_ROCM return __shfl_sync(uint32_t(-1), sum, 0); +#else + return __shfl(uint32_t(-1), sum, 0); +#endif } // Grid: (num_heads, num_seqs). @@ -196,7 +208,11 @@ __global__ void single_query_cached_kv_attention_kernel( // The 0-th thread of each thread group already has its max qk value. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { +#ifndef USE_ROCM qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); +#else + qk_max = fmaxf(qk_max, __shfl_xor(uint32_t(-1), qk_max, mask)); +#endif } if (lane == 0) { red_smem[warp_idx] = qk_max; @@ -208,10 +224,18 @@ __global__ void single_query_cached_kv_attention_kernel( qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); +#else + qk_max = fmaxf(qk_max, __shfl_xor(uint32_t(-1), qk_max, mask)); +#endif } // Broadcast the max qk value to all threads. +#ifndef USE_ROCM qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); +#else + qk_max = __shfl(uint32_t(-1), qk_max, 0); +#endif // Get the sum of the exp values. float exp_sum = 0.f; @@ -284,7 +308,11 @@ __global__ void single_query_cached_kv_attention_kernel( float acc = accs[i]; #pragma unroll for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM acc += __shfl_xor_sync(uint32_t(-1), acc, mask); +#else + acc += __shfl_xor(uint32_t(-1), acc, mask); +#endif } accs[i] = acc; } @@ -342,7 +370,7 @@ __global__ void single_query_cached_kv_attention_kernel( #define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ cudaFuncSetAttribute( \ - vllm::single_query_cached_kv_attention_kernel, \ + (void*)vllm::single_query_cached_kv_attention_kernel, \ cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ vllm::single_query_cached_kv_attention_kernel \ <<>>( \ diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index bb7df25b14f0..1c3ea9369414 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -39,7 +39,11 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { float qk = sum(qk_vec); #pragma unroll for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM qk += __shfl_xor_sync(uint32_t(-1), qk, mask); +#else + qk += __shfl_xor(uint32_t(-1), qk, mask); +#endif } return qk; } From ddc496c7166c055b2adf0ca772475a8add24b3d6 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 10 Oct 2023 17:35:41 +0000 Subject: [PATCH 04/21] port more kernels --- csrc/activation_kernels.cu | 7 ++++--- csrc/cuda_utils_kernels.cu | 4 ++++ csrc/pos_encoding_kernels.cu | 9 +++++---- csrc/reduction_utils.cuh | 4 ++++ setup.py | 4 +++- 5 files changed, 20 insertions(+), 8 deletions(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index c6ae5db8f9c4..617cf6c0e4a5 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "cuda_compat.h" #include "dispatch_utils.h" namespace vllm { @@ -18,8 +19,8 @@ __global__ void silu_and_mul_kernel( const int d) { const int token_idx = blockIdx.x; for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]); - const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]); + const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); + const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); out[token_idx * d + idx] = silu(x) * y; } } @@ -57,7 +58,7 @@ __global__ void activation_kernel( const int d) { const int token_idx = blockIdx.x; for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = __ldg(&input[token_idx * d + idx]); + const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); out[token_idx * d + idx] = ACT_FN(x); } } diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu index f1c30fe7ea99..2439f5922a3f 100644 --- a/csrc/cuda_utils_kernels.cu +++ b/csrc/cuda_utils_kernels.cu @@ -1,3 +1,7 @@ +#ifdef USE_ROCM + #include +#endif + int get_device_attribute( int attribute, int device_id) diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index b4351ee0d794..1e977fa92837 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "cuda_compat.h" #include "dispatch_utils.h" namespace vllm { @@ -19,14 +20,14 @@ inline __device__ void apply_rotary_embedding( // GPT-NeoX style rotary embedding. x_index = rot_offset; y_index = embed_dim + rot_offset; - cos = __ldg(cos_ptr + x_index); - sin = __ldg(sin_ptr + x_index); + cos = VLLM_LDG(cos_ptr + x_index); + sin = VLLM_LDG(sin_ptr + x_index); } else { // GPT-J style rotary embedding. x_index = 2 * rot_offset; y_index = 2 * rot_offset + 1; - cos = __ldg(cos_ptr + x_index / 2); - sin = __ldg(sin_ptr + x_index / 2); + cos = VLLM_LDG(cos_ptr + x_index / 2); + sin = VLLM_LDG(sin_ptr + x_index / 2); } const scalar_t x = arr[x_index]; diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index bc35aa0424b5..382ad162dfef 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -23,7 +23,11 @@ template __inline__ __device__ T warpReduceSum(T val) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) +#ifndef USE_ROCM val += __shfl_xor_sync(0xffffffff, val, mask, 32); +#else + val += __shfl_xor(val, mask, 32); +#endif return val; } diff --git a/setup.py b/setup.py index 75433b336c35..2d9b2afc067f 100644 --- a/setup.py +++ b/setup.py @@ -117,6 +117,7 @@ def get_torch_arch_list() -> Set[str]: ) ext_modules.append(activation_extension) + # Quantization kernels. quantization_extension = CUDAExtension( name="vllm.quantization_ops", @@ -129,7 +130,8 @@ def get_torch_arch_list() -> Set[str]: "nvcc": NVCC_FLAGS, }, ) -ext_modules.append(quantization_extension) +if not torch.version.hip: + ext_modules.append(quantization_extension) # Misc. CUDA utils. cuda_utils_extension = CUDAExtension( From 5eaa7a10052ba77f2794dcd35316b371bddd92ea Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 10 Oct 2023 17:41:06 +0000 Subject: [PATCH 05/21] fix typo --- csrc/cache_kernels.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 2d54ac5eab01..1a9376b3103e 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -177,8 +177,8 @@ __global__ void reshape_and_cache_kernel( + head_idx * head_size * block_size + head_offset * block_size + block_offset; - key_cache[tgt_key_idx] = VLLM_LDB(&key[src_key_idx]); - value_cache[tgt_value_idx] = VLLM_LDB(&value[src_value_idx]); + key_cache[tgt_key_idx] = VLLM_LDG(&key[src_key_idx]); + value_cache[tgt_value_idx] = VLLM_LDG(&value[src_value_idx]); } } @@ -263,8 +263,8 @@ __global__ void gather_cached_kv_kernel( + head_offset * block_size + block_offset; - key[tgt_key_idx] = VLLM_LDB(&key_cache[src_key_idx]); - value[tgt_value_idx] = VLLM_LDB(&value_cache[src_value_idx]); + key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]); + value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]); } } @@ -329,8 +329,8 @@ __global__ void gather_cached_kv_kernel_optimized( src_key_indices[j] = src_key_idx; src_value_indices[j] = src_value_idx; - keys_to_store[j] = VLLM_LDB(&key_cache[src_key_idx]); - values_to_store[j] = VLLM_LDB(&value_cache[src_value_idx]); + keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]); + values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]); } #pragma unroll From f7273c6831f87ffa8a2b2a569695cd876f47778c Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 10 Oct 2023 18:18:08 +0000 Subject: [PATCH 06/21] add cuda_compat.h --- csrc/cuda_compat.h | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 csrc/cuda_compat.h diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h new file mode 100644 index 000000000000..3348b78cfa19 --- /dev/null +++ b/csrc/cuda_compat.h @@ -0,0 +1,7 @@ +#pragma once + +#ifndef USE_ROCM + #define VLLM_LDG(arg) __ldg(arg) +#else + #define VLLM_LDG(arg) *(arg) +#endif From f8093dc0dfeda653589c5f0db2fe5be46be614b3 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 20:27:30 +0000 Subject: [PATCH 07/21] sync branches --- csrc/attention/attention_kernels.cu | 498 +++++----------------------- csrc/attention/attention_utils.cuh | 2 +- csrc/attention/dtype_bfloat16.cuh | 5 - csrc/attention/dtype_float16.cuh | 65 ++-- 4 files changed, 121 insertions(+), 449 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index da7dedb0faf3..debde463786e 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -26,7 +26,6 @@ #define WARP_SIZE 32 #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) namespace vllm { @@ -43,7 +42,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { #ifndef USE_ROCM sum += __shfl_xor_sync(uint32_t(-1), sum, mask); #else - sum += __shfl_xor(uint32_t(-1), sum, mask); + sum += __shfl_xor(sum, mask); #endif } @@ -66,7 +65,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { #ifndef USE_ROCM sum += __shfl_xor_sync(uint32_t(-1), sum, mask); #else - sum += __shfl_xor(uint32_t(-1), sum, mask); + sum += __shfl_xor(sum, mask); #endif } @@ -74,22 +73,18 @@ inline __device__ float block_sum(float* red_smem, float sum) { #ifndef USE_ROCM return __shfl_sync(uint32_t(-1), sum, 0); #else - return __shfl(uint32_t(-1), sum, 0); + return __shfl(sum, 0); #endif } -// TODO(woosuk): Merge the last two dimensions of the grid. -// Grid: (num_heads, num_seqs, max_num_partitions). +// Grid: (num_heads, num_seqs). template< typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, - int NUM_THREADS, - int PARTITION_SIZE = 0> // Zero means no partitioning. -__device__ void paged_attention_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] + int NUM_THREADS> +__global__ void single_query_cached_kv_attention_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] @@ -102,33 +97,10 @@ __device__ void paged_attention_kernel( const int q_stride, const int kv_block_stride, const int kv_head_stride) { - const int seq_idx = blockIdx.y; - const int partition_idx = blockIdx.z; - const int max_num_partitions = gridDim.z; - constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; - const int context_len = context_lens[seq_idx]; - if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { - // No work to do. Terminate the thread block. - return; - } - - const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); - const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; - - // [start_block_idx, end_block_idx) is the range of blocks to process. - const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; - const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); - const int num_blocks = end_block_idx - start_block_idx; - - // [start_token_idx, end_token_idx) is the range of tokens to process. - const int start_token_idx = start_block_idx * BLOCK_SIZE; - const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); - const int num_tokens = end_token_idx - start_token_idx; - constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); - constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int thread_idx = threadIdx.x; const int warp_idx = thread_idx / WARP_SIZE; @@ -137,6 +109,7 @@ __device__ void paged_attention_kernel( const int head_idx = blockIdx.x; const int num_heads = gridDim.x; const int kv_head_idx = head_mapping[head_idx]; + const int seq_idx = blockIdx.y; const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; // A vector type to store a part of a key or a query. @@ -181,12 +154,15 @@ __device__ void paged_attention_kernel( constexpr int x = 16 / sizeof(scalar_t); float qk_max = -FLT_MAX; + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + const int context_len = context_lens[seq_idx]; + const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + // Iterate over the key blocks. // Each warp fetches a block of keys for each iteration. // Each thread group in a warp fetches a key from the block, and computes // dot product with the query. - const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { const int physical_block_number = block_table[block_idx]; // Load a key to registers. @@ -220,7 +196,7 @@ __device__ void paged_attention_kernel( // Store the partial reductions to shared memory. // NOTE(woosuk): It is required to zero out the masked logits. const bool mask = token_idx >= context_len; - logits[token_idx - start_token_idx] = mask ? 0.f : qk; + logits[token_idx] = mask ? 0.f : qk; // Update the max value. qk_max = mask ? qk_max : fmaxf(qk_max, qk); } @@ -235,7 +211,7 @@ __device__ void paged_attention_kernel( #ifndef USE_ROCM qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); #else - qk_max = fmaxf(qk_max, __shfl_xor(uint32_t(-1), qk_max, mask)); + qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask)); #endif } if (lane == 0) { @@ -251,19 +227,19 @@ __device__ void paged_attention_kernel( #ifndef USE_ROCM qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); #else - qk_max = fmaxf(qk_max, __shfl_xor(uint32_t(-1), qk_max, mask)); + qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask)); #endif } // Broadcast the max qk value to all threads. #ifndef USE_ROCM qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); #else - qk_max = __shfl(uint32_t(-1), qk_max, 0); + qk_max = __shfl(qk_max, 0); #endif // Get the sum of the exp values. float exp_sum = 0.f; - for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { float val = __expf(logits[i] - qk_max); logits[i] = val; exp_sum += val; @@ -272,23 +248,11 @@ __device__ void paged_attention_kernel( // Compute softmax. const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); - for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { logits[i] *= inv_sum; } __syncthreads(); - // If partitioning is enabled, store the max logit and exp_sum. - if (USE_PARTITIONING && thread_idx == 0) { - float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions - + partition_idx; - *max_logits_ptr = qk_max; - float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions - + partition_idx; - *exp_sums_ptr = exp_sum; - } - // Each thread will fetch 16 bytes from the value cache at a time. constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); using V_vec = typename Vec::Type; @@ -297,7 +261,7 @@ __device__ void paged_attention_kernel( constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; - constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. float accs[NUM_ROWS_PER_THREAD]; @@ -308,12 +272,12 @@ __device__ void paged_attention_kernel( scalar_t zero_value; zero(zero_value); - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { const int physical_block_number = block_table[block_idx]; const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; L_vec logits_vec; - from_float(logits_vec, *reinterpret_cast(logits + token_idx - start_token_idx)); + from_float(logits_vec, *reinterpret_cast(logits + token_idx)); const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride + kv_head_idx * kv_head_stride; @@ -323,7 +287,7 @@ __device__ void paged_attention_kernel( if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; V_vec v_vec = *reinterpret_cast(v_ptr + offset); - if (block_idx == num_context_blocks - 1) { + if (block_idx == num_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the context, // we should explicitly zero out the values since they may contain NaNs. // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 @@ -347,7 +311,7 @@ __device__ void paged_attention_kernel( #ifndef USE_ROCM acc += __shfl_xor_sync(uint32_t(-1), acc, mask); #else - acc += __shfl_xor(uint32_t(-1), acc, mask); + acc += __shfl_xor(acc, mask); #endif } accs[i] = acc; @@ -391,9 +355,7 @@ __device__ void paged_attention_kernel( // Write the final output. if (warp_idx == 0) { - scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE - + partition_idx * HEAD_SIZE; + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -404,167 +366,13 @@ __device__ void paged_attention_kernel( } } -// Grid: (num_heads, num_seqs, 1). -template< - typename scalar_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS> -__global__ void paged_attention_v1_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int* __restrict__ head_mapping, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride) { - paged_attention_kernel( - /* exp_sums */ nullptr, /* max_logits */ nullptr, - out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); -} - -// Grid: (num_heads, num_seqs, max_num_partitions). -template< - typename scalar_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - int PARTITION_SIZE> -__global__ void paged_attention_v2_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int* __restrict__ head_mapping, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride) { - paged_attention_kernel( - exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale, - block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, - q_stride, kv_block_stride, kv_head_stride); -} - -// Grid: (num_heads, num_seqs). -template< - typename scalar_t, - int HEAD_SIZE, - int NUM_THREADS, - int PARTITION_SIZE> -__global__ void paged_attention_v2_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_partitions) { - const int num_heads = gridDim.x; - const int head_idx = blockIdx.x; - const int seq_idx = blockIdx.y; - const int context_len = context_lens[seq_idx]; - const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); - if (num_partitions == 1) { - // No need to reduce. Only copy tmp_out to out. - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; - for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { - out_ptr[i] = tmp_out_ptr[i]; - } - // Terminate the thread block. - return; - } - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - const int warp_idx = threadIdx.x / WARP_SIZE; - const int lane = threadIdx.x % WARP_SIZE; - - // Size: 2 * num_partitions. - extern __shared__ char shared_mem[]; - // Workspace for reduction. - __shared__ float red_smem[2 * NUM_WARPS]; - - // Load max logits to shared memory. - float* shared_max_logits = reinterpret_cast(shared_mem); - const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; - float max_logit = -FLT_MAX; - for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { - const float l = max_logits_ptr[i]; - shared_max_logits[i] = l; - max_logit = fmaxf(max_logit, l); - } - __syncthreads(); - - // Get the global max logit. - // Reduce within the warp. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); - } - if (lane == 0) { - red_smem[warp_idx] = max_logit; - } - __syncthreads(); - // Reduce across warps. - max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); - } - // Broadcast the max value to all threads. - max_logit = __shfl_sync(uint32_t(-1), max_logit, 0); - - // Load rescaled exp sums to shared memory. - float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); - const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; - float global_exp_sum = 0.0f; - for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { - float l = shared_max_logits[i]; - float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); - global_exp_sum += rescaled_exp_sum; - shared_exp_sums[i] = rescaled_exp_sum; - } - __syncthreads(); - global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); - const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); - - // Aggregate tmp_out to out. - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; -#pragma unroll - for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { - float acc = 0.0f; - for (int j = 0; j < num_partitions; ++j) { - acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; - } - from_float(out_ptr[i], acc); - } -} - } // namespace vllm -#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ +#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ cudaFuncSetAttribute( \ - (void*)vllm::paged_attention_v1_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ - vllm::paged_attention_v1_kernel \ + (void*)vllm::single_query_cached_kv_attention_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ + vllm::single_query_cached_kv_attention_kernel \ <<>>( \ out_ptr, \ query_ptr, \ @@ -585,7 +393,7 @@ template< typename T, int BLOCK_SIZE, int NUM_THREADS = 128> -void paged_attention_v1_launcher( +void single_query_cached_kv_attention_launcher( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, @@ -621,206 +429,45 @@ void paged_attention_v1_launcher( int* context_lens_ptr = context_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; + int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; int logits_size = padded_max_context_len * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Keep that in sync with the logic here! int shared_mem_size = std::max(logits_size, outputs_size); - dim3 grid(num_heads, num_seqs, 1); - dim3 block(NUM_THREADS); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - switch (head_size) { - // NOTE(woosuk): To reduce the compilation time, we only compile for the - // head sizes that we use in the model. However, we can easily extend this - // to support any head size which is a multiple of 16. - case 64: - LAUNCH_PAGED_ATTENTION_V1(64); - break; - case 80: - LAUNCH_PAGED_ATTENTION_V1(80); - break; - case 96: - LAUNCH_PAGED_ATTENTION_V1(96); - break; - case 112: - LAUNCH_PAGED_ATTENTION_V1(112); - break; - case 128: - LAUNCH_PAGED_ATTENTION_V1(128); - break; - case 256: - LAUNCH_PAGED_ATTENTION_V1(256); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; - } -} - -#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v1_launcher( \ - out, \ - query, \ - key_cache, \ - value_cache, \ - head_mapping, \ - scale, \ - block_tables, \ - context_lens, \ - max_context_len, \ - alibi_slopes); - -// NOTE(woosuk): To reduce the compilation time, we omitted block sizes -// 1, 2, 4, 64, 128, 256. -#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 8: \ - CALL_V1_LAUNCHER(T, 8); \ - break; \ - case 16: \ - CALL_V1_LAUNCHER(T, 16); \ - break; \ - case 32: \ - CALL_V1_LAUNCHER(T, 32); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ - } - -void paged_attention_v1( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - torch::Tensor& head_mapping, // [num_heads] - float scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] - int block_size, - int max_context_len, - const c10::optional& alibi_slopes) { - if (query.dtype() == at::ScalarType::Float) { - CALL_V1_LAUNCHER_BLOCK_SIZE(float); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } -} - -#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ - vllm::paged_attention_v2_kernel \ - <<>>( \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - head_mapping_ptr, \ - scale, \ - block_tables_ptr, \ - context_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride); \ - vllm::paged_attention_v2_reduce_kernel \ - <<>>( \ - out_ptr, \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - context_lens_ptr, \ - max_num_partitions); - -template< - typename T, - int BLOCK_SIZE, - int NUM_THREADS = 128, - int PARTITION_SIZE = 512> -void paged_attention_v2_launcher( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& head_mapping, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int max_context_len, - const c10::optional& alibi_slopes) { - int num_seqs = query.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); - - int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); - assert(head_size % thread_group_size == 0); - - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; - - T* out_ptr = reinterpret_cast(out.data_ptr()); - float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); - float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); - T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* head_mapping_ptr = reinterpret_cast(head_mapping.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); - int logits_size = PARTITION_SIZE * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); - - // For paged attention v2 kernel. - dim3 grid(num_heads, num_seqs, max_num_partitions); - int shared_mem_size = std::max(logits_size, outputs_size); - // For paged attention v2 reduce kernel. - dim3 reduce_grid(num_heads, num_seqs); - int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); - + dim3 grid(num_heads, num_seqs); dim3 block(NUM_THREADS); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (head_size) { - // NOTE(woosuk): To reduce the compilation time, we only compile for the - // head sizes that we use in the model. However, we can easily extend this - // to support any head size which is a multiple of 16. + // NOTE(woosuk): To reduce the compilation time, we omitted head sizes + // 32, 160, 192. + // case 32: + // LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); + // break; case 64: - LAUNCH_PAGED_ATTENTION_V2(64); + LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); break; case 80: - LAUNCH_PAGED_ATTENTION_V2(80); + LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS); break; case 96: - LAUNCH_PAGED_ATTENTION_V2(96); + LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); break; case 112: - LAUNCH_PAGED_ATTENTION_V2(112); + LAUNCH_ATTENTION_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS); break; case 128: - LAUNCH_PAGED_ATTENTION_V2(128); + LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); break; + // case 160: + // LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); + // break; + // case 192: + // LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); + // break; case 256: - LAUNCH_PAGED_ATTENTION_V2(256); + LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); break; default: TORCH_CHECK(false, "Unsupported head size: ", head_size); @@ -828,12 +475,9 @@ void paged_attention_v2_launcher( } } -#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v2_launcher( \ +#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + single_query_cached_kv_attention_launcher( \ out, \ - exp_sums, \ - max_logits, \ - tmp_out, \ query, \ key_cache, \ value_cache, \ @@ -846,27 +490,42 @@ void paged_attention_v2_launcher( // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \ +#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ switch (block_size) { \ + /* case 1: */ \ + /* CALL_KERNEL_LAUNCHER(T, 1); */ \ + /* break; */ \ + /* case 2: */ \ + /* CALL_KERNEL_LAUNCHER(T, 2); */ \ + /* break; */ \ + /* case 4: */ \ + /* CALL_KERNEL_LAUNCHER(T, 4); */ \ + /* break; */ \ case 8: \ - CALL_V2_LAUNCHER(T, 8); \ + CALL_KERNEL_LAUNCHER(T, 8); \ break; \ case 16: \ - CALL_V2_LAUNCHER(T, 16); \ + CALL_KERNEL_LAUNCHER(T, 16); \ break; \ case 32: \ - CALL_V2_LAUNCHER(T, 32); \ + CALL_KERNEL_LAUNCHER(T, 32); \ break; \ + /* case 64: */ \ + /* CALL_KERNEL_LAUNCHER(T, 64); */ \ + /* break; */ \ + /* case 128: */ \ + /* CALL_KERNEL_LAUNCHER(T, 128); */ \ + /* break; */ \ + /* case 256: */ \ + /* CALL_KERNEL_LAUNCHER(T, 256); */ \ + /* break; */ \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ } -void paged_attention_v2( +void single_query_cached_kv_attention( torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] @@ -878,11 +537,11 @@ void paged_attention_v2( int max_context_len, const c10::optional& alibi_slopes) { if (query.dtype() == at::ScalarType::Float) { - CALL_V2_LAUNCHER_BLOCK_SIZE(float); + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float); } else if (query.dtype() == at::ScalarType::Half) { - CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t); + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t); } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } @@ -891,4 +550,3 @@ void paged_attention_v2( #undef WARP_SIZE #undef MAX #undef MIN -#undef DIVIDE_ROUND_UP diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index 1c3ea9369414..7e6b64eea96f 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -42,7 +42,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { #ifndef USE_ROCM qk += __shfl_xor_sync(uint32_t(-1), qk, mask); #else - qk += __shfl_xor(uint32_t(-1), qk, mask); + qk += __shfl_xor(qk, mask); #endif } return qk; diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 7f2b29de0d93..9ad2e299c7aa 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -439,11 +439,6 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) { #endif } -// From bfloat16 to float32. -inline __device__ float to_float(__nv_bfloat16 u) { - return __bfloat162float(u); -} - // Zero-out a variable. inline __device__ void zero(__nv_bfloat16& dst) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index 079fa607f96b..dc45dbf3daea 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -92,11 +92,15 @@ inline __device__ float2 half2_to_float2(uint32_t v) { #ifndef USE_ROCM uint16_t lo, hi; asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); + return make_float2(half_to_float(lo), half_to_float(hi)); #else - uint16_t hi = (v >> 16) & 0xFFFF; - uint16_t lo = v & 0xFFFF; + union { + __half2 h2; + uint32_t u32; + } V; + V.u32 = v; + return make_float2(half_to_float(V.h2.x), half_to_float(V.h2.y)); #endif - return make_float2(half_to_float(lo), half_to_float(hi)); } inline __device__ uint16_t float_to_half(float f) { @@ -113,24 +117,29 @@ inline __device__ uint16_t float_to_half(float f) { } inline __device__ uint32_t float2_to_half2(float2 f) { +#ifndef USE_ROCM union { uint32_t u32; uint16_t u16[2]; } tmp; -#ifndef USE_ROCM #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); #else asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); #endif + return tmp.u32; #else - __half2 h = __float22half2_rn(f); - tmp.u16[0] = h.x; - tmp.u16[1] = h.y; + union { + __half2 h2; + uint32_t u32; + } R; + + R.h2.x = __half_as_ushort(__float2half_rn(f.x)); + R.h2.y = __half_as_ushort(__float2half_rn(f.y)); + return R.u32; #endif - return tmp.u32; } // Vector addition. @@ -150,15 +159,14 @@ inline __device__ uint32_t add(uint32_t a, uint32_t b) { asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; #else - __half2 h = __hadd2(a, b); union { + __half2 h2; uint32_t u32; - uint16_t u16[2]; - } tmp; - tmp.u16[0] = h.x; - tmp.u16[1] = h.y; - - return tmp.u32; + } A, B, C; + A.u32 = a; + B.u32 = b; + C.h2 = __hadd2(A.h2, B.h2); + return C.u32; #endif } @@ -218,15 +226,14 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) { asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; #else - __half2 h = __hmul2(a, b); - union { + union { + __half2 h2; uint32_t u32; - uint16_t u16[2]; - } tmp; - tmp.u16[0] = h.x; - tmp.u16[1] = h.y; - - return tmp.u32; + } A, B, C; + A.u32 = a; + B.u32 = b; + C.h2 = __hmul2(A.h2, B.h2); + return C.u32; #endif } @@ -332,9 +339,21 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { // Vector fused multiply-add. inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { +#ifndef USE_ROCM uint32_t d; asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); return d; +#else + union { + __half2 h2; + uint32_t u32; + } A, B, C, D; + A.u32 = a; + B.u32 = b; + C.u32 = c; + D.h2 = __hfma2(A.h2, B.h2, C.h2); + return D.u32; +#endif } inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { From 41df6890a0398ed8535eeb644e70be78825cdd2c Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 20:31:05 +0000 Subject: [PATCH 08/21] update --- csrc/attention/attention_kernels.cu | 30 +---------------------------- 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index debde463786e..505c63d2efd7 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -39,11 +39,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Compute the sum per warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { -#ifndef USE_ROCM sum += __shfl_xor_sync(uint32_t(-1), sum, mask); -#else - sum += __shfl_xor(sum, mask); -#endif } // Warp leaders store the data to shared memory. @@ -62,19 +58,11 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Parallel reduction inside the warp. #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { -#ifndef USE_ROCM sum += __shfl_xor_sync(uint32_t(-1), sum, mask); -#else - sum += __shfl_xor(sum, mask); -#endif } // Broadcast to other threads. -#ifndef USE_ROCM return __shfl_sync(uint32_t(-1), sum, 0); -#else - return __shfl(sum, 0); -#endif } // Grid: (num_heads, num_seqs). @@ -208,11 +196,7 @@ __global__ void single_query_cached_kv_attention_kernel( // The 0-th thread of each thread group already has its max qk value. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { -#ifndef USE_ROCM qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); -#else - qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask)); -#endif } if (lane == 0) { red_smem[warp_idx] = qk_max; @@ -224,18 +208,10 @@ __global__ void single_query_cached_kv_attention_kernel( qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { -#ifndef USE_ROCM qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); -#else - qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask)); -#endif } // Broadcast the max qk value to all threads. -#ifndef USE_ROCM qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); -#else - qk_max = __shfl(qk_max, 0); -#endif // Get the sum of the exp values. float exp_sum = 0.f; @@ -308,11 +284,7 @@ __global__ void single_query_cached_kv_attention_kernel( float acc = accs[i]; #pragma unroll for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { -#ifndef USE_ROCM acc += __shfl_xor_sync(uint32_t(-1), acc, mask); -#else - acc += __shfl_xor(acc, mask); -#endif } accs[i] = acc; } @@ -370,7 +342,7 @@ __global__ void single_query_cached_kv_attention_kernel( #define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ cudaFuncSetAttribute( \ - (void*)vllm::single_query_cached_kv_attention_kernel, \ + vllm::single_query_cached_kv_attention_kernel, \ cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ vllm::single_query_cached_kv_attention_kernel \ <<>>( \ From 93be9c5b32aa36b96a2376e41498e0efd0dbb329 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 20:34:57 +0000 Subject: [PATCH 09/21] update --- csrc/attention/attention_kernels.cu | 484 ++++++++++++++++++++++++---- 1 file changed, 413 insertions(+), 71 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 505c63d2efd7..ee6b715adaef 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -26,6 +26,7 @@ #define WARP_SIZE 32 #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) namespace vllm { @@ -65,14 +66,18 @@ inline __device__ float block_sum(float* red_smem, float sum) { return __shfl_sync(uint32_t(-1), sum, 0); } -// Grid: (num_heads, num_seqs). +// TODO(woosuk): Merge the last two dimensions of the grid. +// Grid: (num_heads, num_seqs, max_num_partitions). template< typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, - int NUM_THREADS> -__global__ void single_query_cached_kv_attention_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + int NUM_THREADS, + int PARTITION_SIZE = 0> // Zero means no partitioning. +__device__ void paged_attention_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] @@ -85,10 +90,33 @@ __global__ void single_query_cached_kv_attention_kernel( const int q_stride, const int kv_block_stride, const int kv_head_stride) { + const int seq_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int max_num_partitions = gridDim.z; + constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; + const int context_len = context_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { + // No work to do. Terminate the thread block. + return; + } + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; + + // [start_block_idx, end_block_idx) is the range of blocks to process. + const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); + const int num_blocks = end_block_idx - start_block_idx; + + // [start_token_idx, end_token_idx) is the range of tokens to process. + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + const int num_tokens = end_token_idx - start_token_idx; + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); - constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; + constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int thread_idx = threadIdx.x; const int warp_idx = thread_idx / WARP_SIZE; @@ -97,7 +125,6 @@ __global__ void single_query_cached_kv_attention_kernel( const int head_idx = blockIdx.x; const int num_heads = gridDim.x; const int kv_head_idx = head_mapping[head_idx]; - const int seq_idx = blockIdx.y; const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; // A vector type to store a part of a key or a query. @@ -142,15 +169,12 @@ __global__ void single_query_cached_kv_attention_kernel( constexpr int x = 16 / sizeof(scalar_t); float qk_max = -FLT_MAX; - const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - const int context_len = context_lens[seq_idx]; - const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; - // Iterate over the key blocks. // Each warp fetches a block of keys for each iteration. // Each thread group in a warp fetches a key from the block, and computes // dot product with the query. - for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { const int physical_block_number = block_table[block_idx]; // Load a key to registers. @@ -184,7 +208,7 @@ __global__ void single_query_cached_kv_attention_kernel( // Store the partial reductions to shared memory. // NOTE(woosuk): It is required to zero out the masked logits. const bool mask = token_idx >= context_len; - logits[token_idx] = mask ? 0.f : qk; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; // Update the max value. qk_max = mask ? qk_max : fmaxf(qk_max, qk); } @@ -215,7 +239,7 @@ __global__ void single_query_cached_kv_attention_kernel( // Get the sum of the exp values. float exp_sum = 0.f; - for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { float val = __expf(logits[i] - qk_max); logits[i] = val; exp_sum += val; @@ -224,11 +248,23 @@ __global__ void single_query_cached_kv_attention_kernel( // Compute softmax. const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); - for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { logits[i] *= inv_sum; } __syncthreads(); + // If partitioning is enabled, store the max logit and exp_sum. + if (USE_PARTITIONING && thread_idx == 0) { + float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + + partition_idx; + *max_logits_ptr = qk_max; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + + partition_idx; + *exp_sums_ptr = exp_sum; + } + // Each thread will fetch 16 bytes from the value cache at a time. constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); using V_vec = typename Vec::Type; @@ -237,7 +273,7 @@ __global__ void single_query_cached_kv_attention_kernel( constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; - constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; + constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. float accs[NUM_ROWS_PER_THREAD]; @@ -248,12 +284,12 @@ __global__ void single_query_cached_kv_attention_kernel( scalar_t zero_value; zero(zero_value); - for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { const int physical_block_number = block_table[block_idx]; const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; L_vec logits_vec; - from_float(logits_vec, *reinterpret_cast(logits + token_idx)); + from_float(logits_vec, *reinterpret_cast(logits + token_idx - start_token_idx)); const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride + kv_head_idx * kv_head_stride; @@ -263,7 +299,7 @@ __global__ void single_query_cached_kv_attention_kernel( if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; V_vec v_vec = *reinterpret_cast(v_ptr + offset); - if (block_idx == num_blocks - 1) { + if (block_idx == num_context_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the context, // we should explicitly zero out the values since they may contain NaNs. // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 @@ -327,7 +363,9 @@ __global__ void single_query_cached_kv_attention_kernel( // Write the final output. if (warp_idx == 0) { - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -338,13 +376,167 @@ __global__ void single_query_cached_kv_attention_kernel( } } +// Grid: (num_heads, num_seqs, 1). +template< + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS> +__global__ void paged_attention_v1_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int* __restrict__ head_mapping, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride) { + paged_attention_kernel( + /* exp_sums */ nullptr, /* max_logits */ nullptr, + out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens, + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); +} + +// Grid: (num_heads, num_seqs, max_num_partitions). +template< + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS, + int PARTITION_SIZE> +__global__ void paged_attention_v2_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int* __restrict__ head_mapping, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride) { + paged_attention_kernel( + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale, + block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, + q_stride, kv_block_stride, kv_head_stride); +} + +// Grid: (num_heads, num_seqs). +template< + typename scalar_t, + int HEAD_SIZE, + int NUM_THREADS, + int PARTITION_SIZE> +__global__ void paged_attention_v2_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + if (num_partitions == 1) { + // No need to reduce. Only copy tmp_out to out. + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { + out_ptr[i] = tmp_out_ptr[i]; + } + // Terminate the thread block. + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warp_idx = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + + // Size: 2 * num_partitions. + extern __shared__ char shared_mem[]; + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // Load max logits to shared memory. + float* shared_max_logits = reinterpret_cast(shared_mem); + const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float max_logit = -FLT_MAX; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + const float l = max_logits_ptr[i]; + shared_max_logits[i] = l; + max_logit = fmaxf(max_logit, l); + } + __syncthreads(); + + // Get the global max logit. + // Reduce within the warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = max_logit; + } + __syncthreads(); + // Reduce across warps. + max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); + } + // Broadcast the max value to all threads. + max_logit = __shfl_sync(uint32_t(-1), max_logit, 0); + + // Load rescaled exp sums to shared memory. + float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float global_exp_sum = 0.0f; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + float l = shared_max_logits[i]; + float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); + global_exp_sum += rescaled_exp_sum; + shared_exp_sums[i] = rescaled_exp_sum; + } + __syncthreads(); + global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); + const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); + + // Aggregate tmp_out to out. + const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { + float acc = 0.0f; + for (int j = 0; j < num_partitions; ++j) { + acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; + } + from_float(out_ptr[i], acc); + } +} + } // namespace vllm -#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ cudaFuncSetAttribute( \ - vllm::single_query_cached_kv_attention_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ - vllm::single_query_cached_kv_attention_kernel \ + vllm::paged_attention_v1_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ + vllm::paged_attention_v1_kernel \ <<>>( \ out_ptr, \ query_ptr, \ @@ -365,7 +557,7 @@ template< typename T, int BLOCK_SIZE, int NUM_THREADS = 128> -void single_query_cached_kv_attention_launcher( +void paged_attention_v1_launcher( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, @@ -401,45 +593,206 @@ void single_query_cached_kv_attention_launcher( int* context_lens_ptr = context_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; + int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; int logits_size = padded_max_context_len * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Keep that in sync with the logic here! int shared_mem_size = std::max(logits_size, outputs_size); - dim3 grid(num_heads, num_seqs); + dim3 grid(num_heads, num_seqs, 1); + dim3 block(NUM_THREADS); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 64: + LAUNCH_PAGED_ATTENTION_V1(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V1(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V1(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V1(112); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V1(128); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V1(256); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v1_launcher( \ + out, \ + query, \ + key_cache, \ + value_cache, \ + head_mapping, \ + scale, \ + block_tables, \ + context_lens, \ + max_context_len, \ + alibi_slopes); + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER(T, 8); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER(T, 16); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER(T, 32); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void paged_attention_v1( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& head_mapping, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + int block_size, + int max_context_len, + const c10::optional& alibi_slopes) { + if (query.dtype() == at::ScalarType::Float) { + CALL_V1_LAUNCHER_BLOCK_SIZE(float); + } else if (query.dtype() == at::ScalarType::Half) { + CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } +} + +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, \ + max_logits_ptr, \ + tmp_out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + head_mapping_ptr, \ + scale, \ + block_tables_ptr, \ + context_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, \ + exp_sums_ptr, \ + max_logits_ptr, \ + tmp_out_ptr, \ + context_lens_ptr, \ + max_num_partitions); + +template< + typename T, + int BLOCK_SIZE, + int NUM_THREADS = 128, + int PARTITION_SIZE = 512> +void paged_attention_v2_launcher( + torch::Tensor& out, + torch::Tensor& exp_sums, + torch::Tensor& max_logits, + torch::Tensor& tmp_out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& head_mapping, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int max_context_len, + const c10::optional& alibi_slopes) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = alibi_slopes ? + reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* head_mapping_ptr = reinterpret_cast(head_mapping.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + int logits_size = PARTITION_SIZE * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + + // For paged attention v2 kernel. + dim3 grid(num_heads, num_seqs, max_num_partitions); + int shared_mem_size = std::max(logits_size, outputs_size); + // For paged attention v2 reduce kernel. + dim3 reduce_grid(num_heads, num_seqs); + int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + dim3 block(NUM_THREADS); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (head_size) { - // NOTE(woosuk): To reduce the compilation time, we omitted head sizes - // 32, 160, 192. - // case 32: - // LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); - // break; + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. case 64: - LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V2(64); break; case 80: - LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V2(80); break; case 96: - LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V2(96); break; case 112: - LAUNCH_ATTENTION_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V2(112); break; case 128: - LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V2(128); break; - // case 160: - // LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); - // break; - // case 192: - // LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); - // break; case 256: - LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V2(256); break; default: TORCH_CHECK(false, "Unsupported head size: ", head_size); @@ -447,9 +800,12 @@ void single_query_cached_kv_attention_launcher( } } -#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - single_query_cached_kv_attention_launcher( \ +#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v2_launcher( \ out, \ + exp_sums, \ + max_logits, \ + tmp_out, \ query, \ key_cache, \ value_cache, \ @@ -462,42 +818,27 @@ void single_query_cached_kv_attention_launcher( // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \ switch (block_size) { \ - /* case 1: */ \ - /* CALL_KERNEL_LAUNCHER(T, 1); */ \ - /* break; */ \ - /* case 2: */ \ - /* CALL_KERNEL_LAUNCHER(T, 2); */ \ - /* break; */ \ - /* case 4: */ \ - /* CALL_KERNEL_LAUNCHER(T, 4); */ \ - /* break; */ \ case 8: \ - CALL_KERNEL_LAUNCHER(T, 8); \ + CALL_V2_LAUNCHER(T, 8); \ break; \ case 16: \ - CALL_KERNEL_LAUNCHER(T, 16); \ + CALL_V2_LAUNCHER(T, 16); \ break; \ case 32: \ - CALL_KERNEL_LAUNCHER(T, 32); \ + CALL_V2_LAUNCHER(T, 32); \ break; \ - /* case 64: */ \ - /* CALL_KERNEL_LAUNCHER(T, 64); */ \ - /* break; */ \ - /* case 128: */ \ - /* CALL_KERNEL_LAUNCHER(T, 128); */ \ - /* break; */ \ - /* case 256: */ \ - /* CALL_KERNEL_LAUNCHER(T, 256); */ \ - /* break; */ \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ } -void single_query_cached_kv_attention( +void paged_attention_v2( torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] @@ -509,11 +850,11 @@ void single_query_cached_kv_attention( int max_context_len, const c10::optional& alibi_slopes) { if (query.dtype() == at::ScalarType::Float) { - CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float); + CALL_V2_LAUNCHER_BLOCK_SIZE(float); } else if (query.dtype() == at::ScalarType::Half) { - CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t); + CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t); } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } @@ -522,3 +863,4 @@ void single_query_cached_kv_attention( #undef WARP_SIZE #undef MAX #undef MIN +#undef DIVIDE_ROUND_UP From d96fa3c2c7b45ae78052433e1b433244e8428ab0 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 21:18:11 +0000 Subject: [PATCH 10/21] fixes --- csrc/attention/attention_kernels.cu | 42 ++++++++++++++++++++++++++++- csrc/attention/dtype_bfloat16.cuh | 5 ++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index ee6b715adaef..8fe641aa307b 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -40,7 +40,11 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Compute the sum per warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM sum += __shfl_xor_sync(uint32_t(-1), sum, mask); +#else + sum += __shfl_xor(sum, mask); +#endif } // Warp leaders store the data to shared memory. @@ -59,11 +63,19 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Parallel reduction inside the warp. #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM sum += __shfl_xor_sync(uint32_t(-1), sum, mask); +#else + sum += __shfl_xor(sum, mask); +#endif } // Broadcast to other threads. +#ifndef USE_ROCM return __shfl_sync(uint32_t(-1), sum, 0); +#else + return __shfl(sum, 0); +#endif } // TODO(woosuk): Merge the last two dimensions of the grid. @@ -220,7 +232,11 @@ __device__ void paged_attention_kernel( // The 0-th thread of each thread group already has its max qk value. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { +#ifndef USE_ROCM qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); +#else + qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask)); +#endif } if (lane == 0) { red_smem[warp_idx] = qk_max; @@ -232,10 +248,18 @@ __device__ void paged_attention_kernel( qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); +#else + qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask)); +#endif } // Broadcast the max qk value to all threads. +#ifndef USE_ROCM qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); +#else + qk_max = __shfl(qk_max, 0); +#endif // Get the sum of the exp values. float exp_sum = 0.f; @@ -320,7 +344,11 @@ __device__ void paged_attention_kernel( float acc = accs[i]; #pragma unroll for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM acc += __shfl_xor_sync(uint32_t(-1), acc, mask); +#else + acc += __shfl_xor(acc, mask); +#endif } accs[i] = acc; } @@ -486,7 +514,11 @@ __global__ void paged_attention_v2_reduce_kernel( // Reduce within the warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); +#else + max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); +#endif } if (lane == 0) { red_smem[warp_idx] = max_logit; @@ -496,10 +528,18 @@ __global__ void paged_attention_v2_reduce_kernel( max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); +#else + max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); +#endif } // Broadcast the max value to all threads. +#ifndef USE_ROCM max_logit = __shfl_sync(uint32_t(-1), max_logit, 0); +#else + max_logit = __shfl(max_logit, 0); +#endif // Load rescaled exp sums to shared memory. float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); @@ -534,7 +574,7 @@ __global__ void paged_attention_v2_reduce_kernel( #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ cudaFuncSetAttribute( \ - vllm::paged_attention_v1_kernel, \ + (void*)vllm::paged_attention_v1_kernel, \ cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ vllm::paged_attention_v1_kernel \ <<>>( \ diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 9ad2e299c7aa..7f2b29de0d93 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -439,6 +439,11 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) { #endif } +// From bfloat16 to float32. +inline __device__ float to_float(__nv_bfloat16 u) { + return __bfloat162float(u); +} + // Zero-out a variable. inline __device__ void zero(__nv_bfloat16& dst) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 From 421365b5cf80268710c921645e31c8e48e0596fa Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 22:36:57 +0000 Subject: [PATCH 11/21] cleanup --- csrc/attention/attention_kernels.cu | 43 ++++++----------------------- csrc/cuda_compat.h | 6 ++++ 2 files changed, 14 insertions(+), 35 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 8fe641aa307b..ebfe3f6a38b0 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -20,6 +20,7 @@ #include "attention_dtypes.h" #include "attention_utils.cuh" +#include "../cuda_compat.h" #include @@ -40,11 +41,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Compute the sum per warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { -#ifndef USE_ROCM - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); -#else - sum += __shfl_xor(sum, mask); -#endif + sum += VLLM_SHFL_XOR_SYNC(sum, mask); } // Warp leaders store the data to shared memory. @@ -63,11 +60,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Parallel reduction inside the warp. #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { -#ifndef USE_ROCM - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); -#else - sum += __shfl_xor(sum, mask); -#endif + sum += VLLM_SHFL_XOR_SYNC(sum, mask); } // Broadcast to other threads. @@ -232,11 +225,7 @@ __device__ void paged_attention_kernel( // The 0-th thread of each thread group already has its max qk value. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { -#ifndef USE_ROCM - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); -#else - qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask)); -#endif + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); } if (lane == 0) { red_smem[warp_idx] = qk_max; @@ -248,11 +237,7 @@ __device__ void paged_attention_kernel( qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { -#ifndef USE_ROCM - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); -#else - qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask)); -#endif + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); } // Broadcast the max qk value to all threads. #ifndef USE_ROCM @@ -344,11 +329,7 @@ __device__ void paged_attention_kernel( float acc = accs[i]; #pragma unroll for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { -#ifndef USE_ROCM - acc += __shfl_xor_sync(uint32_t(-1), acc, mask); -#else - acc += __shfl_xor(acc, mask); -#endif + acc += VLLM_SHFL_XOR_SYNC(acc, mask); } accs[i] = acc; } @@ -514,11 +495,7 @@ __global__ void paged_attention_v2_reduce_kernel( // Reduce within the warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { -#ifndef USE_ROCM - max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); -#else - max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); -#endif + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); } if (lane == 0) { red_smem[warp_idx] = max_logit; @@ -528,11 +505,7 @@ __global__ void paged_attention_v2_reduce_kernel( max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { -#ifndef USE_ROCM - max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); -#else - max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); -#endif + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); } // Broadcast the max value to all threads. #ifndef USE_ROCM diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index 3348b78cfa19..c5f170fcb475 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -5,3 +5,9 @@ #else #define VLLM_LDG(arg) *(arg) #endif + +#ifndef USE_ROCM + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) +#else + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) +#endif From 06b800e3d67097c05a0446b8d2f35047c6a794e7 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 22:48:14 +0000 Subject: [PATCH 12/21] update --- csrc/attention/attention_kernels.cu | 1 - csrc/attention/attention_utils.cuh | 7 ++----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index ebfe3f6a38b0..3bea905b0679 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -20,7 +20,6 @@ #include "attention_dtypes.h" #include "attention_utils.cuh" -#include "../cuda_compat.h" #include diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index 7e6b64eea96f..ff64c4bd8f80 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -17,6 +17,7 @@ */ #pragma once +#include "../cuda_compat.h" #include "attention_dtypes.h" #include @@ -39,11 +40,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { float qk = sum(qk_vec); #pragma unroll for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { -#ifndef USE_ROCM - qk += __shfl_xor_sync(uint32_t(-1), qk, mask); -#else - qk += __shfl_xor(qk, mask); -#endif + qk += VLLM_SHFL_XOR_SYNC(qk, mask); } return qk; } From 2312beb1f41206bbe97ed17b8fd0b935d9014f28 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 22:58:34 +0000 Subject: [PATCH 13/21] update --- csrc/attention/attention_kernels.cu | 18 +++--------------- csrc/cuda_compat.h | 6 ++++++ 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 3bea905b0679..babd15bb30fb 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -63,11 +63,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { } // Broadcast to other threads. -#ifndef USE_ROCM - return __shfl_sync(uint32_t(-1), sum, 0); -#else - return __shfl(sum, 0); -#endif + return VLLM_SHFL_SYNC(sum, 0); } // TODO(woosuk): Merge the last two dimensions of the grid. @@ -239,11 +235,7 @@ __device__ void paged_attention_kernel( qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); } // Broadcast the max qk value to all threads. -#ifndef USE_ROCM - qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); -#else - qk_max = __shfl(qk_max, 0); -#endif + qk_max = VLLM_SHFL_SYNC(qk_max, 0); // Get the sum of the exp values. float exp_sum = 0.f; @@ -507,11 +499,7 @@ __global__ void paged_attention_v2_reduce_kernel( max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); } // Broadcast the max value to all threads. -#ifndef USE_ROCM - max_logit = __shfl_sync(uint32_t(-1), max_logit, 0); -#else - max_logit = __shfl(max_logit, 0); -#endif + max_logit = VLLM_SHFL_SYNC(max_logit, 0); // Load rescaled exp sums to shared memory. float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index c5f170fcb475..8991462a862e 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -11,3 +11,9 @@ #else #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) #endif + +#ifndef USE_ROCM + #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane); +#else + #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) +#endif From 2958b39d9ad440ec90dd13fe597d5d5e49d59d3e Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 23:11:25 +0000 Subject: [PATCH 14/21] update --- csrc/layernorm_kernels.cu | 1 + csrc/reduction_utils.cuh | 8 +++----- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index f932b9e2d615..9d4ada1f0715 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -4,6 +4,7 @@ #include "dispatch_utils.h" #include "reduction_utils.cuh" + namespace vllm { // TODO(woosuk): Further optimize this kernel. diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index 382ad162dfef..b95ccef16207 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -17,17 +17,15 @@ */ #pragma once +#include "cuda_compat.h" + namespace vllm { template __inline__ __device__ T warpReduceSum(T val) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) -#ifndef USE_ROCM - val += __shfl_xor_sync(0xffffffff, val, mask, 32); -#else - val += __shfl_xor(val, mask, 32); -#endif + val += VLLM_SHFL_XOR_SYNC(val, mask); return val; } From 3f8973403a3f9fdb21eb48bea1066db0ea4d4fcd Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 23:14:19 +0000 Subject: [PATCH 15/21] fmt --- csrc/layernorm_kernels.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 9d4ada1f0715..f932b9e2d615 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -4,7 +4,6 @@ #include "dispatch_utils.h" #include "reduction_utils.cuh" - namespace vllm { // TODO(woosuk): Further optimize this kernel. From 5397a5748c19f3ceedc037ce101eadd83191edae Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 23:15:49 +0000 Subject: [PATCH 16/21] cleanup --- setup.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/setup.py b/setup.py index 5eeb36b5b75e..a8e9aa4af8c3 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,11 @@ CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] +if not torch.version.hip: + if CUDA_HOME is None: + raise RuntimeError( + "Cannot find CUDA_HOME. CUDA must be available to build the package.") + def get_nvcc_cuda_version(cuda_dir: str) -> Version: """Get the CUDA version from nvcc. From 90e02d25b64757b2ccf1b7fc452db207591b8cb8 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 23:29:19 +0000 Subject: [PATCH 17/21] refactor --- setup.py | 66 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/setup.py b/setup.py index a8e9aa4af8c3..a44e33bfcea2 100644 --- a/setup.py +++ b/setup.py @@ -77,6 +77,72 @@ def get_torch_arch_list() -> Set[str]: f"{valid_archs}.") return arch_list +def get_cuda_compute_capabilities(nvcc_cuda_version): + # First, check the TORCH_CUDA_ARCH_LIST environment variable. + compute_capabilities = get_torch_arch_list() + if not compute_capabilities: + # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available + # GPUs on the current machine. + device_count = torch.cuda.device_count() + for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 7: + raise RuntimeError( + "GPUs with compute capability below 7.0 are not supported.") + compute_capabilities.add(f"{major}.{minor}") + + if not compute_capabilities: + # If no GPU is specified nor available, add all supported architectures + # based on the NVCC CUDA version. + compute_capabilities = SUPPORTED_ARCHS.copy() + if nvcc_cuda_version < Version("11.1"): + compute_capabilities.remove("8.6") + if nvcc_cuda_version < Version("11.8"): + compute_capabilities.remove("8.9") + compute_capabilities.remove("9.0") + + return compute_capabilities + +def validate_nvcc_cuda_version(nvcc_cuda_version, compute_capabilities): + if nvcc_cuda_version < Version("11.0"): + raise RuntimeError("CUDA 11.0 or higher is required to build the package.") + if nvcc_cuda_version < Version("11.1"): + if any(cc.startswith("8.6") for cc in compute_capabilities): + raise RuntimeError( + "CUDA 11.1 or higher is required for compute capability 8.6.") + if nvcc_cuda_version < Version("11.8"): + if any(cc.startswith("8.9") for cc in compute_capabilities): + # CUDA 11.8 is required to generate the code targeting compute capability 8.9. + # However, GPUs with compute capability 8.9 can also run the code generated by + # the previous versions of CUDA 11 and targeting compute capability 8.0. + # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 + # instead of 8.9. + warnings.warn( + "CUDA 11.8 or higher is required for compute capability 8.9. " + "Targeting compute capability 8.0 instead.") + compute_capabilities = set(cc for cc in compute_capabilities + if not cc.startswith("8.9")) + compute_capabilities.add("8.0+PTX") + if any(cc.startswith("9.0") for cc in compute_capabilities): + raise RuntimeError( + "CUDA 11.8 or higher is required for compute capability 9.0.") + +if not torch.version.hip: + nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) + compute_capabilities = get_cuda_compute_capabilities(nvcc_cuda_version) + validate_nvcc_cuda_version(nvcc_cuda_version, compute_capabilities) + + # Add target compute capabilities to NVCC flags. + for capability in compute_capabilities: + num = capability[0] + capability[2] + NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] + if capability.endswith("+PTX"): + NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] + + # Use NVCC threads to parallelize the build. + if nvcc_cuda_version >= Version("11.2"): + num_threads = min(os.cpu_count(), 8) + NVCC_FLAGS += ["--threads", str(num_threads)] ext_modules = [] From a42020206afea67123304f206c186523f31d5582 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 23:31:29 +0000 Subject: [PATCH 18/21] update --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index a44e33bfcea2..c2e5ad80aa75 100644 --- a/setup.py +++ b/setup.py @@ -201,7 +201,6 @@ def validate_nvcc_cuda_version(nvcc_cuda_version, compute_capabilities): ) ext_modules.append(activation_extension) - # Quantization kernels. quantization_extension = CUDAExtension( name="vllm.quantization_ops", From 2d1e43581c2a3eefc09ad1a93ab2c83145ae38ae Mon Sep 17 00:00:00 2001 From: Amir Balwel Date: Tue, 17 Oct 2023 06:59:09 +0000 Subject: [PATCH 19/21] detecting rocm and adding flag for compiling --- setup.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index c2e5ad80aa75..55a4358f734f 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ from packaging.version import parse, Version import setuptools import torch -from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME +from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME ROOT_DIR = os.path.dirname(__file__) @@ -24,12 +24,15 @@ CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] +if torch.version.hip: + if ROCM_HOME is not None: + NVCC_FLAGS += [f"-DUSE_ROCM"] + if not torch.version.hip: if CUDA_HOME is None: raise RuntimeError( "Cannot find CUDA_HOME. CUDA must be available to build the package.") - def get_nvcc_cuda_version(cuda_dir: str) -> Version: """Get the CUDA version from nvcc. From e231b7903dabbc698639d9a303edcb13d4e34d0f Mon Sep 17 00:00:00 2001 From: Amir Balwel Date: Tue, 17 Oct 2023 06:59:46 +0000 Subject: [PATCH 20/21] using asm volatile instead of hip api --- csrc/attention/dtype_float16.cuh | 100 ++++++++++++------------------- 1 file changed, 39 insertions(+), 61 deletions(-) diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index dc45dbf3daea..8e670d81ff3a 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -71,21 +71,25 @@ inline __device__ uint32_t h0_h0(uint16_t a) { uint32_t b; asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); #else - uint32_t b = a; - b <<= 16; - b |= a; + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u16[0] = a; + tmp.u16[1] = a; + b = tmp.u32 #endif return b; } inline __device__ float half_to_float(uint16_t h) { -#ifndef USE_ROCM float f; +#ifndef USE_ROCM asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); - return f; #else - return __half2float(__ushort_as_half(h)); + asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h)); #endif + return f; } inline __device__ float2 half2_to_float2(uint32_t v) { @@ -95,79 +99,68 @@ inline __device__ float2 half2_to_float2(uint32_t v) { return make_float2(half_to_float(lo), half_to_float(hi)); #else union { - __half2 h2; uint32_t u32; - } V; - V.u32 = v; - return make_float2(half_to_float(V.h2.x), half_to_float(V.h2.y)); + uint16_t u16[2]; + } tmp; + tmp.u32 = v; + float2 ret; + ret.x = half_to_float(tmp.u16[0]); + ret.y = half_to_float(tmp.u16[1]); + return ret; #endif } inline __device__ uint16_t float_to_half(float f) { -#ifndef USE_ROCM union { uint32_t u32; uint16_t u16[2]; } tmp; +#ifndef USE_ROCM asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); - return tmp.u16[0]; #else - return __half_as_ushort(__float2half(f)); + asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f)); #endif + return tmp.u16[0]; } inline __device__ uint32_t float2_to_half2(float2 f) { -#ifndef USE_ROCM union { uint32_t u32; uint16_t u16[2]; } tmp; - +#ifndef USE_ROCM #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); #else asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); #endif - return tmp.u32; #else - union { - __half2 h2; - uint32_t u32; - } R; - - R.h2.x = __half_as_ushort(__float2half_rn(f.x)); - R.h2.y = __half_as_ushort(__float2half_rn(f.y)); - return R.u32; + tmp.u16[0] = float_to_half(f.x); + tmp.u16[1] = float_to_half(f.y); #endif + return tmp.u32; } // Vector addition. inline __device__ uint16_t add(uint16_t a, uint16_t b) { -#ifndef USE_ROCM uint16_t c; +#ifndef USE_ROCM asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; #else - return __half_as_ushort(__hadd(__ushort_as_half(a), __ushort_as_half(b))); + asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); #endif + return c; } inline __device__ uint32_t add(uint32_t a, uint32_t b) { -#ifndef USE_ROCM uint32_t c; +#ifndef USE_ROCM asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; #else - union { - __half2 h2; - uint32_t u32; - } A, B, C; - A.u32 = a; - B.u32 = b; - C.h2 = __hadd2(A.h2, B.h2); - return C.u32; + asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); #endif + return c; } inline __device__ uint2 add(uint2 a, uint2 b) { @@ -210,31 +203,24 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { // Vector multiplication. template<> inline __device__ uint16_t mul(uint16_t a, uint16_t b) { -#ifndef USE_ROCM uint16_t c; +#ifndef USE_ROCM asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; #else - return __half_as_ushort(__hmul(__ushort_as_half(a), __ushort_as_half(b))); + asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); #endif + return c; } template<> inline __device__ uint32_t mul(uint32_t a, uint32_t b) { -#ifndef USE_ROCM uint32_t c; +#ifndef USE_ROCM asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; #else - union { - __half2 h2; - uint32_t u32; - } A, B, C; - A.u32 = a; - B.u32 = b; - C.h2 = __hmul2(A.h2, B.h2); - return C.u32; + asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); #endif + return c; } template<> @@ -339,21 +325,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { // Vector fused multiply-add. inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { -#ifndef USE_ROCM uint32_t d; +#ifndef USE_ROCM asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); - return d; #else - union { - __half2 h2; - uint32_t u32; - } A, B, C, D; - A.u32 = a; - B.u32 = b; - C.u32 = c; - D.h2 = __hfma2(A.h2, B.h2, C.h2); - return D.u32; + asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); #endif + return d; } inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { From 31bb33568f1b14cbcc4e6aee84eed11c0da10c16 Mon Sep 17 00:00:00 2001 From: Amir Balwel Date: Tue, 17 Oct 2023 07:15:01 +0000 Subject: [PATCH 21/21] using asm volatile for type casting of f16 --- csrc/attention/dtype_float16.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index 8e670d81ff3a..b9c9275aae3f 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -67,8 +67,8 @@ struct FloatVec { // Utility functions for type conversions. inline __device__ uint32_t h0_h0(uint16_t a) { -#ifndef USE_ROCM uint32_t b; +#ifndef USE_ROCM asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); #else union { @@ -77,7 +77,7 @@ inline __device__ uint32_t h0_h0(uint16_t a) { } tmp; tmp.u16[0] = a; tmp.u16[1] = a; - b = tmp.u32 + b = tmp.u32; #endif return b; }