diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index b5321f748e6b..c93f9d54d780 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -418,6 +418,15 @@ __device__ inline T neg_inf() { return cuda_cast(-cuda::std::numeric_limits::infinity()); } +template +__device__ inline bool is_finite(const T val) { +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800) + return cuda::std::isfinite(val); +#else + return isfinite(cuda_cast(val)); +#endif +} + template __device__ void topk_with_k2(T* output, T const* input, cg::thread_block_tile<32> const& tile, @@ -533,7 +542,7 @@ __global__ void group_idx_and_topk_idx_kernel( // calculate group_idx int32_t target_num_min = WARP_SIZE - n_group + topk_group; // The check is necessary to avoid abnormal input - if (lane_id < n_group && cuda::std::isfinite(group_scores[lane_id])) { + if (lane_id < n_group && is_finite(group_scores[lane_id])) { value = group_scores[lane_id]; } @@ -568,11 +577,10 @@ __global__ void group_idx_and_topk_idx_kernel( int32_t offset = i_group * num_experts_per_group; for (int32_t i = lane_id; i < align_num_experts_per_group; i += WARP_SIZE) { - T candidates = - (i < num_experts_per_group) && - cuda::std::isfinite(scores_with_bias[offset + i]) - ? scores_with_bias[offset + i] - : neg_inf(); + T candidates = (i < num_experts_per_group) && + is_finite(scores_with_bias[offset + i]) + ? scores_with_bias[offset + i] + : neg_inf(); queue.add(candidates, offset + i); } if (group_scores[i_group] == topk_group_value) {