Skip to content
20 changes: 14 additions & 6 deletions csrc/moe/grouped_topk_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,15 @@ __device__ inline T neg_inf() {
return cuda_cast<T, float>(-cuda::std::numeric_limits<float>::infinity());
}

template <typename T>
__device__ inline bool is_finite(const T val) {
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
return cuda::std::isfinite(val);
#else
return isfinite(cuda_cast<float, T>(val));
#endif
}

template <typename T>
__device__ void topk_with_k2(T* output, T const* input,
cg::thread_block_tile<32> const& tile,
Expand Down Expand Up @@ -533,7 +542,7 @@ __global__ void group_idx_and_topk_idx_kernel(
// calculate group_idx
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
// The check is necessary to avoid abnormal input
if (lane_id < n_group && cuda::std::isfinite(group_scores[lane_id])) {
if (lane_id < n_group && is_finite(group_scores[lane_id])) {
value = group_scores[lane_id];
}

Expand Down Expand Up @@ -568,11 +577,10 @@ __global__ void group_idx_and_topk_idx_kernel(
int32_t offset = i_group * num_experts_per_group;
for (int32_t i = lane_id; i < align_num_experts_per_group;
i += WARP_SIZE) {
T candidates =
(i < num_experts_per_group) &&
cuda::std::isfinite(scores_with_bias[offset + i])
? scores_with_bias[offset + i]
: neg_inf<T>();
T candidates = (i < num_experts_per_group) &&
is_finite(scores_with_bias[offset + i])
? scores_with_bias[offset + i]
: neg_inf<T>();
queue.add(candidates, offset + i);
}
if (group_scores[i_group] == topk_group_value) {
Expand Down