Skip to content

Commit 76cee36

Browse files
authored
[CUDA] Fix build for sm<53 (#24582)
### Description There is some build error for `--cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=52`. Some half2 function like `__hfma2` used in MatMul 8 bits is not defined for sm < 53. Add an implementation that does not use half2 for those old GPUs. Fix another build error using cuda 12.5 that is caused by extra `const` in MOE code for sm<53. ### Motivation and Context Fix nuget packaging pipeline, which uses `CMAKE_CUDA_ARCHITECTURES=52-real;61-real;75-real;86-real;89-real;90-virtual`.
1 parent f7619dc commit 76cee36

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,7 +1041,7 @@ void initialize_moe_routing_kernelLauncher(const T *unpermuted_input, T *permute
10411041
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530
10421042
template <typename T, int RESIDUAL_NUM>
10431043
__global__ void finalize_moe_routing_kernel(const T *, T *, const T *, const T *, const T *, const T *, const int *,
1044-
const int *, int, const int) {
1044+
const int *, int, int) {
10451045
// Does not support pre-Kepler architectures
10461046
;
10471047
}
@@ -1168,4 +1168,4 @@ template void finalize_moe_routing_kernelLauncher(const float *, float *, const
11681168
template void finalize_moe_routing_kernelLauncher(const half *, half *, const half *, const half *, const half *,
11691169
const half *, const int *, const int *, int, int, int, cudaStream_t);
11701170

1171-
} // namespace ort_fastertransformer
1171+
} // namespace ort_fastertransformer

onnxruntime/contrib_ops/cuda/quantization/matmul_8bits.cu

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ __device__ __forceinline__ void AccumulateEightElements8b(
3232
const half* a, // Pointer to 8 half values from A
3333
half* sums) { // Pointer to 8 partial sums (half)
3434

35+
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530)
3536
// --- Dequantization Setup ---
3637
half2 scale_h2 = __half2half2(scale); // Broadcast scale
3738
half zp_h = __ushort2half_rn(zp); // Convert zp to half
@@ -74,6 +75,27 @@ __device__ __forceinline__ void AccumulateEightElements8b(
7475
sums_half2[1] = __hfma2(a_vec1, b_vec1, sums_half2[1]); // {s2+=a2*b2, s3+=a3*b3}
7576
sums_half2[2] = __hfma2(a_vec2, b_vec2, sums_half2[2]); // {s4+=a4*b4, s5+=a5*b5}
7677
sums_half2[3] = __hfma2(a_vec3, b_vec3, sums_half2[3]); // {s6+=a6*b6, s7+=a7*b7}
78+
79+
#else // older GPUs of compute capability < 5.3, which lacks native half support.
80+
float scale_f = __half2float(scale);
81+
float zp_f = static_cast<float>(zp);
82+
83+
float b_dequant[8];
84+
#pragma unroll
85+
for (int i = 0; i < 8; ++i) {
86+
uint8_t q = (values_quant >> (i * 8)) & 0xFF;
87+
b_dequant[i] = (static_cast<float>(q) - zp_f) * scale_f;
88+
}
89+
90+
#pragma unroll
91+
for (int i = 0; i < 8; ++i) {
92+
float a_f = __half2float(a[i]);
93+
float product_f = a_f * b_dequant[i];
94+
// Convert back to half for partial sums. It is not ideal for performance.
95+
half product_h = __float2half_rn(product_f);
96+
sums[i] += product_h;
97+
}
98+
#endif
7799
}
78100

79101
// --- Device Function: Accumulate 8 Elements (float precision) ---

0 commit comments

Comments
 (0)