Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed May 9, 2024
1 parent 20cd490 commit 3cefe5e
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 23 deletions.
1 change: 0 additions & 1 deletion csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "attention_dtypes.h"
#include "attention_utils.cuh"

#include "../quantization/fp8/dtype_kv_cache.cuh"
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
Expand Down
6 changes: 6 additions & 0 deletions csrc/attention/dtype_fp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@

namespace vllm {

enum class Fp8KVCacheDataType {
kAuto = 0,
kFp8E4M3 = 1,
kFp8E5M2 = 2,
};

// fp8 vector types for quantization of kv cache
template<>
struct Vec<uint8_t, 1> {
Expand Down
2 changes: 1 addition & 1 deletion csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/fp8/dtype_kv_cache.cuh"

#ifdef USE_ROCM
#include "quantization/fp8/amd/quant_utils.cuh"
#else
Expand Down
8 changes: 3 additions & 5 deletions csrc/quantization/fp8/amd/quant_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <hip/hip_bf16.h>
#include <hip/hip_bfloat16.h>

#include "../dtype_kv_cache.cuh"
#include "../../../attention/dtype_fp8.cuh"
#include "../../../attention/dtype_float32.cuh"
#include "../../../attention/dtype_bfloat16.cuh"

Expand Down Expand Up @@ -522,8 +522,7 @@ __inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(const uint3
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout convert(const Tin &x) {
#ifdef ENABLE_FP8
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ||
kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return vec_conversion<Tout, Tin>(x);
}
#endif
Expand All @@ -533,8 +532,7 @@ __inline__ __device__ Tout convert(const Tin &x) {
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
#ifdef ENABLE_FP8
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ||
kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return scaled_vec_conversion<Tout, Tin>(x, scale);
}
#endif
Expand Down
10 changes: 0 additions & 10 deletions csrc/quantization/fp8/dtype_kv_cache.cuh

This file was deleted.

9 changes: 3 additions & 6 deletions csrc/quantization/fp8/nvidia/quant_utils.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#pragma once

#include "../dtype_kv_cache.cuh"
#include "../../../attention/attention_dtypes.h"
#include <assert.h>
#include <float.h>
Expand Down Expand Up @@ -503,8 +502,7 @@ __inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout convert(const Tin &x) {
#if 0 // Disable the following code to reduce the binary size.
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ||
kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return vec_conversion<Tout, Tin>(x, __NV_E4M3);
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
return vec_conversion<Tout, Tin>(x, __NV_E5M2);
Expand All @@ -516,8 +514,7 @@ __inline__ __device__ Tout convert(const Tin &x) {
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
#ifdef ENABLE_FP8
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ||
kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
Expand Down Expand Up @@ -567,5 +564,5 @@ __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
}

} // namespace fp8
#endif // USE_ROCM
#endif // not USE_ROCM
} // namespace vllm

0 comments on commit 3cefe5e

Please sign in to comment.