From 3cefe5ec0185101eb58baa0b80cbee1e08e615e3 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 9 May 2024 15:29:43 -0700 Subject: [PATCH] refactor --- csrc/attention/attention_kernels.cu | 1 - csrc/attention/dtype_fp8.cuh | 6 ++++++ csrc/cache_kernels.cu | 2 +- csrc/quantization/fp8/amd/quant_utils.cuh | 8 +++----- csrc/quantization/fp8/dtype_kv_cache.cuh | 10 ---------- csrc/quantization/fp8/nvidia/quant_utils.cuh | 9 +++------ 6 files changed, 13 insertions(+), 23 deletions(-) delete mode 100644 csrc/quantization/fp8/dtype_kv_cache.cuh diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 70eea52d3458c..41b337dd91d36 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -24,7 +24,6 @@ #include "attention_dtypes.h" #include "attention_utils.cuh" -#include "../quantization/fp8/dtype_kv_cache.cuh" #ifdef USE_ROCM #include #include "../quantization/fp8/amd/quant_utils.cuh" diff --git a/csrc/attention/dtype_fp8.cuh b/csrc/attention/dtype_fp8.cuh index a1591a56ab4ca..2b32ce372a64f 100644 --- a/csrc/attention/dtype_fp8.cuh +++ b/csrc/attention/dtype_fp8.cuh @@ -11,6 +11,12 @@ namespace vllm { +enum class Fp8KVCacheDataType { + kAuto = 0, + kFp8E4M3 = 1, + kFp8E5M2 = 2, +}; + // fp8 vector types for quantization of kv cache template<> struct Vec { diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 55586b9a20d94..e5b74da6ad068 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -4,7 +4,7 @@ #include "cuda_compat.h" #include "dispatch_utils.h" -#include "quantization/fp8/dtype_kv_cache.cuh" + #ifdef USE_ROCM #include "quantization/fp8/amd/quant_utils.cuh" #else diff --git a/csrc/quantization/fp8/amd/quant_utils.cuh b/csrc/quantization/fp8/amd/quant_utils.cuh index 7722d6667ef9b..df0329f79d361 100644 --- a/csrc/quantization/fp8/amd/quant_utils.cuh +++ b/csrc/quantization/fp8/amd/quant_utils.cuh @@ -5,7 +5,7 @@ #include #include -#include "../dtype_kv_cache.cuh" +#include "../../../attention/dtype_fp8.cuh" #include "../../../attention/dtype_float32.cuh" #include "../../../attention/dtype_bfloat16.cuh" @@ -522,8 +522,7 @@ __inline__ __device__ float4 scaled_vec_conversion(const uint3 template __inline__ __device__ Tout convert(const Tin &x) { #ifdef ENABLE_FP8 - if constexpr (kv_dt == Fp8KVCacheDataType::kAuto || - kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { return vec_conversion(x); } #endif @@ -533,8 +532,7 @@ __inline__ __device__ Tout convert(const Tin &x) { template __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) { #ifdef ENABLE_FP8 - if constexpr (kv_dt == Fp8KVCacheDataType::kAuto || - kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { return scaled_vec_conversion(x, scale); } #endif diff --git a/csrc/quantization/fp8/dtype_kv_cache.cuh b/csrc/quantization/fp8/dtype_kv_cache.cuh deleted file mode 100644 index 3b7b400e9ed96..0000000000000 --- a/csrc/quantization/fp8/dtype_kv_cache.cuh +++ /dev/null @@ -1,10 +0,0 @@ -#pragma once - -namespace vllm { - -enum class Fp8KVCacheDataType { - kAuto = 0, - kFp8E4M3 = 1, - kFp8E5M2 = 2, -}; -} diff --git a/csrc/quantization/fp8/nvidia/quant_utils.cuh b/csrc/quantization/fp8/nvidia/quant_utils.cuh index 04356c91cdebc..4eeacf7a6f9d9 100644 --- a/csrc/quantization/fp8/nvidia/quant_utils.cuh +++ b/csrc/quantization/fp8/nvidia/quant_utils.cuh @@ -1,6 +1,5 @@ #pragma once -#include "../dtype_kv_cache.cuh" #include "../../../attention/attention_dtypes.h" #include #include @@ -503,8 +502,7 @@ __inline__ __device__ float4 scaled_vec_conversion( template __inline__ __device__ Tout convert(const Tin &x) { #if 0 // Disable the following code to reduce the binary size. - if constexpr (kv_dt == Fp8KVCacheDataType::kAuto || - kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { return vec_conversion(x, __NV_E4M3); } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { return vec_conversion(x, __NV_E5M2); @@ -516,8 +514,7 @@ __inline__ __device__ Tout convert(const Tin &x) { template __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) { #ifdef ENABLE_FP8 - if constexpr (kv_dt == Fp8KVCacheDataType::kAuto || - kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { return scaled_vec_conversion(x, scale, __NV_E4M3); } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { return scaled_vec_conversion(x, scale, __NV_E5M2); @@ -567,5 +564,5 @@ __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) { } } // namespace fp8 -#endif // USE_ROCM +#endif // not USE_ROCM } // namespace vllm