From cfc38865d40cf7f72c11f48ae90255d439ec3d80 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Sun, 30 Apr 2023 22:29:19 +0000 Subject: [PATCH 01/26] Minor --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e96c73033379..116dad140484 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ ) ext_modules.append(attention_extension) -# Positional encodings. +# Positional encoding kernels. positional_encoding_extension = cpp_extension.CUDAExtension( name='cacheflow.pos_encoding_ops', sources=['csrc/pos_encoding.cpp', 'csrc/pos_encoding_kernels.cu'], @@ -39,6 +39,7 @@ ) ext_modules.append(layernorm_extension) +# Activation kernels. activation_extension = cpp_extension.CUDAExtension( name='cacheflow.activation_ops', sources=['csrc/activation.cpp', 'csrc/activation_kernels.cu'], From cffca2019510cdc8359a4aa47005a60f774d1b0e Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Sun, 30 Apr 2023 22:29:36 +0000 Subject: [PATCH 02/26] Remove unused kernels --- csrc/attention_kernels.cu | 495 -------------------------------------- 1 file changed, 495 deletions(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index c25acbb8be6f..1bab3d75f325 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -396,501 +396,6 @@ void single_query_cached_kv_attention( } } -// namespace cacheflow { - -// // Grid: (num_heads, num_query_tokens). -// template< -// typename scalar_t, -// int HEAD_SIZE, -// int BLOCK_SIZE, -// int NUM_THREADS> -// __device__ void multi_query_cached_kv_attention_kernel_unoptimized_( -// scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] -// const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] -// const int seq_start_idx, -// const int seq_len, -// const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] -// const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] -// const float scale, -// const int* __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq] -// const int context_len, -// const int max_num_blocks_per_seq, -// const int q_stride) { -// constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE; -// constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; -// const int thread_idx = threadIdx.x; -// const int warp_idx = thread_idx / WARP_SIZE; -// const int lane = thread_idx % WARP_SIZE; - -// const int head_idx = blockIdx.x; -// const int num_heads = gridDim.x; -// const int seq_idx = blockIdx.y; - -// // A vector type to store a part of a key or a query. -// // The vector size is configured in such a way that the threads in a thread group -// // fetch or comput 16 bytes at a time. -// // For example, if the size of a thread group is 4 and the data type is half, -// // then the vector size is 16 / (4 * sizeof(half)) == 2. -// constexpr int VEC_SIZE = 16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)); -// using K_vec = typename Vec::Type; -// using Q_vec = typename Vec::Type; - -// constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; -// constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; - -// const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; -// const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; - -// // Load the query to registers. -// // Each thread in a thread group has a different part of the query. -// // For example, if the the thread group size is 4, then the first thread in the group -// // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... -// // th vectors of the query, and so on. -// // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. -// const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; -// Q_vec q_vecs[NUM_VECS_PER_THREAD]; -// #pragma unroll -// for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { -// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; -// q_vecs[i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); -// } - -// // Memory planning. -// extern __shared__ char shared_mem[]; -// // NOTE(woosuk): We use FP32 logits and accumulation. -// float *logits = reinterpret_cast(shared_mem); -// // Workspace for reduction. -// __shared__ float red_smem[2 * NUM_WARPS]; - -// // x == THREAD_GROUP_SIZE * VEC_SIZE -// // Each thread group fetches x elements from the key at a time. -// constexpr int x = 16 / sizeof(scalar_t); -// float qk_max = -FLT_MAX; - -// const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; -// const int mask_boundary = context_len - seq_len + 1 + (seq_idx - seq_start_idx); - -// // Iterate over the key blocks. -// // Each warp fetches a block of keys for each iteration. -// // Each thread group in a warp fetches a key from the block, and computes -// // dot product with the query. -// for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { -// const int physical_block_number = block_table[block_idx]; -// const int physical_block_offset = thread_group_idx % BLOCK_SIZE; -// const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - -// // Load a key to registers. -// // Each thread in a thread group has a different part of the key. -// // For example, if the the thread group size is 4, then the first thread in the group -// // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th -// // vectors of the key, and so on. -// K_vec k_vecs[NUM_VECS_PER_THREAD]; -// #pragma unroll -// for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { -// const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE -// + head_idx * HEAD_SIZE * BLOCK_SIZE -// + physical_block_offset * x; -// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; -// const int offset1 = (vec_idx * VEC_SIZE) / x; -// const int offset2 = (vec_idx * VEC_SIZE) % x; -// k_vecs[i] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); -// } - -// // Compute dot product. -// // This includes a reduction across the threads in the same thread group. -// const float qk = scale * Qk_dot::dot(q_vecs, k_vecs); -// const bool mask = token_idx >= mask_boundary; - -// if (thread_group_offset == 0) { -// // Store the partial reductions to shared memory. -// // NOTE(woosuk): It is required to zero out the masked logits. -// logits[token_idx] = mask ? 0.f : qk; -// // Update the max value. -// qk_max = mask ? qk_max : fmaxf(qk_max, qk); -// } -// } - -// // Perform reduction across the threads in the same warp to get the -// // max qk value for each "warp" (not across the thread block yet). -// // The 0-th thread of each thread group already has its max qk value. -// #pragma unroll -// for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { -// qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); -// } -// if (lane == 0) { -// red_smem[warp_idx] = qk_max; -// } -// __syncthreads(); - -// // TODO(woosuk): Refactor this part. -// // Get the max qk value for the sequence. -// qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; -// #pragma unroll -// for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { -// qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); -// } -// // Broadcast the max qk value to all threads. -// qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); - -// // Get the sum of the exp values. -// float exp_sum = 0.f; -// for (int i = thread_idx; i < mask_boundary; i += NUM_THREADS) { -// float val = __expf(logits[i] - qk_max); -// logits[i] = val; -// exp_sum += val; -// } -// exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); - -// // Compute softmax. -// const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); -// for (int i = thread_idx; i < context_len; i += NUM_THREADS) { -// logits[i] *= inv_sum; -// } -// __syncthreads(); - -// // Each thread will fetch 16 bytes from the value cache at a time. -// constexpr int V_VEC_SIZE = 16 / sizeof(scalar_t); -// using V_vec = typename Vec::Type; -// using L_vec = typename FloatVec::Type; - -// constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; -// constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; -// constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; - -// float accs[NUM_ROWS_PER_THREAD]; -// #pragma unroll -// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -// accs[i] = 0.f; -// } - -// for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { -// const int physical_block_number = block_table[block_idx]; -// const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; -// const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; -// L_vec logits_vec = *reinterpret_cast(logits + token_idx); - -// const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE -// + head_idx * HEAD_SIZE * BLOCK_SIZE; -// #pragma unroll -// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; -// if (row_idx < HEAD_SIZE) { -// const int offset = row_idx * BLOCK_SIZE + physical_block_offset; -// V_vec v_vec = *reinterpret_cast(v_ptr + offset); -// accs[i] += dot(logits_vec, cast_to_float(v_vec)); -// } -// } -// } - -// // Perform reduction within each warp. -// #pragma unroll -// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -// float acc = accs[i]; -// #pragma unroll -// for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { -// acc += __shfl_xor_sync(uint32_t(-1), acc, mask); -// } -// accs[i] = acc; -// } - -// // NOTE(woosuk): A barrier is required because the shared memory space for logits -// // is reused for the output. -// __syncthreads(); - -// // Perform reduction across warps. -// float* out_smem = reinterpret_cast(shared_mem); -// #pragma unroll -// for (int i = NUM_WARPS; i > 1; i /= 2) { -// int mid = i / 2; -// // Upper warps write to shared memory. -// if (warp_idx >= mid && warp_idx < i) { -// float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; -// #pragma unroll -// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; -// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { -// dst[row_idx] = accs[i]; -// } -// } -// } -// __syncthreads(); - -// // Lower warps update the output. -// if (warp_idx < mid) { -// const float* src = &out_smem[warp_idx * HEAD_SIZE]; -// #pragma unroll -// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; -// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { -// accs[i] += src[row_idx]; -// } -// } -// } -// __syncthreads(); -// } - -// // Write the final output. -// if (warp_idx == 0) { -// scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; -// #pragma unroll -// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; -// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { -// convert_from_float(*(out_ptr + row_idx), accs[i]); -// } -// } -// } -// } - - -// // Grid: (num_heads, num_query_tokens). -// template< -// typename scalar_t, -// int HEAD_SIZE, -// int BLOCK_SIZE, -// int NUM_THREADS> -// __global__ void multi_query_cached_kv_attention_kernel( -// const int* cu_query_lens, // [num_prompts+1] -// const int* seq_prompt_mapping, // [num_seqs] mapping from seq_idx to prompt_idx -// scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] -// const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] -// const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] -// const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] -// const float scale, -// const int* __restrict__ block_tables, // [num_prompts, max_num_blocks_per_seq] -// const int* __restrict__ context_lens, // [num_prompts] -// const int max_num_blocks_per_seq, -// const int q_stride) { -// const int seq_idx = blockIdx.y; -// const int prompt_idx = seq_prompt_mapping[seq_idx]; -// const int seq_start_idx = cu_query_lens[prompt_idx]; -// const int seq_len = cu_query_lens[prompt_idx + 1] - seq_start_idx; -// const int* block_table = block_tables + prompt_idx * max_num_blocks_per_seq; -// const int context_len = context_lens[prompt_idx]; -// multi_query_cached_kv_attention_kernel_unoptimized_< -// scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>( -// out, -// q, -// seq_start_idx, -// seq_len, -// k_cache, -// v_cache, -// scale, -// block_table, -// context_len, -// max_num_blocks_per_seq, -// q_stride); -// } - -// } // namespace cacheflow - -// #define LAUNCH_MULTI_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ -// cacheflow::multi_query_cached_kv_attention_kernel \ -// <<>>( \ -// cu_query_lens_ptr, \ -// seq_prompt_mapping_ptr, \ -// out_ptr, \ -// query_ptr, \ -// key_cache_ptr, \ -// value_cache_ptr, \ -// scale, \ -// block_tables_ptr, \ -// context_lens_ptr, \ -// max_num_blocks_per_seq, \ -// query_stride); - - -// // TODO(woosuk): Tune NUM_THREADS. -// template< -// typename T, -// int BLOCK_SIZE, -// int NUM_THREADS = 128> -// void multi_query_cached_kv_attention_launcher( -// torch::Tensor& cu_query_lens, -// torch::Tensor& seq_prompt_mapping, -// torch::Tensor& out, -// torch::Tensor& query, -// torch::Tensor& key_cache, -// torch::Tensor& value_cache, -// float scale, -// torch::Tensor& block_tables, -// torch::Tensor& context_lens, -// int max_context_len) { -// int num_seqs = query.size(0); -// int num_heads = query.size(1); -// int head_size = query.size(2); -// int max_num_blocks_per_seq = block_tables.size(1); -// int query_stride = query.stride(0); - -// int* cu_query_lens_ptr = cu_query_lens.data_ptr(); -// int* seq_prompt_mapping_ptr = seq_prompt_mapping.data_ptr(); -// T* out_ptr = reinterpret_cast(out.data_ptr()); -// T* query_ptr = reinterpret_cast(query.data_ptr()); -// T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); -// T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); -// int* block_tables_ptr = block_tables.data_ptr(); -// int* context_lens_ptr = context_lens.data_ptr(); - -// constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; -// int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; -// int logits_size = padded_max_context_len * sizeof(float); -// int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); -// int shared_mem_size = std::max(logits_size, outputs_size); - -// dim3 grid(num_heads, num_seqs); -// dim3 block(NUM_THREADS); -// const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); -// switch (head_size) { -// case 32: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); -// break; -// case 64: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); -// break; -// case 80: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS); -// break; -// case 96: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); -// break; -// case 128: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); -// break; -// case 160: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); -// break; -// case 192: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); -// break; -// case 256: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); -// break; -// default: -// assert(false); -// break; -// } -// } - -// void multi_query_cached_kv_attention( -// torch::Tensor& cu_query_lens, -// torch::Tensor& out, -// torch::Tensor& query, -// torch::Tensor& key_cache, -// torch::Tensor& value_cache, -// float scale, -// torch::Tensor& block_tables, -// torch::Tensor& context_lens, -// int block_size, -// int max_context_len) { - -// torch::Tensor query_lens = cu_query_lens.to(torch::kCPU); - -// int num_queries = query_lens.size(0) - 1; -// const int* query_lens_ptr = query_lens.data_ptr(); -// int num_seqs = query.size(0); - -// torch::Tensor cpu_tensor = torch::empty({num_seqs}, torch::dtype(torch::kInt32)); -// auto accessor = cpu_tensor.accessor(); -// for (int i = 0, query_cursor = 0; i < num_seqs; ++i) { -// if (i >= query_lens_ptr[query_cursor + 1]) { -// ++query_cursor; -// } -// accessor[i] = query_cursor; -// } - -// // TODO(suquark): This can be slow, as it to(torch::kCPU) and to(torch::kCUDA) -// // implicitly synchronizes the CPU and GPU. And we can avoid this issue by giving -// // the mapping as an input parameter. Let's do this optimization in a later PR. -// torch::Tensor seq_prompt_mapping = cpu_tensor.to(torch::kCUDA); - -// // TODO(woosuk): Support BF16. -// if (query.element_size() == 2) { -// // Half. -// if (block_size == 8) { -// multi_query_cached_kv_attention_launcher( -// cu_query_lens, -// seq_prompt_mapping, -// out, -// query, -// key_cache, -// value_cache, -// scale, -// block_tables, -// context_lens, -// max_context_len); -// } else if (block_size == 16) { -// multi_query_cached_kv_attention_launcher( -// cu_query_lens, -// seq_prompt_mapping, -// out, -// query, -// key_cache, -// value_cache, -// scale, -// block_tables, -// context_lens, -// max_context_len); -// } else if (block_size == 32) { -// multi_query_cached_kv_attention_launcher( -// cu_query_lens, -// seq_prompt_mapping, -// out, -// query, -// key_cache, -// value_cache, -// scale, -// block_tables, -// context_lens, -// max_context_len); -// } else { -// assert(false); -// } -// } else if (query.element_size() == 4) { -// // Float. -// if (block_size == 8) { -// multi_query_cached_kv_attention_launcher( -// cu_query_lens, -// seq_prompt_mapping, -// out, -// query, -// key_cache, -// value_cache, -// scale, -// block_tables, -// context_lens, -// max_context_len); -// } else if (block_size == 16) { -// multi_query_cached_kv_attention_launcher( -// cu_query_lens, -// seq_prompt_mapping, -// out, -// query, -// key_cache, -// value_cache, -// scale, -// block_tables, -// context_lens, -// max_context_len); -// } else if (block_size == 32) { -// multi_query_cached_kv_attention_launcher( -// cu_query_lens, -// seq_prompt_mapping, -// out, -// query, -// key_cache, -// value_cache, -// scale, -// block_tables, -// context_lens, -// max_context_len); -// } else { -// assert(false); -// } -// } else { -// assert(false); -// } -// } - #undef WARP_SIZE #undef MAX #undef MIN From a97500fe581efd172bec243bd76f735cf79d7236 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Sun, 30 Apr 2023 23:12:37 +0000 Subject: [PATCH 03/26] Add support for bfloat16 --- csrc/activation_kernels.cu | 4 +- csrc/cache_kernels.cu | 98 ++++++++++++++++++++---------------- csrc/layernorm_kernels.cu | 4 +- csrc/pos_encoding_kernels.cu | 4 +- 4 files changed, 64 insertions(+), 46 deletions(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 12ee6c54827c..a13b1b9cf290 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -34,7 +34,9 @@ void silu_and_mul( dim3 grid(num_tokens); dim3 block(std::min(d, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, input.scalar_type(), "silu_and_mul_kernel", [&] { diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 5f97af254142..ddd2d3505780 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -14,14 +14,16 @@ void swap_blocks( torch::Device dst_device = dst.device(); cudaMemcpyKind memcpy_type; if (src_device.is_cuda() && dst_device.is_cuda()) { - assert(src_device.index() == dst_device.index()); + TORCH_CHECK( + src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); memcpy_type = cudaMemcpyDeviceToDevice; } else if (src_device.is_cuda() && dst_device.is_cpu()) { memcpy_type = cudaMemcpyDeviceToHost; } else if (src_device.is_cpu() && dst_device.is_cuda()) { memcpy_type = cudaMemcpyHostToDevice; } else { - assert(false); + TORCH_CHECK(false, "Invalid device combination"); } void *src_ptr = src.data_ptr(); @@ -29,6 +31,7 @@ void swap_blocks( const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + // NOTE(woosuk): This can be slow if the number of blocks is large. for (const auto& pair : block_mapping) { int64_t src_block_number = pair.first; int64_t dst_block_number = pair.second; @@ -122,7 +125,9 @@ void copy_blocks( dim3 grid(num_layers, num_pairs); dim3 block(std::min(1024, numel_per_block)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { cacheflow::copy_blocks_kernel<<>>( key_cache_ptrs_tensor.data_ptr(), @@ -176,6 +181,50 @@ __global__ void reshape_and_cache_kernel( } } +} // namespace cacheflow + +void reshape_and_cache( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& slot_mapping) // [num_tokens] +{ + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(3); + int x = key_cache.size(4); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + key.scalar_type(), + "reshape_and_cache_kernel", + [&] { + cacheflow::reshape_and_cache_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + slot_mapping.data_ptr(), + key_stride, + value_stride, + num_heads, + head_size, + block_size, + x); + }); +} + +namespace cacheflow { + // Grid: (num_blocks, block_size). template __global__ void gather_cached_kv_kernel( @@ -296,45 +345,6 @@ __global__ void gather_cached_kv_kernel_optimized( } // namespace cacheflow -void reshape_and_cache( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - torch::Tensor& slot_mapping) // [num_tokens] -{ - int num_tokens = key.size(0); - int num_heads = key.size(1); - int head_size = key.size(2); - int block_size = key_cache.size(3); - int x = key_cache.size(4); - - int key_stride = key.stride(0); - int value_stride = value.stride(0); - - dim3 grid(num_tokens); - dim3 block(std::min(num_heads * head_size, 512)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - key.scalar_type(), - "reshape_and_cache_kernel", - [&] { - cacheflow::reshape_and_cache_kernel<<>>( - key.data_ptr(), - value.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - slot_mapping.data_ptr(), - key_stride, - value_stride, - num_heads, - head_size, - block_size, - x); - }); -} - - void gather_cached_kv( torch::Tensor& key, // [out] [num_tokens, num_heads, head_size] torch::Tensor& value, // [out] [num_tokens, num_heads, head_size] @@ -354,7 +364,9 @@ void gather_cached_kv( dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, key.scalar_type(), "gather_cached_kv_kernel_optimized", [&] { diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 84372ed2dd60..ba430bad2ff1 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -46,7 +46,9 @@ void rms_norm( dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, input.scalar_type(), "rms_norm_kernel", [&] { diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 527fe2cd97c8..637e233c9a9a 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -64,7 +64,9 @@ void rotary_embedding_neox( dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, query.scalar_type(), "rotary_embedding_neox", [&] { From 46936fa7b810666cf818b87fdb8747178d8db71f Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Sun, 30 Apr 2023 23:13:15 +0000 Subject: [PATCH 04/26] Add bfloat16 option --- cacheflow/master/server.py | 4 ++-- cacheflow/models/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index 311251800d22..90fdfffc421a 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -211,8 +211,8 @@ def add_server_arguments(parser: argparse.ArgumentParser): parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas') # KV cache arguments parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size') - # NOTE(woosuk): If FlashAttention is used, the float data type is not supported. - parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type') + # NOTE(woosuk): FlashAttention does not support float32. + parser.add_argument('--dtype', type=str, default='half', choices=['half', 'bfloat16'], help='data type') # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). parser.add_argument('--seed', type=int, default=0, help='random seed') parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU') diff --git a/cacheflow/models/utils.py b/cacheflow/models/utils.py index 84e7fbce6ccd..3d5240cb0c77 100644 --- a/cacheflow/models/utils.py +++ b/cacheflow/models/utils.py @@ -7,6 +7,7 @@ 'float': torch.float, 'float16': torch.float16, 'float32': torch.float32, + 'bfloat16': torch.bfloat16, } @@ -21,4 +22,3 @@ def get_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: def get_dtype_size(dtype: Union[torch.dtype, str]) -> int: torch_dtype = get_torch_dtype(dtype) return torch.tensor([], dtype=torch_dtype).element_size() - From b4bef98cd0c89ce55c412a4d25700b2a4246bc27 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Mon, 1 May 2023 09:22:52 +0000 Subject: [PATCH 05/26] [WIP] Support bfloat16 in attention kernel --- csrc/attention_kernels.cu | 86 +++++++++++++++++++++++---------------- csrc/attention_utils.h | 37 +++++++++++++++++ csrc/cuda_primitives.h | 55 +++++++++++++------------ setup.py | 13 ++++++ 4 files changed, 129 insertions(+), 62 deletions(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 1bab3d75f325..fb88f99df6b2 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -175,6 +175,7 @@ __global__ void single_query_cached_kv_attention_kernel( constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; + // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. float accs[NUM_ROWS_PER_THREAD]; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { @@ -340,21 +341,55 @@ void single_query_cached_kv_attention_launcher( LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); break; default: - assert(false); + TORCH_CHECK(false, "Unsupported head size: ", head_size); break; } } #define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ single_query_cached_kv_attention_launcher( \ - out, \ - query, \ - key_cache, \ - value_cache, \ - scale, \ - block_tables, \ - context_lens, \ - max_context_len); + out, \ + query, \ + key_cache, \ + value_cache, \ + scale, \ + block_tables, \ + context_lens, \ + max_context_len); + +#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 1: \ + CALL_KERNEL_LAUNCHER(T, 1); \ + break; \ + case 2: \ + CALL_KERNEL_LAUNCHER(T, 2); \ + break; \ + case 4: \ + CALL_KERNEL_LAUNCHER(T, 4); \ + break; \ + case 8: \ + CALL_KERNEL_LAUNCHER(T, 8); \ + break; \ + case 16: \ + CALL_KERNEL_LAUNCHER(T, 16); \ + break; \ + case 32: \ + CALL_KERNEL_LAUNCHER(T, 32); \ + break; \ + case 64: \ + CALL_KERNEL_LAUNCHER(T, 64); \ + break; \ + case 128: \ + CALL_KERNEL_LAUNCHER(T, 128); \ + break; \ + case 256: \ + CALL_KERNEL_LAUNCHER(T, 256); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } void single_query_cached_kv_attention( torch::Tensor& out, // [num_seqs, num_heads, head_size] @@ -366,33 +401,14 @@ void single_query_cached_kv_attention( torch::Tensor& context_lens, // [num_seqs] int block_size, int max_context_len) { - // TODO(woosuk): Support BF16. - if (query.element_size() == 2) { - // Half. - if (block_size == 1) { - CALL_KERNEL_LAUNCHER(uint16_t, 1); - } else if (block_size == 2) { - CALL_KERNEL_LAUNCHER(uint16_t, 2); - } else if (block_size == 4) { - CALL_KERNEL_LAUNCHER(uint16_t, 4); - } else if (block_size == 8) { - CALL_KERNEL_LAUNCHER(uint16_t, 8); - } else if (block_size == 16) { - CALL_KERNEL_LAUNCHER(uint16_t, 16); - } else if (block_size == 32) { - CALL_KERNEL_LAUNCHER(uint16_t, 32); - } else if (block_size == 64) { - CALL_KERNEL_LAUNCHER(uint16_t, 64); - } else if (block_size == 128) { - CALL_KERNEL_LAUNCHER(uint16_t, 128); - } else if (block_size == 256) { - CALL_KERNEL_LAUNCHER(uint16_t, 256); - } else { - assert(false); - } + if (query.dtype() == at::ScalarType::Half) { + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t); +#ifdef ENABLE_BF16 + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); +#endif // ENABLE_BF16 } else { - // Float. - assert(false); + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } } diff --git a/csrc/attention_utils.h b/csrc/attention_utils.h index 049555390715..1fc0d0a95bda 100644 --- a/csrc/attention_utils.h +++ b/csrc/attention_utils.h @@ -41,7 +41,26 @@ template<> struct Vec { using Type = uint4; }; +#ifdef ENABLE_BF16 +template<> +struct Vec<__nv_bfloat16, 1> { + using Type = __nv_bfloat16; +}; +template<> +struct Vec<__nv_bfloat16, 2> { + using Type = __nv_bfloat162; +}; +template<> +struct Vec<__nv_bfloat16, 4> { + using Type = bf16_4_t; +}; +template<> +struct Vec<__nv_bfloat16, 8> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 +// A vector type to store logits. template struct FloatVec {}; template<> @@ -72,6 +91,24 @@ template<> struct FloatVec { using Type = Float8_; }; +#ifdef ENABLE_BF16 +template<> +struct FloatVec<__nv_bfloat16> { + using Type = float; +}; +template<> +struct FloatVec<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct FloatVec { + using Type = Float4_; +}; +template<> +struct FloatVec { + using Type = Float8_; +}; +#endif // ENABLE_BF16 template inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) diff --git a/csrc/cuda_primitives.h b/csrc/cuda_primitives.h index 10e730fd7bda..f49d42f8fdcf 100644 --- a/csrc/cuda_primitives.h +++ b/csrc/cuda_primitives.h @@ -1268,56 +1268,57 @@ inline __device__ float convert_to_float(uint4 u) //////////////////////////////////////////////////////////////////////////////////////////////////// -// inline __device__ float cast_to_float(float u) -// { -// return u; -// } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// inline __device__ float2 cast_to_float(float2 u) -// { -// return u; -// } +inline __device__ float cast_to_float(uint16_t u) +{ + return half_to_float(u); +} //////////////////////////////////////////////////////////////////////////////////////////////////// -// inline __device__ float4 cast_to_float(float4 u) -// { -// return u; -// } +inline __device__ float2 cast_to_float(uint32_t u) +{ + return half2_to_float2(u); +} //////////////////////////////////////////////////////////////////////////////////////////////////// -// inline __device__ Float4_ cast_to_float(Float4_ u) -// { -// return u; -// } +inline __device__ Float4_ cast_to_float(uint2 u) +{ + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} //////////////////////////////////////////////////////////////////////////////////////////////////// -// inline __device__ Float8_ cast_to_float(Float8_ u) -// { -// return u; -// } +inline __device__ Float8_ cast_to_float(uint4 u) +{ + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ float cast_to_float(uint16_t u) +inline __device__ float cast_to_float(__nv_bfloat16 u) { return half_to_float(u); } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ float2 cast_to_float(uint32_t u) +inline __device__ float2 cast_to_float(__nv_bfloat162 u) { return half2_to_float2(u); } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ Float4_ cast_to_float(uint2 u) +inline __device__ Float4_ cast_to_float(bf16_4_t u) { Float4_ tmp; tmp.x = half2_to_float2(u.x); @@ -1327,7 +1328,7 @@ inline __device__ Float4_ cast_to_float(uint2 u) //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ Float8_ cast_to_float(uint4 u) +inline __device__ Float8_ cast_to_float(bf16_8_t u) { Float8_ tmp; tmp.x = half2_to_float2(u.x); diff --git a/setup.py b/setup.py index 116dad140484..6585aebc8ae5 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,22 @@ import setuptools +import torch from torch.utils import cpp_extension +if not torch.cuda.is_available(): + raise RuntimeError( + 'Cannot find CUDA. ' + 'CUDA must be available in order to build the package.') + CXX_FLAGS = ['-g'] NVCC_FLAGS = ['-O2'] +# Enable bfloat16 support if the compute capability is >= 8.0. +# TODO(woosuk): Consider the case where the machine has multiple GPUs with +# different compute capabilities. +compute_capability = torch.cuda.get_device_capability() +major, minor = compute_capability +if major >= 8: + NVCC_FLAGS.append('-DENABLE_BF16') ext_modules = [] From e1b13035812de1dd2ad719f3bbac37eb56be6bae Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 2 May 2023 05:12:59 +0000 Subject: [PATCH 06/26] Use reduced precision for attention computation --- csrc/attention_kernels.cu | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index c25acbb8be6f..1cf7560c37b5 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -1,8 +1,8 @@ #include #include +#include "attention_dtypes.h" #include "attention_utils.h" -#include "cuda_primitives.h" #include "reduction_utils.h" #include @@ -71,8 +71,8 @@ __global__ void single_query_cached_kv_attention_kernel( // Memory planning. extern __shared__ char shared_mem[]; - // NOTE(woosuk): We use FP32 logits and accumulation. - float *logits = reinterpret_cast(shared_mem); + // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. + float* logits = reinterpret_cast(shared_mem); // Workspace for reduction. __shared__ float red_smem[2 * NUM_WARPS]; @@ -145,7 +145,7 @@ __global__ void single_query_cached_kv_attention_kernel( qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); } // Broadcast the max qk value to all threads. qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); @@ -169,12 +169,13 @@ __global__ void single_query_cached_kv_attention_kernel( // Each thread will fetch 16 bytes from the value cache at a time. constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); using V_vec = typename Vec::Type; - using L_vec = typename FloatVec::Type; + using L_vec = typename Vec::Type; constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; + // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. float accs[NUM_ROWS_PER_THREAD]; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { @@ -185,7 +186,8 @@ __global__ void single_query_cached_kv_attention_kernel( const int physical_block_number = block_table[block_idx]; const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - L_vec logits_vec = *reinterpret_cast(logits + token_idx); + L_vec logits_vec; + convert_from_float(logits_vec, *reinterpret_cast(logits + token_idx)); const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE + head_idx * HEAD_SIZE * BLOCK_SIZE; @@ -195,7 +197,7 @@ __global__ void single_query_cached_kv_attention_kernel( if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; V_vec v_vec = *reinterpret_cast(v_ptr + offset); - accs[i] += dot(logits_vec, cast_to_float(v_vec)); + accs[i] += dot(logits_vec, v_vec); } } } @@ -307,7 +309,7 @@ void single_query_cached_kv_attention_launcher( constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = padded_max_context_len * sizeof(float); + int logits_size = padded_max_context_len * sizeof(T); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); int shared_mem_size = std::max(logits_size, outputs_size); From b576e465d21f1d1b1b4869668df57a48c45eca97 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 2 May 2023 05:19:12 +0000 Subject: [PATCH 07/26] Code cleaning --- csrc/attention_kernels.cu | 577 ++++---------------------------------- 1 file changed, 48 insertions(+), 529 deletions(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 1cf7560c37b5..8fe5b97d69f7 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -342,21 +342,55 @@ void single_query_cached_kv_attention_launcher( LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); break; default: - assert(false); + TORCH_CHECK(false, "Unsupported head size: ", head_size); break; } } #define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ single_query_cached_kv_attention_launcher( \ - out, \ - query, \ - key_cache, \ - value_cache, \ - scale, \ - block_tables, \ - context_lens, \ - max_context_len); + out, \ + query, \ + key_cache, \ + value_cache, \ + scale, \ + block_tables, \ + context_lens, \ + max_context_len); + +#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 1: \ + CALL_KERNEL_LAUNCHER(T, 1); \ + break; \ + case 2: \ + CALL_KERNEL_LAUNCHER(T, 2); \ + break; \ + case 4: \ + CALL_KERNEL_LAUNCHER(T, 4); \ + break; \ + case 8: \ + CALL_KERNEL_LAUNCHER(T, 8); \ + break; \ + case 16: \ + CALL_KERNEL_LAUNCHER(T, 16); \ + break; \ + case 32: \ + CALL_KERNEL_LAUNCHER(T, 32); \ + break; \ + case 64: \ + CALL_KERNEL_LAUNCHER(T, 64); \ + break; \ + case 128: \ + CALL_KERNEL_LAUNCHER(T, 128); \ + break; \ + case 256: \ + CALL_KERNEL_LAUNCHER(T, 256); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } void single_query_cached_kv_attention( torch::Tensor& out, // [num_seqs, num_heads, head_size] @@ -369,530 +403,15 @@ void single_query_cached_kv_attention( int block_size, int max_context_len) { // TODO(woosuk): Support BF16. - if (query.element_size() == 2) { - // Half. - if (block_size == 1) { - CALL_KERNEL_LAUNCHER(uint16_t, 1); - } else if (block_size == 2) { - CALL_KERNEL_LAUNCHER(uint16_t, 2); - } else if (block_size == 4) { - CALL_KERNEL_LAUNCHER(uint16_t, 4); - } else if (block_size == 8) { - CALL_KERNEL_LAUNCHER(uint16_t, 8); - } else if (block_size == 16) { - CALL_KERNEL_LAUNCHER(uint16_t, 16); - } else if (block_size == 32) { - CALL_KERNEL_LAUNCHER(uint16_t, 32); - } else if (block_size == 64) { - CALL_KERNEL_LAUNCHER(uint16_t, 64); - } else if (block_size == 128) { - CALL_KERNEL_LAUNCHER(uint16_t, 128); - } else if (block_size == 256) { - CALL_KERNEL_LAUNCHER(uint16_t, 256); - } else { - assert(false); - } + if (query.dtype() == at::ScalarType::Half) { + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t); + } else if (query.dtype() == at::ScalarType::Float) { + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float); } else { - // Float. - assert(false); + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } } -// namespace cacheflow { - -// // Grid: (num_heads, num_query_tokens). -// template< -// typename scalar_t, -// int HEAD_SIZE, -// int BLOCK_SIZE, -// int NUM_THREADS> -// __device__ void multi_query_cached_kv_attention_kernel_unoptimized_( -// scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] -// const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] -// const int seq_start_idx, -// const int seq_len, -// const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] -// const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] -// const float scale, -// const int* __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq] -// const int context_len, -// const int max_num_blocks_per_seq, -// const int q_stride) { -// constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE; -// constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; -// const int thread_idx = threadIdx.x; -// const int warp_idx = thread_idx / WARP_SIZE; -// const int lane = thread_idx % WARP_SIZE; - -// const int head_idx = blockIdx.x; -// const int num_heads = gridDim.x; -// const int seq_idx = blockIdx.y; - -// // A vector type to store a part of a key or a query. -// // The vector size is configured in such a way that the threads in a thread group -// // fetch or comput 16 bytes at a time. -// // For example, if the size of a thread group is 4 and the data type is half, -// // then the vector size is 16 / (4 * sizeof(half)) == 2. -// constexpr int VEC_SIZE = 16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)); -// using K_vec = typename Vec::Type; -// using Q_vec = typename Vec::Type; - -// constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; -// constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; - -// const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; -// const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; - -// // Load the query to registers. -// // Each thread in a thread group has a different part of the query. -// // For example, if the the thread group size is 4, then the first thread in the group -// // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... -// // th vectors of the query, and so on. -// // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. -// const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; -// Q_vec q_vecs[NUM_VECS_PER_THREAD]; -// #pragma unroll -// for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { -// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; -// q_vecs[i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); -// } - -// // Memory planning. -// extern __shared__ char shared_mem[]; -// // NOTE(woosuk): We use FP32 logits and accumulation. -// float *logits = reinterpret_cast(shared_mem); -// // Workspace for reduction. -// __shared__ float red_smem[2 * NUM_WARPS]; - -// // x == THREAD_GROUP_SIZE * VEC_SIZE -// // Each thread group fetches x elements from the key at a time. -// constexpr int x = 16 / sizeof(scalar_t); -// float qk_max = -FLT_MAX; - -// const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; -// const int mask_boundary = context_len - seq_len + 1 + (seq_idx - seq_start_idx); - -// // Iterate over the key blocks. -// // Each warp fetches a block of keys for each iteration. -// // Each thread group in a warp fetches a key from the block, and computes -// // dot product with the query. -// for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { -// const int physical_block_number = block_table[block_idx]; -// const int physical_block_offset = thread_group_idx % BLOCK_SIZE; -// const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - -// // Load a key to registers. -// // Each thread in a thread group has a different part of the key. -// // For example, if the the thread group size is 4, then the first thread in the group -// // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th -// // vectors of the key, and so on. -// K_vec k_vecs[NUM_VECS_PER_THREAD]; -// #pragma unroll -// for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { -// const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE -// + head_idx * HEAD_SIZE * BLOCK_SIZE -// + physical_block_offset * x; -// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; -// const int offset1 = (vec_idx * VEC_SIZE) / x; -// const int offset2 = (vec_idx * VEC_SIZE) % x; -// k_vecs[i] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); -// } - -// // Compute dot product. -// // This includes a reduction across the threads in the same thread group. -// const float qk = scale * Qk_dot::dot(q_vecs, k_vecs); -// const bool mask = token_idx >= mask_boundary; - -// if (thread_group_offset == 0) { -// // Store the partial reductions to shared memory. -// // NOTE(woosuk): It is required to zero out the masked logits. -// logits[token_idx] = mask ? 0.f : qk; -// // Update the max value. -// qk_max = mask ? qk_max : fmaxf(qk_max, qk); -// } -// } - -// // Perform reduction across the threads in the same warp to get the -// // max qk value for each "warp" (not across the thread block yet). -// // The 0-th thread of each thread group already has its max qk value. -// #pragma unroll -// for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { -// qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); -// } -// if (lane == 0) { -// red_smem[warp_idx] = qk_max; -// } -// __syncthreads(); - -// // TODO(woosuk): Refactor this part. -// // Get the max qk value for the sequence. -// qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; -// #pragma unroll -// for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { -// qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); -// } -// // Broadcast the max qk value to all threads. -// qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); - -// // Get the sum of the exp values. -// float exp_sum = 0.f; -// for (int i = thread_idx; i < mask_boundary; i += NUM_THREADS) { -// float val = __expf(logits[i] - qk_max); -// logits[i] = val; -// exp_sum += val; -// } -// exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); - -// // Compute softmax. -// const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); -// for (int i = thread_idx; i < context_len; i += NUM_THREADS) { -// logits[i] *= inv_sum; -// } -// __syncthreads(); - -// // Each thread will fetch 16 bytes from the value cache at a time. -// constexpr int V_VEC_SIZE = 16 / sizeof(scalar_t); -// using V_vec = typename Vec::Type; -// using L_vec = typename FloatVec::Type; - -// constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; -// constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; -// constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; - -// float accs[NUM_ROWS_PER_THREAD]; -// #pragma unroll -// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -// accs[i] = 0.f; -// } - -// for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { -// const int physical_block_number = block_table[block_idx]; -// const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; -// const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; -// L_vec logits_vec = *reinterpret_cast(logits + token_idx); - -// const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE -// + head_idx * HEAD_SIZE * BLOCK_SIZE; -// #pragma unroll -// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; -// if (row_idx < HEAD_SIZE) { -// const int offset = row_idx * BLOCK_SIZE + physical_block_offset; -// V_vec v_vec = *reinterpret_cast(v_ptr + offset); -// accs[i] += dot(logits_vec, cast_to_float(v_vec)); -// } -// } -// } - -// // Perform reduction within each warp. -// #pragma unroll -// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -// float acc = accs[i]; -// #pragma unroll -// for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { -// acc += __shfl_xor_sync(uint32_t(-1), acc, mask); -// } -// accs[i] = acc; -// } - -// // NOTE(woosuk): A barrier is required because the shared memory space for logits -// // is reused for the output. -// __syncthreads(); - -// // Perform reduction across warps. -// float* out_smem = reinterpret_cast(shared_mem); -// #pragma unroll -// for (int i = NUM_WARPS; i > 1; i /= 2) { -// int mid = i / 2; -// // Upper warps write to shared memory. -// if (warp_idx >= mid && warp_idx < i) { -// float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; -// #pragma unroll -// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; -// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { -// dst[row_idx] = accs[i]; -// } -// } -// } -// __syncthreads(); - -// // Lower warps update the output. -// if (warp_idx < mid) { -// const float* src = &out_smem[warp_idx * HEAD_SIZE]; -// #pragma unroll -// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; -// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { -// accs[i] += src[row_idx]; -// } -// } -// } -// __syncthreads(); -// } - -// // Write the final output. -// if (warp_idx == 0) { -// scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; -// #pragma unroll -// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; -// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { -// convert_from_float(*(out_ptr + row_idx), accs[i]); -// } -// } -// } -// } - - -// // Grid: (num_heads, num_query_tokens). -// template< -// typename scalar_t, -// int HEAD_SIZE, -// int BLOCK_SIZE, -// int NUM_THREADS> -// __global__ void multi_query_cached_kv_attention_kernel( -// const int* cu_query_lens, // [num_prompts+1] -// const int* seq_prompt_mapping, // [num_seqs] mapping from seq_idx to prompt_idx -// scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] -// const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] -// const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] -// const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] -// const float scale, -// const int* __restrict__ block_tables, // [num_prompts, max_num_blocks_per_seq] -// const int* __restrict__ context_lens, // [num_prompts] -// const int max_num_blocks_per_seq, -// const int q_stride) { -// const int seq_idx = blockIdx.y; -// const int prompt_idx = seq_prompt_mapping[seq_idx]; -// const int seq_start_idx = cu_query_lens[prompt_idx]; -// const int seq_len = cu_query_lens[prompt_idx + 1] - seq_start_idx; -// const int* block_table = block_tables + prompt_idx * max_num_blocks_per_seq; -// const int context_len = context_lens[prompt_idx]; -// multi_query_cached_kv_attention_kernel_unoptimized_< -// scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>( -// out, -// q, -// seq_start_idx, -// seq_len, -// k_cache, -// v_cache, -// scale, -// block_table, -// context_len, -// max_num_blocks_per_seq, -// q_stride); -// } - -// } // namespace cacheflow - -// #define LAUNCH_MULTI_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ -// cacheflow::multi_query_cached_kv_attention_kernel \ -// <<>>( \ -// cu_query_lens_ptr, \ -// seq_prompt_mapping_ptr, \ -// out_ptr, \ -// query_ptr, \ -// key_cache_ptr, \ -// value_cache_ptr, \ -// scale, \ -// block_tables_ptr, \ -// context_lens_ptr, \ -// max_num_blocks_per_seq, \ -// query_stride); - - -// // TODO(woosuk): Tune NUM_THREADS. -// template< -// typename T, -// int BLOCK_SIZE, -// int NUM_THREADS = 128> -// void multi_query_cached_kv_attention_launcher( -// torch::Tensor& cu_query_lens, -// torch::Tensor& seq_prompt_mapping, -// torch::Tensor& out, -// torch::Tensor& query, -// torch::Tensor& key_cache, -// torch::Tensor& value_cache, -// float scale, -// torch::Tensor& block_tables, -// torch::Tensor& context_lens, -// int max_context_len) { -// int num_seqs = query.size(0); -// int num_heads = query.size(1); -// int head_size = query.size(2); -// int max_num_blocks_per_seq = block_tables.size(1); -// int query_stride = query.stride(0); - -// int* cu_query_lens_ptr = cu_query_lens.data_ptr(); -// int* seq_prompt_mapping_ptr = seq_prompt_mapping.data_ptr(); -// T* out_ptr = reinterpret_cast(out.data_ptr()); -// T* query_ptr = reinterpret_cast(query.data_ptr()); -// T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); -// T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); -// int* block_tables_ptr = block_tables.data_ptr(); -// int* context_lens_ptr = context_lens.data_ptr(); - -// constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; -// int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; -// int logits_size = padded_max_context_len * sizeof(float); -// int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); -// int shared_mem_size = std::max(logits_size, outputs_size); - -// dim3 grid(num_heads, num_seqs); -// dim3 block(NUM_THREADS); -// const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); -// switch (head_size) { -// case 32: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); -// break; -// case 64: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); -// break; -// case 80: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS); -// break; -// case 96: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); -// break; -// case 128: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); -// break; -// case 160: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); -// break; -// case 192: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); -// break; -// case 256: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); -// break; -// default: -// assert(false); -// break; -// } -// } - -// void multi_query_cached_kv_attention( -// torch::Tensor& cu_query_lens, -// torch::Tensor& out, -// torch::Tensor& query, -// torch::Tensor& key_cache, -// torch::Tensor& value_cache, -// float scale, -// torch::Tensor& block_tables, -// torch::Tensor& context_lens, -// int block_size, -// int max_context_len) { - -// torch::Tensor query_lens = cu_query_lens.to(torch::kCPU); - -// int num_queries = query_lens.size(0) - 1; -// const int* query_lens_ptr = query_lens.data_ptr(); -// int num_seqs = query.size(0); - -// torch::Tensor cpu_tensor = torch::empty({num_seqs}, torch::dtype(torch::kInt32)); -// auto accessor = cpu_tensor.accessor(); -// for (int i = 0, query_cursor = 0; i < num_seqs; ++i) { -// if (i >= query_lens_ptr[query_cursor + 1]) { -// ++query_cursor; -// } -// accessor[i] = query_cursor; -// } - -// // TODO(suquark): This can be slow, as it to(torch::kCPU) and to(torch::kCUDA) -// // implicitly synchronizes the CPU and GPU. And we can avoid this issue by giving -// // the mapping as an input parameter. Let's do this optimization in a later PR. -// torch::Tensor seq_prompt_mapping = cpu_tensor.to(torch::kCUDA); - -// // TODO(woosuk): Support BF16. -// if (query.element_size() == 2) { -// // Half. -// if (block_size == 8) { -// multi_query_cached_kv_attention_launcher( -// cu_query_lens, -// seq_prompt_mapping, -// out, -// query, -// key_cache, -// value_cache, -// scale, -// block_tables, -// context_lens, -// max_context_len); -// } else if (block_size == 16) { -// multi_query_cached_kv_attention_launcher( -// cu_query_lens, -// seq_prompt_mapping, -// out, -// query, -// key_cache, -// value_cache, -// scale, -// block_tables, -// context_lens, -// max_context_len); -// } else if (block_size == 32) { -// multi_query_cached_kv_attention_launcher( -// cu_query_lens, -// seq_prompt_mapping, -// out, -// query, -// key_cache, -// value_cache, -// scale, -// block_tables, -// context_lens, -// max_context_len); -// } else { -// assert(false); -// } -// } else if (query.element_size() == 4) { -// // Float. -// if (block_size == 8) { -// multi_query_cached_kv_attention_launcher( -// cu_query_lens, -// seq_prompt_mapping, -// out, -// query, -// key_cache, -// value_cache, -// scale, -// block_tables, -// context_lens, -// max_context_len); -// } else if (block_size == 16) { -// multi_query_cached_kv_attention_launcher( -// cu_query_lens, -// seq_prompt_mapping, -// out, -// query, -// key_cache, -// value_cache, -// scale, -// block_tables, -// context_lens, -// max_context_len); -// } else if (block_size == 32) { -// multi_query_cached_kv_attention_launcher( -// cu_query_lens, -// seq_prompt_mapping, -// out, -// query, -// key_cache, -// value_cache, -// scale, -// block_tables, -// context_lens, -// max_context_len); -// } else { -// assert(false); -// } -// } else { -// assert(false); -// } -// } - #undef WARP_SIZE #undef MAX #undef MIN From 8bdb09c38bc0ac30d8e714421c73e67033e9d7e0 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 2 May 2023 05:22:24 +0000 Subject: [PATCH 08/26] Create attention dir --- csrc/{ => attention}/attention_kernels.cu | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename csrc/{ => attention}/attention_kernels.cu (100%) diff --git a/csrc/attention_kernels.cu b/csrc/attention/attention_kernels.cu similarity index 100% rename from csrc/attention_kernels.cu rename to csrc/attention/attention_kernels.cu From 4b1013640f250869bb8dbd87948bad0388b16be6 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 2 May 2023 05:22:46 +0000 Subject: [PATCH 09/26] Refactor attention_utils --- csrc/attention_utils.h | 81 +++++++----------------------------------- 1 file changed, 12 insertions(+), 69 deletions(-) diff --git a/csrc/attention_utils.h b/csrc/attention_utils.h index 049555390715..71f936125162 100644 --- a/csrc/attention_utils.h +++ b/csrc/attention_utils.h @@ -1,11 +1,10 @@ #pragma once -#include "cuda_primitives.h" +#include "attention_dtypes.h" #include #include -#define MMHA_USE_FP32_ACUM_FOR_FMA #define MMHA_USE_FP32_ACUM_FOR_OUT namespace cacheflow { @@ -42,6 +41,7 @@ struct Vec { using Type = uint4; }; +// A vector type to store accumulators. template struct FloatVec {}; template<> @@ -73,12 +73,13 @@ struct FloatVec { using Type = Float8_; }; -template -inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) +// Q*K^T operation. +template +inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { - using K_vec_acum = typename FloatVec::Type; + using A_vec = typename FloatVec::Type; // Compute the parallel products for Q*K^T (treat vector lanes separately). - K_vec_acum qk_vec = mul(q[0], k[0]); + A_vec qk_vec = mul(q[0], k[0]); #pragma unroll for (int ii = 1; ii < N; ++ii) { qk_vec = fma(q[ii], k[ii], qk_vec); @@ -87,79 +88,21 @@ inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) // Finalize the reduction across lanes. float qk = sum(qk_vec); #pragma unroll - for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { + for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { qk += __shfl_xor_sync(uint32_t(-1), qk, mask); } return qk; } -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template +template struct Qk_dot { - template - static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) - { - return qk_dot_(q, k); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b) -{ - float4 c; - float zero = 0.f; - asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" - " {%0, %1, %2, %3}, \n" - " {%4, %5}, \n" - " {%6}, \n" - " {%7, %7, %7, %7}; \n" - - : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) - : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N]) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 - using K_vec_acum = typename FloatVec::Type; - K_vec_acum qk_vec = mul(q[0], k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); - } -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - uint32_t qk_vec_ = float2_to_half2(qk_vec); - return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; -#else - return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x; -#endif -#else - return 0.f; -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Qk_dot { - template - static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N]) + template + static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { -#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION) - return qk_hmma_dot_(q, k); -#else - return qk_dot_<4>(q, k); -#endif // defined MMHA_USE_HMMA_FOR_REDUCTION + return qk_dot_(q, k); } }; } // namespace cacheflow -#undef MMHA_USE_FP32_ACUM_FOR_FMA #undef MMHA_USE_FP32_ACUM_FOR_OUT From c25cea4e6b82a4a17f40b1a29db3805f728059dd Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 2 May 2023 05:23:58 +0000 Subject: [PATCH 10/26] Move --- csrc/{ => attention}/attention_utils.h | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename csrc/{ => attention}/attention_utils.h (100%) diff --git a/csrc/attention_utils.h b/csrc/attention/attention_utils.h similarity index 100% rename from csrc/attention_utils.h rename to csrc/attention/attention_utils.h From 8ad2170b34932343b0349c649329961842cafdf1 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 2 May 2023 05:26:36 +0000 Subject: [PATCH 11/26] Move blocksum to attention --- csrc/attention/attention_kernels.cu | 39 ++++++++++++++++++++++++-- csrc/reduction_utils.h | 43 +++-------------------------- 2 files changed, 41 insertions(+), 41 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 8fe5b97d69f7..f59aa9706e46 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -2,8 +2,7 @@ #include #include "attention_dtypes.h" -#include "attention_utils.h" -#include "reduction_utils.h" +#include "attention/attention_utils.h" #include @@ -13,6 +12,42 @@ namespace cacheflow { +// Utility function for attention softmax. +template +inline __device__ float block_sum(float* red_smem, float sum) { + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < WARPS_PER_BLOCK) { + sum = red_smem[lane]; + } + + // Parallel reduction inside the warp. +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Broadcast to other threads. + return __shfl_sync(uint32_t(-1), sum, 0); +} + // Grid: (num_heads, num_seqs). template< typename scalar_t, diff --git a/csrc/reduction_utils.h b/csrc/reduction_utils.h index f977ab70f1fe..5eb3e33a8c4e 100644 --- a/csrc/reduction_utils.h +++ b/csrc/reduction_utils.h @@ -1,46 +1,9 @@ #pragma once -namespace cacheflow { - -template -inline __device__ float block_sum(float* red_smem, float sum) -{ - - // Decompose the thread index into warp / lane. - int warp = threadIdx.x / WARP_SIZE; - int lane = threadIdx.x % WARP_SIZE; - - // Compute the sum per warp. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - // Warp leaders store the data to shared memory. - if (lane == 0) { - red_smem[warp] = sum; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The warps compute the final sums. - if (lane < WARPS_PER_BLOCK) { - sum = red_smem[lane]; - } - -// Parallel reduction inside the warp. -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - // Broadcast to other threads. - return __shfl_sync(uint32_t(-1), sum, 0); -} - #define FINAL_MASK 0xffffffff +namespace cacheflow { + template __inline__ __device__ T warpReduceSum(T val) { @@ -74,3 +37,5 @@ __inline__ __device__ T blockReduceSum(T val) } } // namespace cacheflow + +#undef FINAL_MASK From 3a2337b0f5a4c84697ed76173162fe70dcadda3b Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 2 May 2023 07:22:17 +0000 Subject: [PATCH 12/26] Refactor attention kernel --- csrc/attention/attention_dtypes.cuh | 5 + csrc/attention/attention_generic.cuh | 39 ++ csrc/attention/attention_kernels.cu | 21 +- ...{attention_utils.h => attention_utils.cuh} | 6 +- csrc/attention/dtype_float16.cuh | 390 ++++++++++++++++++ csrc/attention/dtype_float32.cuh | 222 ++++++++++ setup.py | 2 +- 7 files changed, 668 insertions(+), 17 deletions(-) create mode 100644 csrc/attention/attention_dtypes.cuh create mode 100644 csrc/attention/attention_generic.cuh rename csrc/attention/{attention_utils.h => attention_utils.cuh} (95%) create mode 100644 csrc/attention/dtype_float16.cuh create mode 100644 csrc/attention/dtype_float32.cuh diff --git a/csrc/attention/attention_dtypes.cuh b/csrc/attention/attention_dtypes.cuh new file mode 100644 index 000000000000..1d586ddf7522 --- /dev/null +++ b/csrc/attention/attention_dtypes.cuh @@ -0,0 +1,5 @@ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float16.cuh" +#include "dtype_float32.cuh" diff --git a/csrc/attention/attention_generic.cuh b/csrc/attention/attention_generic.cuh new file mode 100644 index 000000000000..b7948981d233 --- /dev/null +++ b/csrc/attention/attention_generic.cuh @@ -0,0 +1,39 @@ +#pragma once + +#include + +namespace cacheflow { + +// Template vector operations. +template +inline __device__ Acc mul(A a, B b); + +template +inline __device__ float sum(T v); + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ void zero(T& dst) { + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; + +#pragma unroll + for (int ii = 0; ii < WORDS; ++ii) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + +} // namespace cacheflow diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index f59aa9706e46..a4bd6aeb6867 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -1,8 +1,8 @@ #include #include -#include "attention_dtypes.h" -#include "attention/attention_utils.h" +#include "attention_dtypes.cuh" +#include "attention_utils.cuh" #include @@ -13,7 +13,7 @@ namespace cacheflow { // Utility function for attention softmax. -template +template inline __device__ float block_sum(float* red_smem, float sum) { // Decompose the thread index into warp / lane. int warp = threadIdx.x / WARP_SIZE; @@ -34,13 +34,13 @@ inline __device__ float block_sum(float* red_smem, float sum) { __syncthreads(); // The warps compute the final sums. - if (lane < WARPS_PER_BLOCK) { + if (lane < NUM_WARPS) { sum = red_smem[lane]; } // Parallel reduction inside the warp. #pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { sum += __shfl_xor_sync(uint32_t(-1), sum, mask); } @@ -205,6 +205,7 @@ __global__ void single_query_cached_kv_attention_kernel( constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); using V_vec = typename Vec::Type; using L_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; @@ -222,7 +223,7 @@ __global__ void single_query_cached_kv_attention_kernel( const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; L_vec logits_vec; - convert_from_float(logits_vec, *reinterpret_cast(logits + token_idx)); + from_float(logits_vec, *reinterpret_cast(logits + token_idx)); const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE + head_idx * HEAD_SIZE * BLOCK_SIZE; @@ -291,7 +292,7 @@ __global__ void single_query_cached_kv_attention_kernel( for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - convert_from_float(*(out_ptr + row_idx), accs[i]); + from_float(*(out_ptr + row_idx), accs[i]); } } } @@ -383,7 +384,7 @@ void single_query_cached_kv_attention_launcher( } #define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - single_query_cached_kv_attention_launcher( \ + single_query_cached_kv_attention_launcher( \ out, \ query, \ key_cache, \ @@ -437,11 +438,9 @@ void single_query_cached_kv_attention( torch::Tensor& context_lens, // [num_seqs] int block_size, int max_context_len) { - // TODO(woosuk): Support BF16. + // TODO(woosuk): Support FP32 and BF16. if (query.dtype() == at::ScalarType::Half) { CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t); - } else if (query.dtype() == at::ScalarType::Float) { - CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float); } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } diff --git a/csrc/attention/attention_utils.h b/csrc/attention/attention_utils.cuh similarity index 95% rename from csrc/attention/attention_utils.h rename to csrc/attention/attention_utils.cuh index 71f936125162..5664da2d23fd 100644 --- a/csrc/attention/attention_utils.h +++ b/csrc/attention/attention_utils.cuh @@ -1,12 +1,10 @@ #pragma once -#include "attention_dtypes.h" +#include "attention_dtypes.cuh" #include #include -#define MMHA_USE_FP32_ACUM_FOR_OUT - namespace cacheflow { // A vector type to store Q, K, V elements. @@ -104,5 +102,3 @@ struct Qk_dot { }; } // namespace cacheflow - -#undef MMHA_USE_FP32_ACUM_FOR_OUT diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh new file mode 100644 index 000000000000..d37544e648e0 --- /dev/null +++ b/csrc/attention/dtype_float16.cuh @@ -0,0 +1,390 @@ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float32.cuh" + +#include + +namespace cacheflow { + +// Utility functions for type conversions. +inline __device__ uint32_t h0_h0(uint16_t a) { + uint32_t b; + asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); + return b; +} + +inline __device__ float half_to_float(uint16_t h) { + float f; + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); + return f; +} + +inline __device__ float2 half2_to_float2(uint32_t v) { + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); + return make_float2(half_to_float(lo), half_to_float(hi)); +} + +inline __device__ uint16_t float_to_half(float f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); + return tmp.u16[0]; +} + +inline __device__ uint32_t float2_to_half2(float2 f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); +#else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); +#endif + return tmp.u32; +} + +// Vector addition. +inline __device__ uint16_t add(uint16_t a, uint16_t b) { + uint16_t c; + asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; +} + +inline __device__ uint32_t add(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +inline __device__ uint2 add(uint2 a, uint2 b) { + uint2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ uint4 add(uint4 a, uint4 b) { + uint4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float2 add(uint32_t a, float2 fb) { + float2 fa = half2_to_float2(a); + return add(fa, fb); +} + +inline __device__ Float4_ add(uint2 a, Float4_ fb) { + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +inline __device__ Float8_ add(uint4 a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +// Vector multiplication. +template<> +inline __device__ uint16_t mul(uint16_t a, uint16_t b) { + uint16_t c; + asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; +} + +template<> +inline __device__ uint32_t mul(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +template<> +inline __device__ uint32_t mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +template<> +inline __device__ uint2 mul(uint2 a, uint2 b) { + uint2 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template<> +inline __device__ uint2 mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + uint2 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + return c; +} + +template<> +inline __device__ uint4 mul(uint4 a, uint4 b) { + uint4 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + c.z = mul(a.z, b.z); + c.w = mul(a.w, b.w); + return c; +} + +template<> +inline __device__ uint4 mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + uint4 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + c.z = mul(s, b.z); + c.w = mul(s, b.w); + return c; +} + +template<> +inline __device__ float mul(uint16_t a, uint16_t b) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb; +} + +template<> +inline __device__ float2 mul(uint32_t a, uint32_t b) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return mul(fa, fb); +} + +template<> +inline __device__ float2 mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +template<> +inline __device__ Float4_ mul(uint2 a, uint2 b) { + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +template<> +inline __device__ Float4_ mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +template<> +inline __device__ Float8_ mul(uint4 a, uint4 b) { + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +template<> +inline __device__ Float8_ mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +// Vector fused multiply-add. +inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); + return d; +} + +inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { + return fma(h0_h0(a), b, c); +} + +inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { + uint2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { + uint32_t s = h0_h0(a); + uint2 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { + uint4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { + uint32_t s = h0_h0(a); + uint4 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float fma(uint16_t a, uint16_t b, float fc) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb + fc; +} + +inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return fma(fa, fb, fc); +} + +inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { + return fma(h0_h0(a), b, fc); +} + +inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) { + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) { + uint32_t s = h0_h0(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) { + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { + uint32_t s = h0_h0(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +// Vector sum. +template<> +inline __device__ float sum(uint16_t v) { + return half_to_float(v); +} + +template<> +inline __device__ float sum(uint32_t v) { + float2 tmp = half2_to_float2(v); + return tmp.x + tmp.y; +} + +template<> +inline __device__ float sum(uint2 v) { + uint32_t c = add(v.x, v.y); + return sum(c); +} + +template<> +inline __device__ float sum(uint4 v) { + uint32_t c = add(v.x, v.y); + c = add(c, v.z); + c = add(c, v.w); + return sum(c); +} + +// Zero-out a vector. +inline __device__ void zero(uint16_t& dst) { + dst = uint16_t(0); +} + +// From float32 to float16. +inline __device__ void from_float(uint16_t& dst, float src) { + dst = float_to_half(src); +} + +inline __device__ void from_float(uint32_t& dst, float2 src) { + dst = float2_to_half2(src); +} + +inline __device__ void from_float(uint2& dst, Float4_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); +} + +inline __device__ void from_float(uint4& dst, Float8_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +// From float16 to float32. +inline __device__ float to_float(uint16_t u) { + return half_to_float(u); +} + +inline __device__ float2 to_float(uint32_t u) { + return half2_to_float2(u); +} + +inline __device__ Float4_ to_float(uint2 u) { + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} + +inline __device__ Float8_ to_float(uint4 u) { + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +} // namespace cacheflow diff --git a/csrc/attention/dtype_float32.cuh b/csrc/attention/dtype_float32.cuh new file mode 100644 index 000000000000..e51888cc1397 --- /dev/null +++ b/csrc/attention/dtype_float32.cuh @@ -0,0 +1,222 @@ +#pragma once + +#include "attention_generic.cuh" + +#include + +namespace cacheflow { + +// Float vector types. +struct Float4_ { + float2 x; + float2 y; +}; + +struct Float8_ { + float2 x; + float2 y; + float2 z; + float2 w; +}; + +// Vector addition. +inline __device__ float add(float a, float b) { + return a + b; +} + +inline __device__ float2 add(float2 a, float2 b) { + float2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ float4 add(float4 a, float4 b) { + float4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +// Vector multiplication. +template<> +inline __device__ float mul(float a, float b) { + return a * b; +} + +template<> +inline __device__ float2 mul(float2 a, float2 b) { + float2 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template<> +inline __device__ float2 mul(float a, float2 b) { + float2 c; + c.x = a * b.x; + c.y = a * b.y; + return c; +} + +template<> +inline __device__ float4 mul(float4 a, float4 b) { + float4 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + c.z = a.z * b.z; + c.w = a.w * b.w; + return c; +} + +template<> +inline __device__ float4 mul(float a, float4 b) { + float4 c; + c.x = a * b.x; + c.y = a * b.y; + c.z = a * b.z; + c.w = a * b.w; + return c; +} + +// Vector fused multiply-add. +inline __device__ float fma(float a, float b, float c) { + return a * b + c; +} + +inline __device__ float2 fma(float2 a, float2 b, float2 c) { + float2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ float2 fma(float a, float2 b, float2 c) { + float2 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ float4 fma(float4 a, float4 b, float4 c) { + float4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ float4 fma(float a, float4 b, float4 c) { + float4 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) { + Float4_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { + Float8_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +// Vector sum. +template<> +inline __device__ float sum(float v) { + return v; +} + +template<> +inline __device__ float sum(float2 v) { + return v.x + v.y; +} + +template<> +inline __device__ float sum(float4 v) { + return v.x + v.y + v.z + v.w; +} + +template<> +inline __device__ float sum(Float4_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y; +} + +template<> +inline __device__ float sum(Float8_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; +} + +// Vector dot product. +inline __device__ float dot(float a, float b) { + return a * b; +} + +inline __device__ float dot(float2 a, float2 b) { + float2 c = mul(a, b); + return c.x + c.y; +} + +inline __device__ float dot(Float4_ a, Float4_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + return acc.x + acc.y; +} + +inline __device__ float dot(Float8_ a, Float8_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + acc = fma(a.z, b.z, acc); + acc = fma(a.w, b.w, acc); + return acc.x + acc.y; +} + +// From float to float. +inline __device__ void from_float(float& dst, float src) { + dst = src; +} + +inline __device__ void from_float(float2& dst, float2 src) { + dst = src; +} + +inline __device__ void from_float(float4& dst, float4 src) { + dst = src; +} + +// From float to float. +inline __device__ float to_float(float u) { + return u; +} + +inline __device__ float2 to_float(float2 u) { + return u; +} + +inline __device__ float4 to_float(float4 u) { + return u; +} + +inline __device__ Float4_ to_float(Float4_ u) { + return u; +} + +inline __device__ Float8_ to_float(Float8_ u) { + return u; +} + +} // namespace cacheflow diff --git a/setup.py b/setup.py index e96c73033379..bac0b0f18c74 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ # Attention kernels. attention_extension = cpp_extension.CUDAExtension( name='cacheflow.attention_ops', - sources=['csrc/attention.cpp', 'csrc/attention_kernels.cu'], + sources=['csrc/attention.cpp', 'csrc/attention/attention_kernels.cu'], extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS}, ) ext_modules.append(attention_extension) From f299f32ea7e9c9fe2655164f59008e9f663058d2 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 2 May 2023 07:30:45 +0000 Subject: [PATCH 13/26] Move vector data types --- csrc/attention/attention_generic.cuh | 8 ++++ csrc/attention/attention_utils.cuh | 64 ---------------------------- csrc/attention/dtype_float16.cuh | 36 ++++++++++++++++ csrc/attention/dtype_float32.cuh | 30 ++++++++++++- 4 files changed, 73 insertions(+), 65 deletions(-) diff --git a/csrc/attention/attention_generic.cuh b/csrc/attention/attention_generic.cuh index b7948981d233..799f873f462a 100644 --- a/csrc/attention/attention_generic.cuh +++ b/csrc/attention/attention_generic.cuh @@ -4,6 +4,14 @@ namespace cacheflow { +// A vector type to store Q, K, V elements. +template +struct Vec {}; + +// A vector type to store FP32 accumulators. +template +struct FloatVec {}; + // Template vector operations. template inline __device__ Acc mul(A a, B b); diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index 5664da2d23fd..b55a45fb4716 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -7,70 +7,6 @@ namespace cacheflow { -// A vector type to store Q, K, V elements. -template -struct Vec {}; -template<> -struct Vec { - using Type = float; -}; -template<> -struct Vec { - using Type = float2; -}; -template<> -struct Vec { - using Type = float4; -}; -template<> -struct Vec { - using Type = uint16_t; -}; -template<> -struct Vec { - using Type = uint32_t; -}; -template<> -struct Vec { - using Type = uint2; -}; -template<> -struct Vec { - using Type = uint4; -}; - -// A vector type to store accumulators. -template -struct FloatVec {}; -template<> -struct FloatVec { - using Type = float; -}; -template<> -struct FloatVec { - using Type = float2; -}; -template<> -struct FloatVec { - using Type = float4; -}; -template<> -struct FloatVec { - using Type = float; -}; -template<> -struct FloatVec { - using Type = float2; -}; -template<> -struct FloatVec { - using Type = Float4_; -}; -template<> -struct FloatVec { - using Type = Float8_; -}; - // Q*K^T operation. template inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index d37544e648e0..92a7172c1a75 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -7,6 +7,42 @@ namespace cacheflow { +// FP16 vector types for Q, K, V. +template<> +struct Vec { + using Type = uint16_t; +}; +template<> +struct Vec { + using Type = uint32_t; +}; +template<> +struct Vec { + using Type = uint2; +}; +template<> +struct Vec { + using Type = uint4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template<> +struct FloatVec { + using Type = float; +}; +template<> +struct FloatVec { + using Type = float2; +}; +template<> +struct FloatVec { + using Type = Float4_; +}; +template<> +struct FloatVec { + using Type = Float8_; +}; + // Utility functions for type conversions. inline __device__ uint32_t h0_h0(uint16_t a) { uint32_t b; diff --git a/csrc/attention/dtype_float32.cuh b/csrc/attention/dtype_float32.cuh index e51888cc1397..7c3e85e18838 100644 --- a/csrc/attention/dtype_float32.cuh +++ b/csrc/attention/dtype_float32.cuh @@ -6,7 +6,7 @@ namespace cacheflow { -// Float vector types. +// Define FP32 vector data types. struct Float4_ { float2 x; float2 y; @@ -19,6 +19,34 @@ struct Float8_ { float2 w; }; +// FP32 vector types for Q, K, V. +template<> +struct Vec { + using Type = float; +}; +template<> +struct Vec { + using Type = float2; +}; +template<> +struct Vec { + using Type = float4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template<> +struct FloatVec { + using Type = float; +}; +template<> +struct FloatVec { + using Type = float2; +}; +template<> +struct FloatVec { + using Type = float4; +}; + // Vector addition. inline __device__ float add(float a, float b) { return a + b; From e500681236edc45bb2a73a72065b3ea3cbd90b22 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 2 May 2023 07:31:11 +0000 Subject: [PATCH 14/26] Remove cuda_primitives --- csrc/cuda_primitives.h | 1340 ---------------------------------------- 1 file changed, 1340 deletions(-) delete mode 100644 csrc/cuda_primitives.h diff --git a/csrc/cuda_primitives.h b/csrc/cuda_primitives.h deleted file mode 100644 index 10e730fd7bda..000000000000 --- a/csrc/cuda_primitives.h +++ /dev/null @@ -1,1340 +0,0 @@ -#pragma once - -#include - -namespace cacheflow { -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Float8_ { - float2 x; - float2 y; - float2 z; - float2 w; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Float4_ { - float2 x; - float2 y; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -struct bf16_4_t { - __nv_bfloat162 x; - __nv_bfloat162 y; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct bf16_8_t { - __nv_bfloat162 x; - __nv_bfloat162 y; - __nv_bfloat162 z; - __nv_bfloat162 w; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float add(float a, float b) -{ - return a + b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 add(float2 a, float2 b) -{ - float2 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 add(float4 a, float4 b) -{ - float4 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) -{ - return a + b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) -{ - return bf16hadd2(a, b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) -{ - bf16_4_t c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) -{ - bf16_8_t c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint16_t add(uint16_t a, uint16_t b) -{ - uint16_t c; - asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t add(uint32_t a, uint32_t b) -{ - uint32_t c; - asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint2 add(uint2 a, uint2 b) -{ - uint2 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint4 add(uint4 a, uint4 b) -{ - uint4 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint16_t float_to_half(float f) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; -#if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // Is it better? - float zero = 0.f; - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(zero), "f"(f)); -#else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); -#endif - return tmp.u16[0]; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t float2_to_half2(float2 f) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); -#else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); -#endif - return tmp.u32; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float half_to_float(uint16_t h) -{ - float f; - asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); - return f; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 half2_to_float2(uint32_t v) -{ - uint16_t lo, hi; - asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); - return make_float2(half_to_float(lo), half_to_float(hi)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 add(uint32_t a, float2 fb) -{ - float2 fa = half2_to_float2(a); - return add(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ add(uint2 a, Float4_ fb) -{ - Float4_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ add(uint4 a, Float8_ fb) -{ - Float8_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - fc.z = add(a.z, fb.z); - fc.w = add(a.w, fb.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t h0_h0(uint16_t a) -{ - uint32_t b; - asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); - return b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float fma(float a, float b, float c) -{ - return a * b + c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(float2 a, float2 b, float2 c) -{ - float2 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(float a, float2 b, float2 c) -{ - float2 d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 fma(float4 a, float4 b, float4 c) -{ - float4 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 fma(float a, float4 b, float4 c) -{ - float4 d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - d.z = fma(a, b.z, c.z); - d.w = fma(a, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) -{ - Float4_ d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) -{ - Float8_ d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - d.z = fma(a, b.z, c.z); - d.w = fma(a, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ float2 add(__nv_bfloat162 a, float2 fb) -{ - float2 fa = bf1622float2(a); - return add(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) -{ - Float4_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) -{ - Float8_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - fc.z = add(a.z, fb.z); - fc.w = add(a.w, fb.w); - return fc; -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) -{ - uint32_t d; - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) -{ - return fma(h0_h0(a), b, c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) -{ - uint2 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) -{ - uint32_t s = h0_h0(a); - uint2 d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) -{ - uint4 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) -{ - uint32_t s = h0_h0(a); - uint4 d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - d.z = fma(s, b.z, c.z); - d.w = fma(s, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float fma(uint16_t a, uint16_t b, float fc) -{ - float fa = half_to_float(a); - float fb = half_to_float(b); - return fa * fb + fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) -{ - float2 fa = half2_to_float2(a); - float2 fb = half2_to_float2(b); - return fma(fa, fb, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) -{ - return fma(h0_h0(a), b, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) -{ - Float4_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) -{ - uint32_t s = h0_h0(a); - Float4_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) -{ - Float8_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - fd.z = fma(a.z, b.z, fc.z); - fd.w = fma(a.w, b.w, fc.w); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) -{ - uint32_t s = h0_h0(a); - Float8_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - fd.z = fma(s, b.z, fc.z); - fd.w = fma(s, b.w, fc.w); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -#ifdef ENABLE_BF16 -inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ - return bf16hfma2(a, b, c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ - return bf16hfma2(bf162bf162(a), b, c); -} -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) -{ - bf16_4_t d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_4_t d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) -{ - bf16_8_t d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_8_t d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - d.z = fma(s, b.z, c.z); - d.w = fma(s, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) -{ - return __bfloat162float(a) * __bfloat162float(b) + fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) -{ - float2 fa = bf1622float2(a); - float2 fb = bf1622float2(b); - return fma(fa, fb, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) -{ - return fma(bf162bf162(a), b, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) -{ - Float4_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) -{ - __nv_bfloat162 s = bf162bf162(a); - Float4_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) -{ - Float8_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - fd.z = fma(a.z, b.z, fc.z); - fd.w = fma(a.w, b.w, fc.w); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) -{ - __nv_bfloat162 s = bf162bf162(a); - Float8_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - fd.z = fma(s, b.z, fc.z); - fd.w = fma(s, b.w, fc.w); - return fd; -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ Acc mul(A a, B b); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(float a, float b) -{ - return a * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(float2 a, float2 b) -{ - float2 c; - c.x = a.x * b.x; - c.y = a.y * b.y; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(float a, float2 b) -{ - float2 c; - c.x = a * b.x; - c.y = a * b.y; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float4 mul(float4 a, float4 b) -{ - float4 c; - c.x = a.x * b.x; - c.y = a.y * b.y; - c.z = a.z * b.z; - c.w = a.w * b.w; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float4 mul(float a, float4 b) -{ - float4 c; - c.x = a * b.x; - c.y = a * b.y; - c.z = a * b.z; - c.w = a * b.w; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint16_t mul(uint16_t a, uint16_t b) -{ - uint16_t c; - asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint32_t mul(uint32_t a, uint32_t b) -{ - uint32_t c; - asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint32_t mul(uint16_t a, uint32_t b) -{ - return mul(h0_h0(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint2 mul(uint2 a, uint2 b) -{ - uint2 c; - c.x = mul(a.x, b.x); - c.y = mul(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint2 mul(uint16_t a, uint2 b) -{ - uint32_t s = h0_h0(a); - uint2 c; - c.x = mul(s, b.x); - c.y = mul(s, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint4 mul(uint4 a, uint4 b) -{ - uint4 c; - c.x = mul(a.x, b.x); - c.y = mul(a.y, b.y); - c.z = mul(a.z, b.z); - c.w = mul(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint4 mul(uint16_t a, uint4 b) -{ - uint32_t s = h0_h0(a); - uint4 c; - c.x = mul(s, b.x); - c.y = mul(s, b.y); - c.z = mul(s, b.z); - c.w = mul(s, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(uint16_t a, uint16_t b) -{ - float fa = half_to_float(a); - float fb = half_to_float(b); - return fa * fb; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(uint32_t a, uint32_t b) -{ - float2 fa = half2_to_float2(a); - float2 fb = half2_to_float2(b); - return mul(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(uint16_t a, uint32_t b) -{ - return mul(h0_h0(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(uint2 a, uint2 b) -{ - Float4_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(uint16_t a, uint2 b) -{ - uint32_t s = h0_h0(a); - Float4_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(uint4 a, uint4 b) -{ - Float8_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - fc.z = mul(a.z, b.z); - fc.w = mul(a.w, b.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(uint16_t a, uint4 b) -{ - uint32_t s = h0_h0(a); - Float8_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - fc.z = mul(s, b.z); - fc.w = mul(s, b.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -template<> -inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - return __hmul(a, b); -#else - return bf16hmul(a, b); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) -{ - return bf16hmul2(a, b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) -{ - return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) -{ - bf16_4_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_4_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) -{ - bf16_8_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); - c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); - c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_8_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); - c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); - c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) -{ - float fa = (float)a; - float fb = (float)b; - return fa * fb; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) -{ - float2 fa = bf1622float2(a); - float2 fb = bf1622float2(b); - return mul(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) -{ - return mul(bf162bf162(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) -{ - Float4_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - Float4_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) -{ - Float8_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - fc.z = mul(a.z, b.z); - fc.w = mul(a.w, b.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - Float8_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - fc.z = mul(s, b.z); - fc.w = mul(s, b.w); - return fc; -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(float v) -{ - return v; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(float2 v) -{ - return v.x + v.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(float4 v) -{ - return v.x + v.y + v.z + v.w; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ float sum(__nv_bfloat162 v) -{ - float2 vf = bf1622float2(v); - return vf.x + vf.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(bf16_4_t v) -{ - return sum(v.x) + sum(v.y); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(bf16_8_t v) -{ - return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint16_t v) -{ - return half_to_float(v); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint32_t v) -{ - float2 tmp = half2_to_float2(v); - return tmp.x + tmp.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint2 v) -{ - uint32_t c = add(v.x, v.y); - return sum(c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint4 v) -{ -#if 1 - uint32_t c = add(v.x, v.y); - c = add(c, v.z); - c = add(c, v.w); -#else - uint32_t c = add(v.x, v.y); - uint32_t d = add(v.z, v.w); - c = add(c, d); -#endif - return sum(c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(Float4_ v) -{ - return v.x.x + v.x.y + v.y.x + v.y.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(Float8_ v) -{ - return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float dot(float a, float b) -{ - return a * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float dot(float2 a, float2 b) -{ - float2 c = mul(a, b); - return c.x + c.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float dot(Float4_ a, Float4_ b) -{ - float2 acc = mul(a.x, b.x); - acc = fma(a.y, b.y, acc); - return acc.x + acc.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float dot(Float8_ a, Float8_ b) -{ - float2 acc = mul(a.x, b.x); - acc = fma(a.y, b.y, acc); - acc = fma(a.z, b.z, acc); - acc = fma(a.w, b.w, acc); - return acc.x + acc.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float dot(T a, T b) -{ - return sum(mul(a, b)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float dot(T a, T b) -{ - return sum(mul(a, b)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void zero(uint16_t& dst) -{ - dst = uint16_t(0); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void zero(T& dst) -{ - constexpr int WORDS = sizeof(T) / 4; - union { - T raw; - uint32_t words[WORDS]; - } tmp; -#pragma unroll - for (int ii = 0; ii < WORDS; ++ii) { - tmp.words[ii] = 0u; - } - dst = tmp.raw; -} - - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float& dst, float src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint16_t& dst, float src) -{ - dst = float_to_half(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint32_t& dst, float2 src) -{ - dst = float2_to_half2(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -#ifdef ENABLE_BF16 -inline __device__ void convert_from_float(__nv_bfloat16& dst, float src) -{ - dst = __float2bfloat16(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst = __float22bfloat162_rn(src); -#else - dst = __floats2bfloat162_rn(src.x, src.y); -#endif -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint2& dst, Float4_ src) -{ - dst.x = float2_to_half2(src.x); - dst.y = float2_to_half2(src.y); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint4& dst, Float8_ src) -{ - dst.x = float2_to_half2(src.x); - dst.y = float2_to_half2(src.y); - dst.z = float2_to_half2(src.z); - dst.w = float2_to_half2(src.w); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst.x = __float22bfloat162_rn(src.x); - dst.y = __float22bfloat162_rn(src.y); -#else - dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); - dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst.x = __float22bfloat162_rn(src.x); - dst.y = __float22bfloat162_rn(src.y); - dst.z = __float22bfloat162_rn(src.z); - dst.w = __float22bfloat162_rn(src.w); -#else - dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); - dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); - dst.z = __floats2bfloat162_rn(src.z.x, src.z.y); - dst.w = __floats2bfloat162_rn(src.w.x, src.w.y); -#endif -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float2& dst, float2 src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float4& dst, float4 src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float convert_to_float(float4 u) -{ - return u.x; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float convert_to_float(uint4 u) -{ - float2 tmp = half2_to_float2(u.x); - return tmp.x; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// inline __device__ float cast_to_float(float u) -// { -// return u; -// } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// inline __device__ float2 cast_to_float(float2 u) -// { -// return u; -// } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// inline __device__ float4 cast_to_float(float4 u) -// { -// return u; -// } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// inline __device__ Float4_ cast_to_float(Float4_ u) -// { -// return u; -// } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// inline __device__ Float8_ cast_to_float(Float8_ u) -// { -// return u; -// } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float cast_to_float(uint16_t u) -{ - return half_to_float(u); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 cast_to_float(uint32_t u) -{ - return half2_to_float2(u); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ cast_to_float(uint2 u) -{ - Float4_ tmp; - tmp.x = half2_to_float2(u.x); - tmp.y = half2_to_float2(u.y); - return tmp; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ cast_to_float(uint4 u) -{ - Float8_ tmp; - tmp.x = half2_to_float2(u.x); - tmp.y = half2_to_float2(u.y); - tmp.z = half2_to_float2(u.z); - tmp.w = half2_to_float2(u.w); - return tmp; -} - -} From cd82a109eeb648caa476c67586a8f8edd8b264fd Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 2 May 2023 07:31:53 +0000 Subject: [PATCH 15/26] Change .h to .cuh --- csrc/layernorm_kernels.cu | 2 +- csrc/{reduction_utils.h => reduction_utils.cuh} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename csrc/{reduction_utils.h => reduction_utils.cuh} (100%) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 84372ed2dd60..a9606b106721 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,7 +1,7 @@ #include #include -#include "reduction_utils.h" +#include "reduction_utils.cuh" namespace cacheflow { diff --git a/csrc/reduction_utils.h b/csrc/reduction_utils.cuh similarity index 100% rename from csrc/reduction_utils.h rename to csrc/reduction_utils.cuh From beb21bfad16cfbface403821ae9fddb669830a80 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 2 May 2023 07:43:56 +0000 Subject: [PATCH 16/26] Minor fix --- csrc/attention/attention_utils.cuh | 36 ++++++++++++++---------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index b55a45fb4716..df529095d9c2 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -9,32 +9,30 @@ namespace cacheflow { // Q*K^T operation. template -inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) -{ - using A_vec = typename FloatVec::Type; - // Compute the parallel products for Q*K^T (treat vector lanes separately). - A_vec qk_vec = mul(q[0], k[0]); +inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { + using A_vec = typename FloatVec::Type; + // Compute the parallel products for Q*K^T (treat vector lanes separately). + A_vec qk_vec = mul(q[0], k[0]); #pragma unroll - for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); - } + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } - // Finalize the reduction across lanes. - float qk = sum(qk_vec); + // Finalize the reduction across lanes. + float qk = sum(qk_vec); #pragma unroll - for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(uint32_t(-1), qk, mask); - } - return qk; + for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; } template struct Qk_dot { - template - static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) - { - return qk_dot_(q, k); - } + template + static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { + return qk_dot_(q, k); + } }; } // namespace cacheflow From b9355dbe40497aadbbd22fd063313cd859109efb Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 2 May 2023 07:53:33 +0000 Subject: [PATCH 17/26] Raise a build error for old GPUs --- setup.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/setup.py b/setup.py index 5ce42b93da62..513edfdd5e1e 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,10 @@ major, minor = compute_capability if major >= 8: NVCC_FLAGS.append('-DENABLE_BF16') +if major <= 6 or (major == 7 and minor < 5): + raise RuntimeError( + 'CacheFlow requires CUDA compute capability >= 7.5. ' + f'Found {major}.{minor}.') ext_modules = [] From 1818852ec7f6c9403f126534593422ec28c8f28d Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 2 May 2023 07:54:49 +0000 Subject: [PATCH 18/26] Minor --- setup.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 513edfdd5e1e..f8123aad9092 100644 --- a/setup.py +++ b/setup.py @@ -10,17 +10,17 @@ CXX_FLAGS = ['-g'] NVCC_FLAGS = ['-O2'] -# Enable bfloat16 support if the compute capability is >= 8.0. -# TODO(woosuk): Consider the case where the machine has multiple GPUs with +# FIXME(woosuk): Consider the case where the machine has multiple GPUs with # different compute capabilities. compute_capability = torch.cuda.get_device_capability() major, minor = compute_capability -if major >= 8: - NVCC_FLAGS.append('-DENABLE_BF16') if major <= 6 or (major == 7 and minor < 5): raise RuntimeError( 'CacheFlow requires CUDA compute capability >= 7.5. ' f'Found {major}.{minor}.') +# Enable bfloat16 support if the compute capability is >= 8.0. +if major >= 8: + NVCC_FLAGS.append('-DENABLE_BF16') ext_modules = [] From e8a7855fe86abb7f26b186c3da641d310f5a9861 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 2 May 2023 07:57:44 +0000 Subject: [PATCH 19/26] Minor --- setup.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index f8123aad9092..942bd01635e7 100644 --- a/setup.py +++ b/setup.py @@ -2,14 +2,14 @@ import torch from torch.utils import cpp_extension +CXX_FLAGS = ['-g'] +NVCC_FLAGS = ['-O2'] + if not torch.cuda.is_available(): raise RuntimeError( - 'Cannot find CUDA. ' + f'Cannot find CUDA at CUDA_HOME: {cpp_extension.CUDA_HOME}. ' 'CUDA must be available in order to build the package.') -CXX_FLAGS = ['-g'] -NVCC_FLAGS = ['-O2'] - # FIXME(woosuk): Consider the case where the machine has multiple GPUs with # different compute capabilities. compute_capability = torch.cuda.get_device_capability() From 1f91b8939fff3dcc8b294a1958f8fb761f46a2cf Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 2 May 2023 08:29:40 +0000 Subject: [PATCH 20/26] Minor --- csrc/attention/dtype_float16.cuh | 16 ++++++++-------- csrc/attention/dtype_float32.cuh | 14 +++++++------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index 92a7172c1a75..d2a60353e116 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -10,37 +10,37 @@ namespace cacheflow { // FP16 vector types for Q, K, V. template<> struct Vec { - using Type = uint16_t; + using Type = uint16_t; }; template<> struct Vec { - using Type = uint32_t; + using Type = uint32_t; }; template<> struct Vec { - using Type = uint2; + using Type = uint2; }; template<> struct Vec { - using Type = uint4; + using Type = uint4; }; // FP32 accumulator vector types corresponding to Vec. template<> struct FloatVec { - using Type = float; + using Type = float; }; template<> struct FloatVec { - using Type = float2; + using Type = float2; }; template<> struct FloatVec { - using Type = Float4_; + using Type = Float4_; }; template<> struct FloatVec { - using Type = Float8_; + using Type = Float8_; }; // Utility functions for type conversions. diff --git a/csrc/attention/dtype_float32.cuh b/csrc/attention/dtype_float32.cuh index 7c3e85e18838..fdb35bf4307d 100644 --- a/csrc/attention/dtype_float32.cuh +++ b/csrc/attention/dtype_float32.cuh @@ -6,7 +6,7 @@ namespace cacheflow { -// Define FP32 vector data types. +// Define custom FP32 vector data types. struct Float4_ { float2 x; float2 y; @@ -22,29 +22,29 @@ struct Float8_ { // FP32 vector types for Q, K, V. template<> struct Vec { - using Type = float; + using Type = float; }; template<> struct Vec { - using Type = float2; + using Type = float2; }; template<> struct Vec { - using Type = float4; + using Type = float4; }; // FP32 accumulator vector types corresponding to Vec. template<> struct FloatVec { - using Type = float; + using Type = float; }; template<> struct FloatVec { - using Type = float2; + using Type = float2; }; template<> struct FloatVec { - using Type = float4; + using Type = float4; }; // Vector addition. From 3a5e9f06d16481159b7b322491c195f2f850a70a Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 2 May 2023 09:14:28 +0000 Subject: [PATCH 21/26] Change cu.h to .h --- csrc/attention/{attention_dtypes.cuh => attention_dtypes.h} | 4 ++++ 1 file changed, 4 insertions(+) rename csrc/attention/{attention_dtypes.cuh => attention_dtypes.h} (59%) diff --git a/csrc/attention/attention_dtypes.cuh b/csrc/attention/attention_dtypes.h similarity index 59% rename from csrc/attention/attention_dtypes.cuh rename to csrc/attention/attention_dtypes.h index 1d586ddf7522..b04ea9a1145e 100644 --- a/csrc/attention/attention_dtypes.cuh +++ b/csrc/attention/attention_dtypes.h @@ -3,3 +3,7 @@ #include "attention_generic.cuh" #include "dtype_float16.cuh" #include "dtype_float32.cuh" + +#ifdef ENABLE_BF16 +#include "dtype_bfloat16.cuh" +#endif // ENABLE_BF16 From 502a67821eb3a6e58aada93512a76d8c7fba34f8 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 2 May 2023 09:14:53 +0000 Subject: [PATCH 22/26] Add Bfloat16 support to attention kernel --- csrc/attention/attention_kernels.cu | 8 +- csrc/attention/attention_utils.cuh | 2 +- csrc/attention/dtype_bfloat16.cuh | 362 ++++++++++++++++++++++++++++ 3 files changed, 369 insertions(+), 3 deletions(-) create mode 100644 csrc/attention/dtype_bfloat16.cuh diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index a4bd6aeb6867..83a2d42e6d46 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -1,7 +1,7 @@ #include #include -#include "attention_dtypes.cuh" +#include "attention_dtypes.h" #include "attention_utils.cuh" #include @@ -438,9 +438,13 @@ void single_query_cached_kv_attention( torch::Tensor& context_lens, // [num_seqs] int block_size, int max_context_len) { - // TODO(woosuk): Support FP32 and BF16. + // TODO(woosuk): Support FP32. if (query.dtype() == at::ScalarType::Half) { CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t); +#ifdef ENABLE_BF16 + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); +#endif } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index df529095d9c2..a4180b171e1d 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -1,6 +1,6 @@ #pragma once -#include "attention_dtypes.cuh" +#include "attention_dtypes.h" #include #include diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh new file mode 100644 index 000000000000..cd4750d7f27a --- /dev/null +++ b/csrc/attention/dtype_bfloat16.cuh @@ -0,0 +1,362 @@ +#pragma once + +#include "attention_generic.cuh" +// #include "bfloat16_utils.cuh" +#include "dtype_float32.cuh" + +#include +#include +#include + +namespace cacheflow { + +// Define custom BF16 vector data types. +struct bf16_4_t { + __nv_bfloat162 x; + __nv_bfloat162 y; +}; + +struct bf16_8_t { + __nv_bfloat162 x; + __nv_bfloat162 y; + __nv_bfloat162 z; + __nv_bfloat162 w; +}; + +// BF16 vector types for Q, K, V. +template<> +struct Vec<__nv_bfloat16, 1> { + using Type = __nv_bfloat16; +}; +template<> +struct Vec<__nv_bfloat16, 2> { + using Type = __nv_bfloat162; +}; +template<> +struct Vec<__nv_bfloat16, 4> { + using Type = bf16_4_t; +}; +template<> +struct Vec<__nv_bfloat16, 8> { + using Type = bf16_8_t; +}; + +// FP32 accumulator vector types corresponding to Vec. +template<> +struct FloatVec<__nv_bfloat16> { + using Type = float; +}; +template<> +struct FloatVec<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct FloatVec { + using Type = Float4_; +}; +template<> +struct FloatVec { + using Type = Float8_; +}; + +// Utility functions for type conversions. +inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { + return __bfloat1622float2(val); +} + +inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { + return __bfloat162bfloat162(val); +} + +// Vector addition. +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { + return a + b; +} + +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { + return __hadd2(a, b); +} + +inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { + bf16_4_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) { + bf16_8_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float2 add(__nv_bfloat162 a, float2 fb) { + float2 fa = bf1622float2(a); + return add(fa, fb); +} + +inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) { + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +// Vector multiplication. +template<> +inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { + return __hmul(a, b); +} + +template<> +inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { + return __hmul2(a, b); +} + +template<> +inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { + return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); +} + +template<> +inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + return c; +} + +template<> +inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + return c; +} + +template<> +inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); + return c; +} + +template<> +inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); + return c; +} + +template<> +inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { + float fa = __bfloat162float(a); + float fb = __bfloat162float(b); + return fa * fb; +} + +template<> +inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return mul(fa, fb); +} + +template<> +inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { + return mul(bf162bf162(a), b); +} + +template<> +inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +template<> +inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { + __nv_bfloat162 s = bf162bf162(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +template<> +inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +template<> +inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { + __nv_bfloat162 s = bf162bf162(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +// Vector fused multiply-add. +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { + return __hfma2(a, b, c); +} + +inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { + return __hfma2(bf162bf162(a), b, c); +} + +inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { + bf16_4_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) { + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) { + bf16_8_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) { + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) { + return __bfloat162float(a) * __bfloat162float(b) + fc; +} + +inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) { + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return fma(fa, fb, fc); +} + +inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) { + return fma(bf162bf162(a), b, fc); +} + +inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) { + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) { + __nv_bfloat162 s = bf162bf162(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) { + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { + __nv_bfloat162 s = bf162bf162(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +// Vector sum. +template<> +inline __device__ float sum(__nv_bfloat16 v) { + return __bfloat162float(v); +} + +template<> +inline __device__ float sum(__nv_bfloat162 v) { + float2 vf = bf1622float2(v); + return vf.x + vf.y; +} + +template<> +inline __device__ float sum(bf16_4_t v) { + return sum(v.x) + sum(v.y); +} + +template<> +inline __device__ float sum(bf16_8_t v) { + return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); +} + +// From float32 to bfloat16. +inline __device__ void from_float(__nv_bfloat16& dst, float src) { + dst = __float2bfloat16(src); +} + +inline __device__ void from_float(__nv_bfloat162& dst, float2 src) { + dst = __float22bfloat162_rn(src); +} + +inline __device__ void from_float(bf16_4_t& dst, Float4_ src) { + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); +} + +inline __device__ void from_float(bf16_8_t& dst, Float8_ src) { + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); + dst.z = __float22bfloat162_rn(src.z); + dst.w = __float22bfloat162_rn(src.w); +} + +} // namespace cacheflow From fa0d5d4d14609feaf715016723c4b717464f3f37 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 2 May 2023 09:24:57 +0000 Subject: [PATCH 23/26] Minor --- csrc/attention/dtype_bfloat16.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index cd4750d7f27a..1e409296ee58 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -1,7 +1,6 @@ #pragma once #include "attention_generic.cuh" -// #include "bfloat16_utils.cuh" #include "dtype_float32.cuh" #include From ee82a9d441b206df618ee4b284b08273da9eb2b7 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Tue, 2 May 2023 09:27:54 +0000 Subject: [PATCH 24/26] Minor --- setup.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/setup.py b/setup.py index 942bd01635e7..52ff89f63b0e 100644 --- a/setup.py +++ b/setup.py @@ -14,10 +14,6 @@ # different compute capabilities. compute_capability = torch.cuda.get_device_capability() major, minor = compute_capability -if major <= 6 or (major == 7 and minor < 5): - raise RuntimeError( - 'CacheFlow requires CUDA compute capability >= 7.5. ' - f'Found {major}.{minor}.') # Enable bfloat16 support if the compute capability is >= 8.0. if major >= 8: NVCC_FLAGS.append('-DENABLE_BF16') From 6499bc89f071c4591c2adb2b7bf6e3d00d16e3b3 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Wed, 3 May 2023 20:55:29 +0000 Subject: [PATCH 25/26] Fix bugs in merge --- csrc/attention/attention_dtypes.cuh | 4 ++++ csrc/attention/attention_dtypes.h | 9 --------- csrc/attention/attention_kernels.cu | 6 +++++- csrc/attention/dtype_float32.cuh | 2 +- 4 files changed, 10 insertions(+), 11 deletions(-) delete mode 100644 csrc/attention/attention_dtypes.h diff --git a/csrc/attention/attention_dtypes.cuh b/csrc/attention/attention_dtypes.cuh index 1d586ddf7522..b04ea9a1145e 100644 --- a/csrc/attention/attention_dtypes.cuh +++ b/csrc/attention/attention_dtypes.cuh @@ -3,3 +3,7 @@ #include "attention_generic.cuh" #include "dtype_float16.cuh" #include "dtype_float32.cuh" + +#ifdef ENABLE_BF16 +#include "dtype_bfloat16.cuh" +#endif // ENABLE_BF16 diff --git a/csrc/attention/attention_dtypes.h b/csrc/attention/attention_dtypes.h deleted file mode 100644 index b04ea9a1145e..000000000000 --- a/csrc/attention/attention_dtypes.h +++ /dev/null @@ -1,9 +0,0 @@ -#pragma once - -#include "attention_generic.cuh" -#include "dtype_float16.cuh" -#include "dtype_float32.cuh" - -#ifdef ENABLE_BF16 -#include "dtype_bfloat16.cuh" -#endif // ENABLE_BF16 diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index a4bd6aeb6867..d141af54e23d 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -438,9 +438,13 @@ void single_query_cached_kv_attention( torch::Tensor& context_lens, // [num_seqs] int block_size, int max_context_len) { - // TODO(woosuk): Support FP32 and BF16. + // TODO(woosuk): Support FP32. if (query.dtype() == at::ScalarType::Half) { CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t); +#ifdef ENABLE_BF16 + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); +#endif } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } diff --git a/csrc/attention/dtype_float32.cuh b/csrc/attention/dtype_float32.cuh index 517da64b3609..fdb35bf4307d 100644 --- a/csrc/attention/dtype_float32.cuh +++ b/csrc/attention/dtype_float32.cuh @@ -6,7 +6,7 @@ namespace cacheflow { -// Define FP32 vector data types. +// Define custom FP32 vector data types. struct Float4_ { float2 x; float2 y; From bd2f8ce66d3a011dfe14937227d7200928f521c7 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Wed, 3 May 2023 21:08:19 +0000 Subject: [PATCH 26/26] Minor --- csrc/attention/{attention_dtypes.cuh => attention_dtypes.h} | 0 csrc/attention/attention_kernels.cu | 2 +- csrc/attention/attention_utils.cuh | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename csrc/attention/{attention_dtypes.cuh => attention_dtypes.h} (100%) diff --git a/csrc/attention/attention_dtypes.cuh b/csrc/attention/attention_dtypes.h similarity index 100% rename from csrc/attention/attention_dtypes.cuh rename to csrc/attention/attention_dtypes.h diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index d141af54e23d..83a2d42e6d46 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -1,7 +1,7 @@ #include #include -#include "attention_dtypes.cuh" +#include "attention_dtypes.h" #include "attention_utils.cuh" #include diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index df529095d9c2..a4180b171e1d 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -1,6 +1,6 @@ #pragma once -#include "attention_dtypes.cuh" +#include "attention_dtypes.h" #include #include