diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index 4a830c00b307..e2a9956faf98 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -213,8 +213,8 @@ def add_server_arguments(parser: argparse.ArgumentParser): parser.add_argument('--use-np-cache', action='store_true', help='save a numpy copy of model weights for faster loading') parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights') - # NOTE(woosuk): If FlashAttention is used, the float data type is not supported. - parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type') + # NOTE(woosuk): FlashAttention does not support float32. + parser.add_argument('--dtype', type=str, default='half', choices=['half', 'bfloat16'], help='data type') # Parallel arguments parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU') parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages') diff --git a/cacheflow/models/utils.py b/cacheflow/models/utils.py index 4dda73ec5c32..ced58c54a0c8 100644 --- a/cacheflow/models/utils.py +++ b/cacheflow/models/utils.py @@ -17,6 +17,7 @@ 'float': torch.float, 'float16': torch.float16, 'float32': torch.float32, + 'bfloat16': torch.bfloat16, } diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 12ee6c54827c..a13b1b9cf290 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -34,7 +34,9 @@ void silu_and_mul( dim3 grid(num_tokens); dim3 block(std::min(d, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, input.scalar_type(), "silu_and_mul_kernel", [&] { diff --git a/csrc/attention/attention_dtypes.cuh b/csrc/attention/attention_dtypes.h similarity index 59% rename from csrc/attention/attention_dtypes.cuh rename to csrc/attention/attention_dtypes.h index 1d586ddf7522..b04ea9a1145e 100644 --- a/csrc/attention/attention_dtypes.cuh +++ b/csrc/attention/attention_dtypes.h @@ -3,3 +3,7 @@ #include "attention_generic.cuh" #include "dtype_float16.cuh" #include "dtype_float32.cuh" + +#ifdef ENABLE_BF16 +#include "dtype_bfloat16.cuh" +#endif // ENABLE_BF16 diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index a4bd6aeb6867..83a2d42e6d46 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -1,7 +1,7 @@ #include #include -#include "attention_dtypes.cuh" +#include "attention_dtypes.h" #include "attention_utils.cuh" #include @@ -438,9 +438,13 @@ void single_query_cached_kv_attention( torch::Tensor& context_lens, // [num_seqs] int block_size, int max_context_len) { - // TODO(woosuk): Support FP32 and BF16. + // TODO(woosuk): Support FP32. if (query.dtype() == at::ScalarType::Half) { CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t); +#ifdef ENABLE_BF16 + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); +#endif } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index df529095d9c2..a4180b171e1d 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -1,6 +1,6 @@ #pragma once -#include "attention_dtypes.cuh" +#include "attention_dtypes.h" #include #include diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh new file mode 100644 index 000000000000..1e409296ee58 --- /dev/null +++ b/csrc/attention/dtype_bfloat16.cuh @@ -0,0 +1,361 @@ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float32.cuh" + +#include +#include +#include + +namespace cacheflow { + +// Define custom BF16 vector data types. +struct bf16_4_t { + __nv_bfloat162 x; + __nv_bfloat162 y; +}; + +struct bf16_8_t { + __nv_bfloat162 x; + __nv_bfloat162 y; + __nv_bfloat162 z; + __nv_bfloat162 w; +}; + +// BF16 vector types for Q, K, V. +template<> +struct Vec<__nv_bfloat16, 1> { + using Type = __nv_bfloat16; +}; +template<> +struct Vec<__nv_bfloat16, 2> { + using Type = __nv_bfloat162; +}; +template<> +struct Vec<__nv_bfloat16, 4> { + using Type = bf16_4_t; +}; +template<> +struct Vec<__nv_bfloat16, 8> { + using Type = bf16_8_t; +}; + +// FP32 accumulator vector types corresponding to Vec. +template<> +struct FloatVec<__nv_bfloat16> { + using Type = float; +}; +template<> +struct FloatVec<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct FloatVec { + using Type = Float4_; +}; +template<> +struct FloatVec { + using Type = Float8_; +}; + +// Utility functions for type conversions. +inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { + return __bfloat1622float2(val); +} + +inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { + return __bfloat162bfloat162(val); +} + +// Vector addition. +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { + return a + b; +} + +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { + return __hadd2(a, b); +} + +inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { + bf16_4_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) { + bf16_8_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float2 add(__nv_bfloat162 a, float2 fb) { + float2 fa = bf1622float2(a); + return add(fa, fb); +} + +inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) { + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +// Vector multiplication. +template<> +inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { + return __hmul(a, b); +} + +template<> +inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { + return __hmul2(a, b); +} + +template<> +inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { + return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); +} + +template<> +inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + return c; +} + +template<> +inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + return c; +} + +template<> +inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); + return c; +} + +template<> +inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); + return c; +} + +template<> +inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { + float fa = __bfloat162float(a); + float fb = __bfloat162float(b); + return fa * fb; +} + +template<> +inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return mul(fa, fb); +} + +template<> +inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { + return mul(bf162bf162(a), b); +} + +template<> +inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +template<> +inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { + __nv_bfloat162 s = bf162bf162(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +template<> +inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +template<> +inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { + __nv_bfloat162 s = bf162bf162(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +// Vector fused multiply-add. +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { + return __hfma2(a, b, c); +} + +inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { + return __hfma2(bf162bf162(a), b, c); +} + +inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { + bf16_4_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) { + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) { + bf16_8_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) { + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) { + return __bfloat162float(a) * __bfloat162float(b) + fc; +} + +inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) { + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return fma(fa, fb, fc); +} + +inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) { + return fma(bf162bf162(a), b, fc); +} + +inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) { + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) { + __nv_bfloat162 s = bf162bf162(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) { + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { + __nv_bfloat162 s = bf162bf162(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +// Vector sum. +template<> +inline __device__ float sum(__nv_bfloat16 v) { + return __bfloat162float(v); +} + +template<> +inline __device__ float sum(__nv_bfloat162 v) { + float2 vf = bf1622float2(v); + return vf.x + vf.y; +} + +template<> +inline __device__ float sum(bf16_4_t v) { + return sum(v.x) + sum(v.y); +} + +template<> +inline __device__ float sum(bf16_8_t v) { + return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); +} + +// From float32 to bfloat16. +inline __device__ void from_float(__nv_bfloat16& dst, float src) { + dst = __float2bfloat16(src); +} + +inline __device__ void from_float(__nv_bfloat162& dst, float2 src) { + dst = __float22bfloat162_rn(src); +} + +inline __device__ void from_float(bf16_4_t& dst, Float4_ src) { + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); +} + +inline __device__ void from_float(bf16_8_t& dst, Float8_ src) { + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); + dst.z = __float22bfloat162_rn(src.z); + dst.w = __float22bfloat162_rn(src.w); +} + +} // namespace cacheflow diff --git a/csrc/attention/dtype_float32.cuh b/csrc/attention/dtype_float32.cuh index 517da64b3609..fdb35bf4307d 100644 --- a/csrc/attention/dtype_float32.cuh +++ b/csrc/attention/dtype_float32.cuh @@ -6,7 +6,7 @@ namespace cacheflow { -// Define FP32 vector data types. +// Define custom FP32 vector data types. struct Float4_ { float2 x; float2 y; diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 5f97af254142..ddd2d3505780 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -14,14 +14,16 @@ void swap_blocks( torch::Device dst_device = dst.device(); cudaMemcpyKind memcpy_type; if (src_device.is_cuda() && dst_device.is_cuda()) { - assert(src_device.index() == dst_device.index()); + TORCH_CHECK( + src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); memcpy_type = cudaMemcpyDeviceToDevice; } else if (src_device.is_cuda() && dst_device.is_cpu()) { memcpy_type = cudaMemcpyDeviceToHost; } else if (src_device.is_cpu() && dst_device.is_cuda()) { memcpy_type = cudaMemcpyHostToDevice; } else { - assert(false); + TORCH_CHECK(false, "Invalid device combination"); } void *src_ptr = src.data_ptr(); @@ -29,6 +31,7 @@ void swap_blocks( const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + // NOTE(woosuk): This can be slow if the number of blocks is large. for (const auto& pair : block_mapping) { int64_t src_block_number = pair.first; int64_t dst_block_number = pair.second; @@ -122,7 +125,9 @@ void copy_blocks( dim3 grid(num_layers, num_pairs); dim3 block(std::min(1024, numel_per_block)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { cacheflow::copy_blocks_kernel<<>>( key_cache_ptrs_tensor.data_ptr(), @@ -176,6 +181,50 @@ __global__ void reshape_and_cache_kernel( } } +} // namespace cacheflow + +void reshape_and_cache( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& slot_mapping) // [num_tokens] +{ + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(3); + int x = key_cache.size(4); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + key.scalar_type(), + "reshape_and_cache_kernel", + [&] { + cacheflow::reshape_and_cache_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + slot_mapping.data_ptr(), + key_stride, + value_stride, + num_heads, + head_size, + block_size, + x); + }); +} + +namespace cacheflow { + // Grid: (num_blocks, block_size). template __global__ void gather_cached_kv_kernel( @@ -296,45 +345,6 @@ __global__ void gather_cached_kv_kernel_optimized( } // namespace cacheflow -void reshape_and_cache( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - torch::Tensor& slot_mapping) // [num_tokens] -{ - int num_tokens = key.size(0); - int num_heads = key.size(1); - int head_size = key.size(2); - int block_size = key_cache.size(3); - int x = key_cache.size(4); - - int key_stride = key.stride(0); - int value_stride = value.stride(0); - - dim3 grid(num_tokens); - dim3 block(std::min(num_heads * head_size, 512)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - key.scalar_type(), - "reshape_and_cache_kernel", - [&] { - cacheflow::reshape_and_cache_kernel<<>>( - key.data_ptr(), - value.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - slot_mapping.data_ptr(), - key_stride, - value_stride, - num_heads, - head_size, - block_size, - x); - }); -} - - void gather_cached_kv( torch::Tensor& key, // [out] [num_tokens, num_heads, head_size] torch::Tensor& value, // [out] [num_tokens, num_heads, head_size] @@ -354,7 +364,9 @@ void gather_cached_kv( dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, key.scalar_type(), "gather_cached_kv_kernel_optimized", [&] { diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index a9606b106721..ffdcb4176f58 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -46,7 +46,9 @@ void rms_norm( dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, input.scalar_type(), "rms_norm_kernel", [&] { diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 527fe2cd97c8..637e233c9a9a 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -64,7 +64,9 @@ void rotary_embedding_neox( dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, query.scalar_type(), "rotary_embedding_neox", [&] { diff --git a/setup.py b/setup.py index bac0b0f18c74..52ff89f63b0e 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,22 @@ import setuptools +import torch from torch.utils import cpp_extension CXX_FLAGS = ['-g'] NVCC_FLAGS = ['-O2'] +if not torch.cuda.is_available(): + raise RuntimeError( + f'Cannot find CUDA at CUDA_HOME: {cpp_extension.CUDA_HOME}. ' + 'CUDA must be available in order to build the package.') + +# FIXME(woosuk): Consider the case where the machine has multiple GPUs with +# different compute capabilities. +compute_capability = torch.cuda.get_device_capability() +major, minor = compute_capability +# Enable bfloat16 support if the compute capability is >= 8.0. +if major >= 8: + NVCC_FLAGS.append('-DENABLE_BF16') ext_modules = [] @@ -23,7 +36,7 @@ ) ext_modules.append(attention_extension) -# Positional encodings. +# Positional encoding kernels. positional_encoding_extension = cpp_extension.CUDAExtension( name='cacheflow.pos_encoding_ops', sources=['csrc/pos_encoding.cpp', 'csrc/pos_encoding_kernels.cu'], @@ -39,6 +52,7 @@ ) ext_modules.append(layernorm_extension) +# Activation kernels. activation_extension = cpp_extension.CUDAExtension( name='cacheflow.activation_ops', sources=['csrc/activation.cpp', 'csrc/activation_kernels.cu'],