Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Refactor FP8 kv-cache with NVIDIA float8_e4m3 support #4535

Merged
merged 6 commits into from
May 10, 2024

Conversation

comaniac
Copy link
Collaborator

@comaniac comaniac commented May 1, 2024

The first PR for #4532.

Task list:

  • Add NVIDIA e4m3.
  • Refactor cache_kernel.cu.
  • Unit tests for reshape and cache.
  • Refactor attention_kernel.cu.
  • Unit tests for attention.
  • Refactor AMD.
  • Compatibility with FP8 disabled.

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@comaniac comaniac force-pushed the e4m3-kv-cache branch 5 times, most recently from dcf4178 to 79a4c19 Compare May 3, 2024 20:30
@comaniac comaniac marked this pull request as ready for review May 3, 2024 20:30
@comaniac
Copy link
Collaborator Author

comaniac commented May 3, 2024

Per offline discussion, this PR only includes backend refactoring for FP8 kv-cache related kernels and utilities. A follow-up PR will then cover the scaling factor loading. Thus, this PR is ready for review.

cc @pcmoritz @robertgshaw2-neuralmagic @HaiShaw @WoosukKwon

@comaniac comaniac changed the title [WIP][Kernel] Refactor FP8 kv-cache with NVIDIA float8_e4m3 support [Kernel] Refactor FP8 kv-cache with NVIDIA float8_e4m3 support May 3, 2024
vllm/_custom_ops.py Outdated Show resolved Hide resolved
@pcmoritz
Copy link
Collaborator

pcmoritz commented May 3, 2024

It is a bummer that github doesn't render the diff between the old and new nvidia quant_utils.cuh -- for ease of reviewing, here is the diff:

