diff --git a/csrc/fused_qknorm_rope_kernel.cu b/csrc/fused_qknorm_rope_kernel.cu index cbd23975a773..83017250ebcd 100644 --- a/csrc/fused_qknorm_rope_kernel.cu +++ b/csrc/fused_qknorm_rope_kernel.cu @@ -35,10 +35,12 @@ CHECK_TH_CUDA(x); \ CHECK_CONTIGUOUS(x) -#define FINAL_MASK 0xffffffff +#ifdef USE_ROCM + #define FINAL_MASK 0xffffffffffffffffULL +#else + #define FINAL_MASK 0xffffffff +#endif -// TODO: suport for AMD ROCM platform -#ifndef USE_ROCM namespace tensorrt_llm::common { template struct packed_as; @@ -60,7 +62,7 @@ struct packed_as { template __inline__ __device__ T warpReduceSum(T val) { - #pragma unroll +#pragma unroll for (int mask = 16; mask > 0; mask >>= 1) val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); return val; @@ -97,12 +99,12 @@ __global__ void fusedQKNormRopeKernel( int64_t const* position_ids, // Position IDs for RoPE int const num_tokens // Number of tokens ) { - #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 +#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM) if constexpr ((std::is_same_v) || std::is_same_v) { return; } else { - #endif +#endif using Converter = vllm::_typeConvert; static_assert(Converter::exists, @@ -179,7 +181,7 @@ __global__ void fusedQKNormRopeKernel( { vec_T vec = *reinterpret_cast(&qkv[offsetThread]); constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in); - #pragma unroll +#pragma unroll for (int i = 0; i < num_packed_elems; i++) { // Interpret the generic vector chunk as the specific packed type T2_in packed_val = *(reinterpret_cast(&vec) + i); @@ -200,7 +202,7 @@ __global__ void fusedQKNormRopeKernel( float rms_rcp = rsqrtf(sumOfSquares / static_cast(head_dim) + eps); // Normalize elements - #pragma unroll +#pragma unroll for (int i = 0; i < numElemsPerThread; i++) { int dim = laneId * numElemsPerThread + i; float weight = isQ ? Converter::convert(q_weight[dim]) @@ -222,7 +224,7 @@ __global__ void fusedQKNormRopeKernel( if constexpr (interleave) { // Perform interleaving. Use pre-computed cos/sin values. - #pragma unroll +#pragma unroll for (int i = 0; i < numElemsPerThread / 2; ++i) { int const idx0 = 2 * i; int const idx1 = 2 * i + 1; @@ -245,9 +247,9 @@ __global__ void fusedQKNormRopeKernel( __syncwarp(); // Get the data from the other half of the warp. Use pre-computed cos/sin // values. - #pragma unroll +#pragma unroll for (int i = 0; i < numElemsPerThread; i++) { - elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16); + elements2[i] = __shfl_xor_sync(FINAL_MASK, elements[i], 16); if (laneId < 16) { elements2[i] = -elements2[i]; } @@ -269,7 +271,7 @@ __global__ void fusedQKNormRopeKernel( { vec_T vec; constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in); - #pragma unroll +#pragma unroll for (int i = 0; i < num_packed_elems; i++) { // Convert from float2 back to the specific packed type T2_in packed_val = Converter::convert( @@ -280,21 +282,21 @@ __global__ void fusedQKNormRopeKernel( *reinterpret_cast(&qkv[offsetThread]) = vec; } - #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 +#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM) } - #endif +#endif } - // Borrowed from - // https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568 - #define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \ - if (interleave) { \ - const bool INTERLEAVE = true; \ - __VA_ARGS__ \ - } else { \ - const bool INTERLEAVE = false; \ - __VA_ARGS__ \ - } +// Borrowed from +// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568 +#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \ + if (interleave) { \ + const bool INTERLEAVE = true; \ + __VA_ARGS__ \ + } else { \ + const bool INTERLEAVE = false; \ + __VA_ARGS__ \ + } template void launchFusedQKNormRope(void* qkv, int const num_tokens, @@ -413,6 +415,4 @@ void fused_qk_norm_rope( stream); }); }); -} - -#endif // not USE_ROCM \ No newline at end of file +} \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index d4a69cbe7971..c3ae06a30e3e 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -175,7 +175,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); -#ifndef USE_ROCM // Function for fused QK Norm and RoPE ops.def( "fused_qk_norm_rope(Tensor! qkv, int num_heads_q, " @@ -183,7 +182,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, " "bool is_neox, Tensor position_ids) -> ()"); ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope); -#endif // Apply repetition penalties to logits in-place ops.def( diff --git a/csrc/type_convert.cuh b/csrc/type_convert.cuh index 6da06f1e66cf..2678f69e19b6 100644 --- a/csrc/type_convert.cuh +++ b/csrc/type_convert.cuh @@ -67,9 +67,9 @@ struct _typeConvert { } }; - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + #if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) || defined(USE_ROCM) // CUDA_ARCH < 800 does not have BF16 support -// TODO: Add in ROCm support once public headers handle bf16 maturely +// ROCm 7.0+ supports bfloat16 template <> struct _typeConvert { static constexpr bool exists = true; @@ -89,7 +89,8 @@ struct _typeConvert { return __float22bfloat162_rn(x); } }; - #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + #endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) || + // defined(USE_ROCM) #endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= // 12000)) diff --git a/tests/compile/test_qk_norm_rope_fusion.py b/tests/compile/test_qk_norm_rope_fusion.py index 973123a3af92..511e50f5fdc2 100644 --- a/tests/compile/test_qk_norm_rope_fusion.py +++ b/tests/compile/test_qk_norm_rope_fusion.py @@ -113,8 +113,8 @@ def ops_in_model_after(self) -> list[torch._ops.OpOverload]: @pytest.mark.parametrize("enable_rope_custom_op", [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.skipif( - not current_platform.is_cuda(), - reason="Only test on cuda platform", + not current_platform.is_cuda_alike(), + reason="Only test on cuda and rocm platform", ) def test_qk_norm_rope_fusion( eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype diff --git a/tests/kernels/core/test_fused_qk_norm_rope.py b/tests/kernels/core/test_fused_qk_norm_rope.py index 88bb7691ec3b..a23959e353da 100644 --- a/tests/kernels/core/test_fused_qk_norm_rope.py +++ b/tests/kernels/core/test_fused_qk_norm_rope.py @@ -44,8 +44,8 @@ def _apply_qk_norm_rope( @pytest.mark.skipif( - not current_platform.is_cuda(), - reason="fused_qk_norm_rope custom op requires cuda platform", + not current_platform.is_cuda_alike(), + reason="fused_qk_norm_rope custom op requires cuda and rocm platform", ) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("dtype", DTYPES) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 9c9557df4e73..d098924f0db5 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -184,10 +184,10 @@ def __post_init__(self) -> None: "Fusion enabled but reshape elimination disabled. " "Allreduce + rms norm + quant (fp8) fusion might not work" ) - if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda(): + if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda_alike(): logger.warning_once( "QK Norm + RoPE fusion enabled but the current platform is not " - "CUDA. The fusion will be disabled." + "CUDA or ROCm. The fusion will be disabled." ) self.enable_qk_norm_rope_fusion = False