diff --git a/CMakeLists.txt b/CMakeLists.txt index d1956f3d409b..fc4ac10b7669 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -195,7 +195,6 @@ set(VLLM_EXT_SRC "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/common.cu" "csrc/cuda_utils_kernels.cu" - "csrc/moe_align_block_size_kernels.cu" "csrc/prepare_inputs/advance_step.cu" "csrc/torch_bindings.cpp") @@ -405,6 +404,7 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" + "csrc/moe/moe_align_sum_kernels.cu" "csrc/moe/topk_softmax_kernels.cu") set_gencode_flags_for_srcs( diff --git a/csrc/moe_align_block_size_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu similarity index 59% rename from csrc/moe_align_block_size_kernels.cu rename to csrc/moe/moe_align_sum_kernels.cu index 1f8d75da83bb..fff7ce34c838 100644 --- a/csrc/moe_align_block_size_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -1,15 +1,17 @@ #include #include +#include #include #include -#include "cuda_compat.h" -#include "dispatch_utils.h" +#include "../cuda_compat.h" +#include "../dispatch_utils.h" #define CEILDIV(x, y) (((x) + (y) - 1) / (y)) namespace vllm { +namespace moe { namespace { __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, @@ -32,10 +34,10 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, extern __shared__ int32_t shared_mem[]; int32_t* tokens_cnts = - shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) + shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts) int32_t* cumsum = - shared_mem + (num_experts + 1) * - num_experts; // 1d tensor with shape (num_experts + 1) + shared_mem + + (blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1) for (int i = 0; i < num_experts; ++i) { tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; @@ -53,10 +55,12 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, __syncthreads(); // For each expert we accumulate the token counts from the different threads. - tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; - for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[index(num_experts, i, threadIdx.x)] += - tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + if (threadIdx.x < num_experts) { + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[index(num_experts, i, threadIdx.x)] += + tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + } } __syncthreads(); @@ -79,9 +83,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, * For each expert, each thread processes the tokens of the corresponding * blocks and stores the corresponding expert_id for each block. */ - for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; - i += block_size) { - expert_ids[i / block_size] = threadIdx.x; + if (threadIdx.x < num_experts) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; + i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } } /** @@ -106,6 +112,24 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; } } + +template +__global__ void moe_sum_kernel( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., topk, d] + const int d) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + scalar_t x = 0.0; +#pragma unroll + for (int k = 0; k < TOPK; ++k) { + x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]); + } + out[token_idx * d + idx] = x; + } +} + +} // namespace moe } // namespace vllm void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, @@ -117,18 +141,62 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { // calc needed amount of shared mem for `tokens_cnts` and `cumsum` // tensors + const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); const int32_t shared_mem = - ((num_experts + 1) * num_experts + (num_experts + 1)) * + ((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); // set dynamic shared mem - auto kernel = vllm::moe_align_block_size_kernel; + auto kernel = vllm::moe::moe_align_block_size_kernel; AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( (void*)kernel, shared_mem)); - kernel<<<1, num_experts, shared_mem, stream>>>( + kernel<<<1, num_thread, shared_mem, stream>>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), num_experts, block_size, topk_ids.numel()); }); } + +void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] + torch::Tensor& output) // [num_tokens, hidden_size] +{ + const int hidden_size = input.size(-1); + const int num_tokens = output.numel() / hidden_size; + const int topk = input.size(1); + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + switch (topk) { + case 2: + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { + vllm::moe::moe_sum_kernel<<>>( + output.data_ptr(), input.data_ptr(), + hidden_size); + }); + break; + + case 3: + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { + vllm::moe::moe_sum_kernel<<>>( + output.data_ptr(), input.data_ptr(), + hidden_size); + }); + break; + + case 4: + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { + vllm::moe::moe_sum_kernel<<>>( + output.data_ptr(), input.data_ptr(), + hidden_size); + }); + break; + + default: + at::sum_out(output, input, 1); + break; + } +} diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index a251730aa765..596cc0aa6c85 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -5,3 +5,10 @@ void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& token_expert_indices, torch::Tensor& gating_output); + +void moe_sum(torch::Tensor& input, torch::Tensor& output); + +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 019c6cedd3d8..f3a558c14ab9 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -8,6 +8,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "token_expert_indices, Tensor gating_output) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); + // Calculate the result of moe by summing up the partial results + // from all selected experts. + m.def("moe_sum(Tensor! input, Tensor output) -> ()"); + m.impl("moe_sum", torch::kCUDA, &moe_sum); + + // Aligning the number of tokens to be processed by each expert such + // that it is divisible by the block size. + m.def( + "moe_align_block_size(Tensor topk_ids, int num_experts," + " int block_size, Tensor! sorted_token_ids," + " Tensor! experts_ids," + " Tensor! num_tokens_post_pad) -> ()"); + m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); + #ifndef USE_ROCM m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " diff --git a/csrc/ops.h b/csrc/ops.h index c10c34e08575..04be2c65b12a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -142,11 +142,6 @@ void dynamic_per_token_scaled_fp8_quant( torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, c10::optional const& scale_ub); -void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, - int64_t block_size, torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad); - void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& C, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b999028fe06a..06aeb06d7e97 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -332,15 +332,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, &dynamic_per_token_scaled_fp8_quant); - // Aligning the number of tokens to be processed by each expert such - // that it is divisible by the block size. - ops.def( - "moe_align_block_size(Tensor topk_ids, int num_experts," - " int block_size, Tensor! sorted_token_ids," - " Tensor! experts_ids," - " Tensor! num_tokens_post_pad) -> ()"); - ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); - // Compute int8 quantized tensor for given scaling factor. ops.def( "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index b87fbc3f1937..c0053071258e 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -19,7 +19,7 @@ marlin_quantize) from vllm.model_executor.models.mixtral import MixtralMoE from vllm.scalar_type import scalar_types -from vllm.utils import seed_everything +from vllm.utils import is_hip, seed_everything @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1]) @@ -103,6 +103,7 @@ def test_mixtral_moe(dtype: torch.dtype): @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("num_bits", [4, 8]) @pytest.mark.parametrize("is_k_full", [True, False]) +@pytest.mark.skipif(is_hip(), reason="Skip for rocm") def test_fused_marlin_moe( m: int, n: int, @@ -255,6 +256,7 @@ def test_fused_marlin_moe( @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("num_bits", [4, 8]) @pytest.mark.parametrize("is_k_full", [True, False]) +@pytest.mark.skipif(is_hip(), reason="Skip for rocm") def test_single_marlin_moe_multiply( m: int, n: int, @@ -345,6 +347,6 @@ def test_moe_align_block_size_opcheck(): dtype=torch.int32, device=topk_ids.device) - opcheck(torch.ops._C.moe_align_block_size, + opcheck(torch.ops._moe_C.moe_align_block_size, (topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad)) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a25f7abca549..a6efb32745a7 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -807,13 +807,17 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, # moe +def moe_sum(input: torch.Tensor, output: torch.Tensor): + torch.ops._moe_C.moe_sum(input, output) + + def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor) -> None: - torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size, - sorted_token_ids, experts_ids, - num_tokens_post_pad) + torch.ops._moe_C.moe_align_block_size(topk_ids, num_experts, block_size, + sorted_token_ids, experts_ids, + num_tokens_post_pad) def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index b1d3bc0a5f05..90a4209b5bce 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -589,9 +589,8 @@ def fused_experts(hidden_states: torch.Tensor, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16) - torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=out_hidden_states[begin_chunk_idx:end_chunk_idx]) + ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx]) return out_hidden_states