(base) pcmoritz@pcmoritz-DQ44HV60WX /tmp % diff quant_utils_old.cuh quant_utils_new.cuh 
2a3,6
> #include "../../../attention/attention_dtypes.h"
> #include "../../../attention/dtype_bfloat16.cuh"
> #include "../../../attention/dtype_float16.cuh"
> #include "../../../attention/dtype_float32.cuh"
4d7
< #include <stdint.h>
5a9
> #include <stdint.h>
7,10d10
< #include "../../attention/attention_dtypes.h"
< #include "../../attention/dtype_float32.cuh"
< #include "../../attention/dtype_float16.cuh"
< #include "../../attention/dtype_bfloat16.cuh"
12d11
< 
14,15d12
< #ifdef ENABLE_FP8_E5M2
< namespace fp8_e5m2_unscaled {
17,20c14,20
< template<typename Tout, typename Tin>
< __inline__ __device__ Tout vec_conversion(const Tin& x)
< {
<     return x;
---
> namespace fp8 {
> #ifdef ENABLE_FP8
> 
> template <typename Tout, typename Tin>
> __inline__ __device__ Tout
> vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
>   return x;
24,28c24,28
< template<>
< __inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
< {
<     __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
<     return res.x;
---
> template <>
> __inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(
>     const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
>   __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
>   return res.x;
32,42c32,42
< template<>
< __inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
< {
<     union {
<         uint16_t u16[2];
<         uint32_t u32;
<     } tmp;
<     __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2);
<     tmp.u16[0] = res.x;
<     tmp.u16[1] = res.y;
<     return tmp.u32;
---
> template <>
> __inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(
>     const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
>   union {
>     uint16_t u16[2];
>     uint32_t u32;
>   } tmp;
>   __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
>   tmp.u16[0] = res.x;
>   tmp.u16[1] = res.y;
>   return tmp.u32;
46,55c46,56
< template<>
< __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
< {
<     union {
<         uint2    u32x2;
<         uint32_t u32[2];
<     } tmp;
<     tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
<     tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
<     return tmp.u32x2;
---
> template <>
> __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(
>     const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
>   union {
>     uint2 u32x2;
>     uint32_t u32[2];
>   } tmp;
>   tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a, fp8_type);
>   tmp.u32[1] =
>       vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), fp8_type);
>   return tmp.u32x2;
59,68c60,69
< template<>
< __inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
< {
<     union {
<         uint4 u64x2;
<         uint2 u64[2];
<     } tmp;
<     tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
<     tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
<     return tmp.u64x2;
---
> template <>
> __inline__ __device__ uint4 vec_conversion<uint4, uint2>(
>     const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
>   union {
>     uint4 u64x2;
>     uint2 u64[2];
>   } tmp;
>   tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x, fp8_type);
>   tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y, fp8_type);
>   return tmp.u64x2;
72,80c73,81
< template<>
< __inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
< {
<     // Note there is no direct convert function from fp8 to bf16.
<     // fp8 -> half
<     __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
<     // half -> float -> bf16
<     float tmp = half_to_float(res.x);
<     return __float2bfloat16(tmp);
---
> template <>
> __inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(
>     const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
>   // Note there is no direct convert function from fp8 to bf16.
>   // fp8 -> half
>   __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
>   // half -> float -> bf16
>   float tmp = half_to_float(res.x);
>   return __float2bfloat16(tmp);
84,90c85,91
< template<>
< __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
< {
<     __nv_bfloat162 res;
<     res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
<     res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
<     return res;
---
> template <>
> __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(
>     const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
>   __nv_bfloat162 res;
>   res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type);
>   res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type);
>   return res;
94,100c95,102
< template<>
< __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
< {
<     bf16_4_t res;
<     res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
<     res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
<     return res;
---
> template <>
> __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(
>     const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
>   bf16_4_t res;
>   res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type);
>   res.y =
>       vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type);
>   return res;
104,115c106,117
< template<>
< __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
< {
<     bf16_4_t tmp1, tmp2;
<     tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
<     tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
<     bf16_8_t res;
<     res.x = tmp1.x;
<     res.y = tmp1.y;
<     res.z = tmp2.x;
<     res.w = tmp2.y;
<     return res;
---
> template <>
> __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(
>     const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
>   bf16_4_t tmp1, tmp2;
>   tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x, fp8_type);
>   tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y, fp8_type);
>   bf16_8_t res;
>   res.x = tmp1.x;
>   res.y = tmp1.y;
>   res.z = tmp2.x;
>   res.w = tmp2.y;
>   return res;
119,125c121,128
< template<>
< __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
< {
<     // fp8 -> half
<     uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a);
<     // half -> float
<     return half_to_float(tmp);
---
> template <>
> __inline__ __device__ float
> vec_conversion<float, uint8_t>(const uint8_t &a,
>                                const __nv_fp8_interpretation_t fp8_type) {
>   // fp8 -> half
>   uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type);
>   // half -> float
>   return half_to_float(tmp);
129,135c132,138
< template<>
< __inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
< {
<     // fp8x2 -> half2
<     uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a);
<     // half2 -> float2
<     return half2_to_float2(tmp);
---
> template <>
> __inline__ __device__ float2 vec_conversion<float2, uint16_t>(
>     const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
>   // fp8x2 -> half2
>   uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a, fp8_type);
>   // half2 -> float2
>   return half2_to_float2(tmp);
139,145c142,148
< template<>
< __inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
< {
<     Float4_ res;
<     res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
<     res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
<     return res;
---
> template <>
> __inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(
>     const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
>   Float4_ res;
>   res.x = vec_conversion<float2, uint16_t>((uint16_t)a, fp8_type);
>   res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), fp8_type);
>   return res;
149,160c152,163
< template<>
< __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
< {
<     Float4_ tmp1, tmp2;
<     tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
<     tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
<     Float8_ res;
<     res.x = tmp1.x;
<     res.y = tmp1.y;
<     res.z = tmp2.x;
<     res.w = tmp2.y;
<     return res;
---
> template <>
> __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(
>     const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
>   Float4_ tmp1, tmp2;
>   tmp1 = vec_conversion<Float4_, uint32_t>(a.x, fp8_type);
>   tmp2 = vec_conversion<Float4_, uint32_t>(a.y, fp8_type);
>   Float8_ res;
>   res.x = tmp1.x;
>   res.y = tmp1.y;
>   res.z = tmp2.x;
>   res.w = tmp2.y;
>   return res;
163d165
< 
165,171c167,174
< template<>
< __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
< {
<     __half_raw tmp;
<     tmp.x = a;
<     __nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2);
<     return (uint8_t)res;
---
> template <>
> __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(
>     const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
>   __half_raw tmp;
>   tmp.x = a;
>   __nv_fp8_storage_t res =
>       __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type);
>   return (uint8_t)res;
175,177c178,180
< template<>
< __inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
< {
---
> template <>
> __inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(
>     const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
179c182
<     assert(false);
---
>   assert(false);
181,182c184,186
<     __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2);
<     return (uint8_t)res;
---
>   __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
>       __nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
>   return (uint8_t)res;
187,191c191,195
< template<>
< __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
< {
<     __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2);
<     return (uint8_t)res;
---
> template <>
> __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(
>     const float &a, const __nv_fp8_interpretation_t fp8_type) {
>   __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type);
>   return (uint8_t)res;
195,200c199,204
< template<>
< __inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
< {
<     Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
<     float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
<     return res;
---
> template <>
> __inline__ __device__ float4 vec_conversion<float4, uint32_t>(
>     const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
>   Float4_ tmp = vec_conversion<Float4_, uint32_t>(a, fp8_type);
>   float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
>   return res;
202a207,213
> template <>
> __inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(
>     const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
>   union {
>     half2 float16;
>     uint32_t uint32;
>   };
204,210c215,217
< template<>
< __inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
< {
<     union {
<         half2    float16;
<         uint32_t uint32;
<     };
---
>   float16 = __float22half2_rn(a);
>   return uint32;
> }
212,213c219,232
<     float16 = __float22half2_rn(a);
<     return uint32;
---
> template <>
> __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(
>     const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
>   uint2 b;
>   float2 val;
>   val.x = a.x.x;
>   val.y = a.x.y;
>   b.x = vec_conversion<uint32_t, float2>(val, fp8_type);
> 
>   val.x = a.y.x;
>   val.y = a.y.y;
>   b.y = vec_conversion<uint32_t, float2>(val, fp8_type);
> 
>   return b;
216,223c235,244
< template<>
< __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
< {
<     uint2  b;
<     float2 val;
<     val.x = a.x.x;
<     val.y = a.x.y;
<     b.x   = vec_conversion<uint32_t, float2>(val);
---
> template <>
> __inline__ __device__ float4 vec_conversion<float4, Float4_>(
>     const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
>   float4 b;
>   b.x = a.x.x;
>   b.y = a.x.y;
>   b.z = a.y.x;
>   b.w = a.y.y;
>   return b;
> }
225,227c246,255
<     val.x = a.y.x;
<     val.y = a.y.y;
<     b.y   = vec_conversion<uint32_t, float2>(val);
---
> template <>
> __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(
>     const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
>   uint4 b;
>   b.x = vec_conversion<uint32_t, float2>(a.x, fp8_type);
>   b.y = vec_conversion<uint32_t, float2>(a.y, fp8_type);
>   b.z = vec_conversion<uint32_t, float2>(a.z, fp8_type);
>   b.w = vec_conversion<uint32_t, float2>(a.w, fp8_type);
>   return b;
> }
229c257,262
<     return b;
---
> template <>
> __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(
>     const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
>   __nv_bfloat162 b;
>   from_float(b, a);
>   return b;
232,240c265,270
< template<>
< __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
< {
<     float4 b;
<     b.x = a.x.x;
<     b.y = a.x.y;
<     b.z = a.y.x;
<     b.w = a.y.y;
<     return b;
---
> template <>
> __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(
>     const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
>   bf16_4_t b;
>   from_float(b, a);
>   return b;
243,251c273,278
< template<>
< __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
< {
<     uint4 b;
<     b.x = vec_conversion<uint32_t, float2>(a.x);
<     b.y = vec_conversion<uint32_t, float2>(a.y);
<     b.z = vec_conversion<uint32_t, float2>(a.z);
<     b.w = vec_conversion<uint32_t, float2>(a.w);
<     return b;
---
> template <>
> __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
>     const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
>   bf16_8_t b;
>   from_float(b, a);
>   return b;
254,258c281,290
< template<>
< __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) {
<     __nv_bfloat162 b;
<     from_float(b, a);
<     return b;
---
> /* Scaled and vectorized conversions, for data exchange between high and low
>    precision domains Convention of the scale in API, e.g: FP8_data =
>    Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8
>      Dequant(FP8) * scale =>  HP
>  */
> 
> template <typename Tout, typename Tin>
> __inline__ __device__ Tout scaled_vec_conversion(
>     const Tin &x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
>   return x;
261,265c293,299
< template<>
< __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) {
<     bf16_4_t b;
<     from_float(b, a);
<     return b;
---
> // fp8 -> half
> template <>
> __inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
>     const uint8_t &a, const float scale,
>     const __nv_fp8_interpretation_t fp8_type) {
>   __half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type);
>   return float_to_half(half_to_float(tmp.x) * scale);
268,272c302,314
< template<>
< __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) {
<     bf16_8_t b;
<     from_float(b, a);
<     return b;
---
> // fp8x2 -> half2
> template <>
> __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
>     const uint16_t &a, const float scale,
>     const __nv_fp8_interpretation_t fp8_type) {
>   union {
>     uint16_t u16[2];
>     uint32_t u32;
>   } tmp;
>   __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
>   tmp.u16[0] = float_to_half(half_to_float(res.x) * scale);
>   tmp.u16[1] = float_to_half(half_to_float(res.y) * scale);
>   return tmp.u32;
275,276c317,576
< } // namespace fp8_e5m2_unscaled
< #endif // ENABLE_FP8_E5M2
---
> // fp8x4 -> half2x2
> template <>
> __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
>     const uint32_t &a, const float scale,
>     const __nv_fp8_interpretation_t fp8_type) {
>   union {
>     uint2 u32x2;
>     uint32_t u32[2];
>   } tmp;
>   tmp.u32[0] =
>       scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, fp8_type);
>   tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U),
>                                                          scale, fp8_type);
>   return tmp.u32x2;
> }
> 
> // fp8x8 -> half2x4
> template <>
> __inline__ __device__ uint4
> scaled_vec_conversion<uint4, uint2>(const uint2 &a, const float scale,
>                                     const __nv_fp8_interpretation_t fp8_type) {
>   union {
>     uint4 u64x2;
>     uint2 u64[2];
>   } tmp;
>   tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, fp8_type);
>   tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, fp8_type);
>   return tmp.u64x2;
> }
> 
> // fp8 -> __nv_bfloat16
> template <>
> __inline__ __device__ __nv_bfloat16
> scaled_vec_conversion<__nv_bfloat16, uint8_t>(
>     const uint8_t &a, const float scale,
>     const __nv_fp8_interpretation_t fp8_type) {
>   // Note there is no direct convert function from fp8 to bf16.
>   // fp8 -> half
>   __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
>   // half -> float -> bf16
>   float tmp = half_to_float(res.x);
>   return __float2bfloat16(tmp * scale);
> }
> 
> // fp8x2 -> __nv_bfloat162
> template <>
> __inline__ __device__ __nv_bfloat162
> scaled_vec_conversion<__nv_bfloat162, uint16_t>(
>     const uint16_t &a, const float scale,
>     const __nv_fp8_interpretation_t fp8_type) {
>   __nv_bfloat162 res;
>   res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale,
>                                                         fp8_type);
>   res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U),
>                                                         scale, fp8_type);
>   return res;
> }
> 
> // fp8x4 -> bf16_4_t
> template <>
> __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
>     const uint32_t &a, const float scale,
>     const __nv_fp8_interpretation_t fp8_type) {
>   bf16_4_t res;
>   res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale,
>                                                           fp8_type);
>   res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
>                                                           scale, fp8_type);
>   return res;
> }
> 
> // fp8x8 -> bf16_8_t
> template <>
> __inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
>     const uint2 &a, const float scale,
>     const __nv_fp8_interpretation_t fp8_type) {
>   bf16_4_t tmp1, tmp2;
>   tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type);
>   tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, fp8_type);
>   bf16_8_t res;
>   res.x = tmp1.x;
>   res.y = tmp1.y;
>   res.z = tmp2.x;
>   res.w = tmp2.y;
>   return res;
> }
> 
> // fp8 -> float
> template <>
> __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
>     const uint8_t &a, const float scale,
>     const __nv_fp8_interpretation_t fp8_type) {
> 
>   // fp8 -> half
>   uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type);
>   // half -> float
>   return half_to_float(tmp) * scale;
> }
> 
> // fp8x2 -> float2
> template <>
> __inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
>     const uint16_t &a, const float scale,
>     const __nv_fp8_interpretation_t fp8_type) {
>   // fp8x2 -> half2
>   uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type);
>   // half2 -> float2
>   return half2_to_float2(tmp);
> }
> 
> // fp8x4 -> float4
> template <>
> __inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
>     const uint32_t &a, const float scale,
>     const __nv_fp8_interpretation_t fp8_type) {
>   Float4_ res;
>   res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type);
>   res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale,
>                                                   fp8_type);
>   return res;
> }
> 
> // fp8x8 -> float8
> template <>
> __inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
>     const uint2 &a, const float scale,
>     const __nv_fp8_interpretation_t fp8_type) {
>   Float4_ tmp1, tmp2;
>   tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type);
>   tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, fp8_type);
>   Float8_ res;
>   res.x = tmp1.x;
>   res.y = tmp1.y;
>   res.z = tmp2.x;
>   res.w = tmp2.y;
>   return res;
> }
> 
> // half -> fp8
> template <>
> __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
>     const uint16_t &a, const float scale,
>     const __nv_fp8_interpretation_t fp8_type) {
>   __nv_fp8_storage_t res =
>       __nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type);
>   return (uint8_t)res;
> }
> 
> // bf16 -> fp8
> template <>
> __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
>     const __nv_bfloat16 &a, const float scale,
>     const __nv_fp8_interpretation_t fp8_type) {
> #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
>   assert(false);
> #else
>   __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale,
>                                                  __NV_SATFINITE, fp8_type);
>   return (uint8_t)res;
> #endif
> }
> 
> // float -> fp8
> template <>
> __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
>     const float &a, const float scale,
>     const __nv_fp8_interpretation_t fp8_type) {
>   __nv_fp8_storage_t res =
>       __nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type);
>   return (uint8_t)res;
> }
> 
> // fp8x4 -> float4
> template <>
> __inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(
>     const uint32_t &a, const float scale,
>     const __nv_fp8_interpretation_t fp8_type) {
>   Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type);
>   float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
>   return res;
> }
> #endif // ENABLE_FP8
> 
> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
> __inline__ __device__ Tout convert(const Tin &x) {
>   switch (kv_dt) {
> #ifdef ENABLE_FP8
>   case Fp8KVCacheDataType::kAuto:
>     // When the type is auto, Tin should be able to be converted to
>     // Tout directly. Thus, the corresponding vec_conversion function
>     // should ignore the last argument (e.g. __NV_E4M3).
>   case Fp8KVCacheDataType::kFp8E4m3:
>     return vec_conversion<Tout, Tin>(x, __NV_E4M3);
>   case Fp8KVCacheDataType::kFp8E5m2:
>     return vec_conversion<Tout, Tin>(x, __NV_E5M2);
> #endif
>   default:
>     assert(false);
>   }
> }
> 
> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
> __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
>   switch (kv_dt) {
> #ifdef ENABLE_FP8
>   case Fp8KVCacheDataType::kAuto:
>     // When the type is auto, Tin should be able to be converted to
>     // Tout directly. Thus, the corresponding vec_conversion function
>     // should ignore the last argument (e.g. __NV_E4M3).
>   case Fp8KVCacheDataType::kFp8E4m3:
>     return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
>   case Fp8KVCacheDataType::kFp8E5m2:
>     return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
> #endif
>   default:
>     assert(false);
>   }
> }
> 
> // The following macro is used to dispatch the conversion function based on the
> // data type of the key and value cache. The FN is a macro that calls a function
> // with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>.
> #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN)                    \
>   if (KV_DTYPE == "auto") {                                                    \
>     if (SRC_DTYPE == at::ScalarType::Float) {                                  \
>       FN(float, float, vllm::Fp8KVCacheDataType::kAuto);                       \
>     } else if (SRC_DTYPE == at::ScalarType::Half) {                            \
>       FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);                 \
>     } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                        \
>       FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);       \
>     } else {                                                                   \
>       TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE);   \
>     }                                                                          \
>   } else {                                                                     \
>     if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") {                         \
>       if (SRC_DTYPE == at::ScalarType::Float) {                                \
>         FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4m3);                \
>       } else if (SRC_DTYPE == at::ScalarType::Half) {                          \
>         FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4m3);             \
>       } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                      \
>         FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4m3);        \
>       } else {                                                                 \
>         TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
>       }                                                                        \
>     } else if (KV_DTYPE == "fp8_e5m2") {                                       \
>       if (SRC_DTYPE == at::ScalarType::Float) {                                \
>         FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5m2);                \
>       } else if (SRC_DTYPE == at::ScalarType::Half) {                          \
>         FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5m2);             \
>       } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                      \
>         FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5m2);        \
>       } else {                                                                 \
>         TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
>       }                                                                        \
>     } else {                                                                   \
>       TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE);     \
>     }                                                                          \
>   }
> 
> } // namespace fp8

