From 34d4ea6ad3b5b07d54c99b0e858d1b0b8af723d5 Mon Sep 17 00:00:00 2001 From: charlifu Date: Tue, 15 Oct 2024 17:53:22 +0000 Subject: [PATCH 1/6] add moe_sum and improve the thread number of moe align Signed-off-by: charlifu --- CMakeLists.txt | 2 +- .../moe_align_sum_kernels.cu} | 101 +++++++++++++++--- csrc/moe/moe_ops.h | 7 ++ csrc/moe/torch_bindings.cpp | 14 +++ csrc/ops.h | 5 - csrc/torch_bindings.cpp | 9 -- vllm/_custom_ops.py | 5 +- .../layers/fused_moe/fused_moe.py | 5 +- 8 files changed, 115 insertions(+), 33 deletions(-) rename csrc/{moe_align_block_size_kernels.cu => moe/moe_align_sum_kernels.cu} (58%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3a424ad7b110..fc4465089535 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -213,7 +213,6 @@ set(VLLM_EXT_SRC "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/common.cu" "csrc/cuda_utils_kernels.cu" - "csrc/moe_align_block_size_kernels.cu" "csrc/prepare_inputs/advance_step.cu" "csrc/torch_bindings.cpp") @@ -423,6 +422,7 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" + "csrc/moe/moe_align_sum_kernels.cu" "csrc/moe/topk_softmax_kernels.cu") set_gencode_flags_for_srcs( diff --git a/csrc/moe_align_block_size_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu similarity index 58% rename from csrc/moe_align_block_size_kernels.cu rename to csrc/moe/moe_align_sum_kernels.cu index 1f8d75da83bb..54007888eeea 100644 --- a/csrc/moe_align_block_size_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -1,15 +1,17 @@ #include #include +#include #include #include -#include "cuda_compat.h" -#include "dispatch_utils.h" +#include "../cuda_compat.h" +#include "../dispatch_utils.h" #define CEILDIV(x, y) (((x) + (y) - 1) / (y)) namespace vllm { +namespace moe { namespace { __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, @@ -32,9 +34,9 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, extern __shared__ int32_t shared_mem[]; int32_t* tokens_cnts = - shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) + shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts) int32_t* cumsum = - shared_mem + (num_experts + 1) * + shared_mem + (blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1) for (int i = 0; i < num_experts; ++i) { @@ -53,10 +55,12 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, __syncthreads(); // For each expert we accumulate the token counts from the different threads. - tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; - for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[index(num_experts, i, threadIdx.x)] += - tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + if (threadIdx.x < num_experts) { + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[index(num_experts, i, threadIdx.x)] += + tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + } } __syncthreads(); @@ -79,9 +83,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, * For each expert, each thread processes the tokens of the corresponding * blocks and stores the corresponding expert_id for each block. */ - for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; - i += block_size) { - expert_ids[i / block_size] = threadIdx.x; + if (threadIdx.x < num_experts) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; + i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } } /** @@ -106,6 +112,25 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; } } + +template +__global__ void moe_sum_kernel( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., topk, d] + const int d) +{ + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + scalar_t x = 0.0; + #pragma unroll + for (int k = 0; k < TOPK; ++k) { + x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]); + } + out[token_idx * d + idx] = x; + } +} + +} // namespace moe } // namespace vllm void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, @@ -117,18 +142,66 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { // calc needed amount of shared mem for `tokens_cnts` and `cumsum` // tensors + const int32_t num_thread = WARP_SIZE; const int32_t shared_mem = - ((num_experts + 1) * num_experts + (num_experts + 1)) * + ((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); // set dynamic shared mem - auto kernel = vllm::moe_align_block_size_kernel; + auto kernel = vllm::moe::moe_align_block_size_kernel; AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( (void*)kernel, shared_mem)); - kernel<<<1, num_experts, shared_mem, stream>>>( + kernel<<<1, num_thread, shared_mem, stream>>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), num_experts, block_size, topk_ids.numel()); }); } + +void moe_sum( + torch::Tensor& input, // [num_tokens, topk, hidden_size] + torch::Tensor& output) // [num_tokens, hidden_size] +{ + const int hidden_size = input.size(-1); + const int num_tokens = output.numel() / hidden_size; + const int topk = input.size(1); + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + switch (topk) { + case 2: + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "moe_sum_kernel", [&] { + vllm::moe::moe_sum_kernel + <<>>(output.data_ptr(), + input.data_ptr(), hidden_size); + }); + break; + + case 3: + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "moe_sum_kernel", [&] { + vllm::moe::moe_sum_kernel + <<>>(output.data_ptr(), + input.data_ptr(), hidden_size); + }); + break; + + case 4: + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "moe_sum_kernel", [&] { + vllm::moe::moe_sum_kernel + <<>>(output.data_ptr(), + input.data_ptr(), hidden_size); + }); + break; + + default: + at::sum_out(output, input, 1); + break; + } +} diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index a251730aa765..596cc0aa6c85 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -5,3 +5,10 @@ void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& token_expert_indices, torch::Tensor& gating_output); + +void moe_sum(torch::Tensor& input, torch::Tensor& output); + +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 18fbc57ac783..c2f79d5d232b 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -7,6 +7,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " "token_expert_indices, Tensor gating_output) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); + + // Calculate the result of moe by summing up the partial results + // from all selected experts. + m.def("moe_sum(Tensor! input, Tensor output) -> ()"); + m.impl("moe_sum", torch::kCUDA, &moe_sum); + + // Aligning the number of tokens to be processed by each expert such + // that it is divisible by the block size. + m.def( + "moe_align_block_size(Tensor topk_ids, int num_experts," + " int block_size, Tensor! sorted_token_ids," + " Tensor! experts_ids," + " Tensor! num_tokens_post_pad) -> ()"); + m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); #ifndef USE_ROCM m.def( diff --git a/csrc/ops.h b/csrc/ops.h index fce545f95a7c..da69827e4f2b 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -142,11 +142,6 @@ void dynamic_per_token_scaled_fp8_quant( torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, c10::optional const& scale_ub); -void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, - int64_t block_size, torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad); - void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& C, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index a0100b4a85ed..796bf1473fb0 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -334,15 +334,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, &dynamic_per_token_scaled_fp8_quant); - // Aligning the number of tokens to be processed by each expert such - // that it is divisible by the block size. - ops.def( - "moe_align_block_size(Tensor topk_ids, int num_experts," - " int block_size, Tensor! sorted_token_ids," - " Tensor! experts_ids," - " Tensor! num_tokens_post_pad) -> ()"); - ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); - // Compute int8 quantized tensor for given scaling factor. ops.def( "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3a23692285ef..9c913daae5c9 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -830,11 +830,14 @@ def selective_scan_fwd( # moe +def moe_sum(input: torch.Tensor, output: torch.Tensor): + torch.ops._moe_C.moe_sum(input, output) + def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor) -> None: - torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size, + torch.ops._moe_C.moe_align_block_size(topk_ids, num_experts, block_size, sorted_token_ids, experts_ids, num_tokens_post_pad) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index b1d3bc0a5f05..da4225fe9898 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -589,9 +589,8 @@ def fused_experts(hidden_states: torch.Tensor, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16) - torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=out_hidden_states[begin_chunk_idx:end_chunk_idx]) + ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx]) return out_hidden_states From 2f0fef9c43e8c83ac8ff98e66139e521263510ef Mon Sep 17 00:00:00 2001 From: charlifu Date: Tue, 15 Oct 2024 18:14:50 +0000 Subject: [PATCH 2/6] linting Signed-off-by: charlifu --- csrc/moe/torch_bindings.cpp | 4 ++-- vllm/_custom_ops.py | 5 +++-- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index c2f79d5d232b..b9bb2291d544 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -7,12 +7,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " "token_expert_indices, Tensor gating_output) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); - + // Calculate the result of moe by summing up the partial results // from all selected experts. m.def("moe_sum(Tensor! input, Tensor output) -> ()"); m.impl("moe_sum", torch::kCUDA, &moe_sum); - + // Aligning the number of tokens to be processed by each expert such // that it is divisible by the block size. m.def( diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 9c913daae5c9..dd3962182397 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -833,13 +833,14 @@ def selective_scan_fwd( def moe_sum(input: torch.Tensor, output: torch.Tensor): torch.ops._moe_C.moe_sum(input, output) + def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor) -> None: torch.ops._moe_C.moe_align_block_size(topk_ids, num_experts, block_size, - sorted_token_ids, experts_ids, - num_tokens_post_pad) + sorted_token_ids, experts_ids, + num_tokens_post_pad) def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index da4225fe9898..90a4209b5bce 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -590,7 +590,7 @@ def fused_experts(hidden_states: torch.Tensor, use_int8_w8a16=use_int8_w8a16) ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), - out_hidden_states[begin_chunk_idx:end_chunk_idx]) + out_hidden_states[begin_chunk_idx:end_chunk_idx]) return out_hidden_states From 7bd636f1d07986a0485de1f7638e475975df39ce Mon Sep 17 00:00:00 2001 From: charlifu Date: Tue, 15 Oct 2024 18:25:42 +0000 Subject: [PATCH 3/6] linting Signed-off-by: charlifu --- csrc/moe/moe_align_sum_kernels.cu | 99 +++++++++++++++---------------- 1 file changed, 47 insertions(+), 52 deletions(-) diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 54007888eeea..25f4220cfd3c 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -36,8 +36,8 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts) int32_t* cumsum = - shared_mem + (blockDim.x + 1) * - num_experts; // 1d tensor with shape (num_experts + 1) + shared_mem + + (blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1) for (int i = 0; i < num_experts; ++i) { tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; @@ -115,19 +115,18 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, template __global__ void moe_sum_kernel( - scalar_t* __restrict__ out, // [..., d] - const scalar_t* __restrict__ input, // [..., topk, d] - const int d) -{ - const int64_t token_idx = blockIdx.x; - for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - scalar_t x = 0.0; - #pragma unroll - for (int k = 0; k < TOPK; ++k) { - x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]); - } - out[token_idx * d + idx] = x; + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., topk, d] + const int d) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + scalar_t x = 0.0; +#pragma unroll + for (int k = 0; k < TOPK; ++k) { + x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]); } + out[token_idx * d + idx] = x; + } } } // namespace moe @@ -159,49 +158,45 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, }); } -void moe_sum( - torch::Tensor& input, // [num_tokens, topk, hidden_size] - torch::Tensor& output) // [num_tokens, hidden_size] +void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] + torch::Tensor& output) // [num_tokens, hidden_size] { - const int hidden_size = input.size(-1); - const int num_tokens = output.numel() / hidden_size; - const int topk = input.size(1); - - dim3 grid(num_tokens); - dim3 block(std::min(hidden_size, 1024)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - switch (topk) { + const int hidden_size = input.size(-1); + const int num_tokens = output.numel() / hidden_size; + const int topk = input.size(1); + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + switch (topk) { case 2: - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "moe_sum_kernel", [&] { - vllm::moe::moe_sum_kernel - <<>>(output.data_ptr(), - input.data_ptr(), hidden_size); - }); - break; + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { + vllm::moe::moe_sum_kernel<<>>( + output.data_ptr(), input.data_ptr(), + hidden_size); + }); + break; case 3: - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "moe_sum_kernel", [&] { - vllm::moe::moe_sum_kernel - <<>>(output.data_ptr(), - input.data_ptr(), hidden_size); - }); - break; - + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { + vllm::moe::moe_sum_kernel<<>>( + output.data_ptr(), input.data_ptr(), + hidden_size); + }); + break; + case 4: - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "moe_sum_kernel", [&] { - vllm::moe::moe_sum_kernel - <<>>(output.data_ptr(), - input.data_ptr(), hidden_size); - }); - break; + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { + vllm::moe::moe_sum_kernel<<>>( + output.data_ptr(), input.data_ptr(), + hidden_size); + }); + break; default: - at::sum_out(output, input, 1); - break; - } + at::sum_out(output, input, 1); + break; + } } From cf2c9900bf4896b67b50e4d10a3668a7ed16bce2 Mon Sep 17 00:00:00 2001 From: charlifu Date: Mon, 21 Oct 2024 18:44:34 +0000 Subject: [PATCH 4/6] fix unit test Signed-off-by: charlifu --- tests/kernels/test_moe.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index b73c45b9cd19..c5f5502e24e0 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -19,7 +19,7 @@ marlin_quantize) from vllm.model_executor.models.mixtral import MixtralMoE from vllm.scalar_type import scalar_types -from vllm.utils import seed_everything +from vllm.utils import is_hip, seed_everything @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1]) @@ -103,6 +103,7 @@ def test_mixtral_moe(dtype: torch.dtype): @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("num_bits", [4, 8]) @pytest.mark.parametrize("is_k_full", [True, False]) +@pytest.mark.skipif(is_hip(), reason="Skip for rocm") def test_fused_marlin_moe( m: int, n: int, @@ -255,6 +256,7 @@ def test_fused_marlin_moe( @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("num_bits", [4, 8]) @pytest.mark.parametrize("is_k_full", [True, False]) +@pytest.mark.skipif(is_hip(), reason="Skip for rocm") def test_single_marlin_moe_multiply( m: int, n: int, @@ -345,6 +347,6 @@ def test_moe_align_block_size_opcheck(): dtype=torch.int32, device=topk_ids.device) - opcheck(torch.ops._C.moe_align_block_size, + opcheck(torch.ops._moe_C.moe_align_block_size, (topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad)) From 4e5bd4952e9368ec3cdcacbc7ceba448e829873d Mon Sep 17 00:00:00 2001 From: charlifu Date: Thu, 24 Oct 2024 19:43:35 +0000 Subject: [PATCH 5/6] fix num_expert 64 Signed-off-by: charlifu --- csrc/moe/moe_align_sum_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 25f4220cfd3c..1eada59d6f88 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -141,7 +141,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { // calc needed amount of shared mem for `tokens_cnts` and `cumsum` // tensors - const int32_t num_thread = WARP_SIZE; + const int32_t num_thread = max(num_experts, WARP_SIZE); const int32_t shared_mem = ((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); From 04af812941f2ebb4e5efe55d7c870670e17755a1 Mon Sep 17 00:00:00 2001 From: charlifu Date: Thu, 24 Oct 2024 19:45:56 +0000 Subject: [PATCH 6/6] cast to in32 Signed-off-by: charlifu --- csrc/moe/moe_align_sum_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 1eada59d6f88..fff7ce34c838 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -141,7 +141,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { // calc needed amount of shared mem for `tokens_cnts` and `cumsum` // tensors - const int32_t num_thread = max(num_experts, WARP_SIZE); + const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); const int32_t shared_mem = ((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);