Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions csrc/fused_qknorm_rope_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
CHECK_TH_CUDA(x); \
CHECK_CONTIGUOUS(x)

#define FINAL_MASK 0xffffffff
#ifdef USE_ROCM
#define FINAL_MASK 0xffffffffffffffffULL
#else
#define FINAL_MASK 0xffffffff
#endif

// TODO: suport for AMD ROCM platform
#ifndef USE_ROCM
namespace tensorrt_llm::common {
template <typename T, int num>
struct packed_as;
Expand All @@ -60,7 +62,7 @@ struct packed_as<uint, 4> {

template <typename T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
return val;
Expand Down Expand Up @@ -97,12 +99,12 @@ __global__ void fusedQKNormRopeKernel(
int64_t const* position_ids, // Position IDs for RoPE
int const num_tokens // Number of tokens
) {
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
if constexpr ((std::is_same_v<scalar_t_in, c10::BFloat16>) ||
std::is_same_v<scalar_t_cache, c10::BFloat16>) {
return;
} else {
#endif
#endif

using Converter = vllm::_typeConvert<scalar_t_in>;
static_assert(Converter::exists,
Expand Down Expand Up @@ -179,7 +181,7 @@ __global__ void fusedQKNormRopeKernel(
{
vec_T vec = *reinterpret_cast<vec_T const*>(&qkv[offsetThread]);
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
#pragma unroll
#pragma unroll
for (int i = 0; i < num_packed_elems; i++) {
// Interpret the generic vector chunk as the specific packed type
T2_in packed_val = *(reinterpret_cast<T2_in*>(&vec) + i);
Expand All @@ -200,7 +202,7 @@ __global__ void fusedQKNormRopeKernel(
float rms_rcp = rsqrtf(sumOfSquares / static_cast<float>(head_dim) + eps);

// Normalize elements
#pragma unroll
#pragma unroll
for (int i = 0; i < numElemsPerThread; i++) {
int dim = laneId * numElemsPerThread + i;
float weight = isQ ? Converter::convert(q_weight[dim])
Expand All @@ -222,7 +224,7 @@ __global__ void fusedQKNormRopeKernel(

if constexpr (interleave) {
// Perform interleaving. Use pre-computed cos/sin values.
#pragma unroll
#pragma unroll
for (int i = 0; i < numElemsPerThread / 2; ++i) {
int const idx0 = 2 * i;
int const idx1 = 2 * i + 1;
Expand All @@ -245,9 +247,9 @@ __global__ void fusedQKNormRopeKernel(
__syncwarp();
// Get the data from the other half of the warp. Use pre-computed cos/sin
// values.
#pragma unroll
#pragma unroll
for (int i = 0; i < numElemsPerThread; i++) {
elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16);
elements2[i] = __shfl_xor_sync(FINAL_MASK, elements[i], 16);
if (laneId < 16) {
Comment on lines 251 to 253

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Limit warp shuffle width to 32 on ROCm

When compiling for ROCm, __shfl_xor_sync defaults to a width equal to the hardware wavefront (64 lanes). The kernel logic assumes 32‑lane groups – warpsPerBlock, laneId, and the head layout all treat 32 threads as one warp – so the shuffle in the non‑interleaved branch must also be restricted to 32 lanes. Leaving it at the default 64 mixes data between two logical warps on AMD GPUs, corrupting the RoPE transformation for Neox models. Pass an explicit width of 32 (as is done in warpReduceSum) to keep the shuffle confined to the intended half‑warp.

Useful? React with 👍 / 👎.

elements2[i] = -elements2[i];
}
Expand All @@ -269,7 +271,7 @@ __global__ void fusedQKNormRopeKernel(
{
vec_T vec;
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
#pragma unroll
#pragma unroll
for (int i = 0; i < num_packed_elems; i++) {
// Convert from float2 back to the specific packed type
T2_in packed_val = Converter::convert(
Expand All @@ -280,21 +282,21 @@ __global__ void fusedQKNormRopeKernel(
*reinterpret_cast<vec_T*>(&qkv[offsetThread]) = vec;
}

#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
}
#endif
#endif
}

// Borrowed from
// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
if (interleave) { \
const bool INTERLEAVE = true; \
__VA_ARGS__ \
} else { \
const bool INTERLEAVE = false; \
__VA_ARGS__ \
}
// Borrowed from
// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
if (interleave) { \
const bool INTERLEAVE = true; \
__VA_ARGS__ \
} else { \
const bool INTERLEAVE = false; \
__VA_ARGS__ \
}

template <typename scalar_t_in, typename scalar_t_cache>
void launchFusedQKNormRope(void* qkv, int const num_tokens,
Expand Down Expand Up @@ -413,6 +415,4 @@ void fused_qk_norm_rope(
stream);
});
});
}

#endif // not USE_ROCM
}
2 changes: 0 additions & 2 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);

#ifndef USE_ROCM
// Function for fused QK Norm and RoPE
ops.def(
"fused_qk_norm_rope(Tensor! qkv, int num_heads_q, "
"int num_heads_k, int num_heads_v, int head_dim, float eps, "
"Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, "
"bool is_neox, Tensor position_ids) -> ()");
ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);
#endif

// Apply repetition penalties to logits in-place
ops.def(
Expand Down
7 changes: 4 additions & 3 deletions csrc/type_convert.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ struct _typeConvert<c10::Half> {
}
};

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) || defined(USE_ROCM)
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
// ROCm 7.0+ supports bfloat16
template <>
struct _typeConvert<c10::BFloat16> {
static constexpr bool exists = true;
Expand All @@ -89,7 +89,8 @@ struct _typeConvert<c10::BFloat16> {
return __float22bfloat162_rn(x);
}
};
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) ||
// defined(USE_ROCM)
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))

Expand Down
4 changes: 2 additions & 2 deletions tests/compile/test_qk_norm_rope_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
@pytest.mark.parametrize("enable_rope_custom_op", [True])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="Only test on cuda platform",
not current_platform.is_cuda_alike(),
reason="Only test on cuda and rocm platform",
)
def test_qk_norm_rope_fusion(
eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype
Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/core/test_fused_qk_norm_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def _apply_qk_norm_rope(


@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="fused_qk_norm_rope custom op requires cuda platform",
not current_platform.is_cuda_alike(),
reason="fused_qk_norm_rope custom op requires cuda and rocm platform",
)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("dtype", DTYPES)
Expand Down
4 changes: 2 additions & 2 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,10 @@ def __post_init__(self) -> None:
"Fusion enabled but reshape elimination disabled. "
"Allreduce + rms norm + quant (fp8) fusion might not work"
)
if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda():
if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda_alike():
logger.warning_once(
"QK Norm + RoPE fusion enabled but the current platform is not "
"CUDA. The fusion will be disabled."
"CUDA or ROCm. The fusion will be disabled."
)
self.enable_qk_norm_rope_fusion = False

Expand Down