#ifdef ENABLE_FP8
case Fp8KVCacheDataType::kAuto:
// When the type is auto, Tin should be able to be converted to
// Tout directly.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment that we are falling throught to the next statement here (same below)

@pcmoritz
Copy link
Collaborator

pcmoritz commented May 3, 2024

Did you investigate the performance impact of passing __nv_fp8_interpretation_t around at runtime? Have you considered making the format a template parameter of the vec_conversion and related functions (e.g. by reusing Fp8KVCacheDataType)?

@comaniac
Copy link
Collaborator Author

comaniac commented May 3, 2024

Did you investigate the performance impact of passing __nv_fp8_interpretation_t around at runtime? Have you considered making the format a template parameter of the vec_conversion and related functions (e.g. by reusing Fp8KVCacheDataType)?

Good question. It would be tedious to put this type to template, because we have roughly 30 overloaded functions. Since C++ doesn't allow partial specialized template, we have to manually duplicate them to 60 functions to cover both formats...

@pcmoritz
Copy link
Collaborator

pcmoritz commented May 3, 2024

Why don't we test if there is a performance overhead (probably the compiler is already smart enough to optimize that -- it should be since the argument is constant in https://github.com/vllm-project/vllm/pull/4535/files#diff-97c4751eafe4ec7333fe2f140e29c84ea054f43d17d4286cc8c4e69a095d09aaR502 and similar for scaled_convert.

Copy link
Collaborator

@pcmoritz pcmoritz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the performance is ok, we can go forward with this.

Thanks a lot for cleaning this up @comaniac ❤️
This code was not pretty and now it is much nicer!

The only thing I'm not a fan of is {nvidia, amd}/quant_utils.cuh. If anybody has ideas how to do that better, that would be very much appreciated!

@comaniac
Copy link
Collaborator Author

comaniac commented May 3, 2024

If the performance is ok, we can go forward with this.

Thanks a lot for cleaning this up @comaniac ❤️ This code was not pretty and now it is much nicer!

The only thing I'm not a fan of is {nvidia, amd}/quant_utils.cuh. If anybody has ideas how to do that better, that would be very much appreciated!

I'll verify the performance. For naming, another way I could think of is {cuda,rocm/hip}/quant_utils.cuh, but I'm open to any proposal.

@pcmoritz
Copy link
Collaborator

pcmoritz commented May 3, 2024

It is not about naming, more about having all these special cases and little conversion utilities :)

@comaniac
Copy link
Collaborator Author

comaniac commented May 3, 2024

Why don't we test if there is a performance overhead (probably the compiler is already smart enough to optimize that -- it should be since the argument is constant in https://github.com/vllm-project/vllm/pull/4535/files#diff-97c4751eafe4ec7333fe2f140e29c84ea054f43d17d4286cc8c4e69a095d09aaR502 and similar for scaled_convert.

I benchmarked on L4 GPU and the latency difference is within 1-2% which should be acceptable.

@comaniac
Copy link
Collaborator Author

comaniac commented May 6, 2024

@HaiShaw @AdrianAbeyta I keep seeing the following error when building this PR with ROCm. It seems like the same attention_generic.cuh headers in different places are included twice. Since I don't get this problem on nvcc, do you have any clue about how to resolve this for ROCm? Thanks.

#20 37.37 [3/13] Building HIP object CMakeFiles/_C.dir/csrc/attention/attention_kernels.hip.o
--
  | #20 37.37 FAILED: CMakeFiles/_C.dir/csrc/attention/attention_kernels.hip.o
  | #20 37.37 /opt/rocm/llvm/bin/clang++  -DTORCH_EXTENSION_NAME=_C -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_DISTRIBUTED -DUSE_PROF_API=1 -DUSE_RPC -DUSE_TENSORPIPE -D_C_EXPORTS -D__HIP_PLATFORM_AMD__=1 -D__HIP_ROCclr__=1 -I/vllm-workspace/csrc -isystem /opt/conda/envs/py_3.9/include/python3.9 -isystem /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/include -isystem /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -isystem /opt/rocm-6.0.0/include/hiprand -O2 -g -DNDEBUG -std=gnu++17 --offload-arch=gfx906 --offload-arch=gfx908 --offload-arch=gfx90a --offload-arch=gfx1030 --offload-arch=gfx1100 --offload-arch=gfx940 --offload-arch=gfx941 --offload-arch=gfx942 -fPIC -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -DUSE_ROCM -DENABLE_FP8 -U__HIP_NO_HALF_CONVERSIONS__ -U__HIP_NO_HALF_OPERATORS__ -fno-gpu-rdc -mllvm -amdgpu-early-inline-all=true -mllvm -amdgpu-function-calls=false -D_GLIBCXX_USE_CXX11_ABI=1 -DTORCH_HIP_VERSION=600 -Wno-shift-count-negative -Wno-shift-count-overflow -Wno-duplicate-decl-specifier -DCAFFE2_USE_MIOPEN -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP -std=c++17 -MD -MT CMakeFiles/_C.dir/csrc/attention/attention_kernels.hip.o -MF CMakeFiles/_C.dir/csrc/attention/attention_kernels.hip.o.d -o CMakeFiles/_C.dir/csrc/attention/attention_kernels.hip.o -x hip -c /vllm-workspace/build/temp.linux-x86_64-cpython-39/csrc/attention/attention_kernels.hip
  | #20 37.37 In file included from /vllm-workspace/build/temp.linux-x86_64-cpython-39/csrc/attention/attention_kernels.hip:31:
  | #20 37.37 In file included from /vllm-workspace/csrc/quantization/fp8/amd/quant_utils.cuh:8:
  | #20 37.37 In file included from /vllm-workspace/csrc/quantization/fp8/amd/../../../attention/attention_dtypes.h:3:
  | #20 37.37 /vllm-workspace/csrc/quantization/fp8/amd/../../../attention/attention_generic.cuh:26:8: error: redefinition of 'Vec'
  | #20 37.37 struct Vec {};
  | #20 37.37        ^
  | #20 37.37 /vllm-workspace/build/temp.linux-x86_64-cpython-39/csrc/attention/attention_generic.cuh:26:8: note: previous definition is here
  | #20 37.37 struct Vec {};
  | #20 37.37        ^

@comaniac comaniac force-pushed the e4m3-kv-cache branch 2 times, most recently from 54d0ee2 to b61b3b7 Compare May 7, 2024 05:29
Copy link
Contributor

@HaiShaw HaiShaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think better not to mix quant_utils with reference to Fp8KVCacheDataType kv_dt in one file, as quant_utils could be used for all other things as well - e.g. activations quantization, we could maintains its autonomy in that sense.

@HaiShaw
Copy link
Contributor

HaiShaw commented May 7, 2024

@HaiShaw @AdrianAbeyta I keep seeing the following error when building this PR with ROCm. It seems like the same attention_generic.cuh headers in different places are included twice. Since I don't get this problem on nvcc, do you have any clue about how to resolve this for ROCm? Thanks.

#20 37.37 [3/13] Building HIP object CMakeFiles/_C.dir/csrc/attention/attention_kernels.hip.o
--
  | #20 37.37 FAILED: CMakeFiles/_C.dir/csrc/attention/attention_kernels.hip.o
  | #20 37.37 /opt/rocm/llvm/bin/clang++  -DTORCH_EXTENSION_NAME=_C -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_DISTRIBUTED -DUSE_PROF_API=1 -DUSE_RPC -DUSE_TENSORPIPE -D_C_EXPORTS -D__HIP_PLATFORM_AMD__=1 -D__HIP_ROCclr__=1 -I/vllm-workspace/csrc -isystem /opt/conda/envs/py_3.9/include/python3.9 -isystem /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/include -isystem /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -isystem /opt/rocm-6.0.0/include/hiprand -O2 -g -DNDEBUG -std=gnu++17 --offload-arch=gfx906 --offload-arch=gfx908 --offload-arch=gfx90a --offload-arch=gfx1030 --offload-arch=gfx1100 --offload-arch=gfx940 --offload-arch=gfx941 --offload-arch=gfx942 -fPIC -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -DUSE_ROCM -DENABLE_FP8 -U__HIP_NO_HALF_CONVERSIONS__ -U__HIP_NO_HALF_OPERATORS__ -fno-gpu-rdc -mllvm -amdgpu-early-inline-all=true -mllvm -amdgpu-function-calls=false -D_GLIBCXX_USE_CXX11_ABI=1 -DTORCH_HIP_VERSION=600 -Wno-shift-count-negative -Wno-shift-count-overflow -Wno-duplicate-decl-specifier -DCAFFE2_USE_MIOPEN -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP -std=c++17 -MD -MT CMakeFiles/_C.dir/csrc/attention/attention_kernels.hip.o -MF CMakeFiles/_C.dir/csrc/attention/attention_kernels.hip.o.d -o CMakeFiles/_C.dir/csrc/attention/attention_kernels.hip.o -x hip -c /vllm-workspace/build/temp.linux-x86_64-cpython-39/csrc/attention/attention_kernels.hip
  | #20 37.37 In file included from /vllm-workspace/build/temp.linux-x86_64-cpython-39/csrc/attention/attention_kernels.hip:31:
  | #20 37.37 In file included from /vllm-workspace/csrc/quantization/fp8/amd/quant_utils.cuh:8:
  | #20 37.37 In file included from /vllm-workspace/csrc/quantization/fp8/amd/../../../attention/attention_dtypes.h:3:
  | #20 37.37 /vllm-workspace/csrc/quantization/fp8/amd/../../../attention/attention_generic.cuh:26:8: error: redefinition of 'Vec'
  | #20 37.37 struct Vec {};
  | #20 37.37        ^
  | #20 37.37 /vllm-workspace/build/temp.linux-x86_64-cpython-39/csrc/attention/attention_generic.cuh:26:8: note: previous definition is here
  | #20 37.37 struct Vec {};
  | #20 37.37        ^

Can you rm -rf build and try again? May give a try soon.

@comaniac
Copy link
Collaborator Author

comaniac commented May 7, 2024

I cannot do that since it's on the CI instead of my local workspace. The issue remains even after I removed the dtype_fp8.cuh header from quant_utils.cuh...

@comaniac
Copy link
Collaborator Author

comaniac commented May 9, 2024

CI passed so we should be good to go. For the comment about not mixing quant_utils with reference to Fp8KVCacheDataType kv_dt in one file, we are also aware of this issue but haven't figured out a better solution yet due to the strong connection between quant_utils and attention_kernel. If needed, we could have another refactor to separate attention kernels entirely.

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic merged commit c833101 into vllm-project:main May 10, 2024
55 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants