From 42d6d582c708d144c1df968e64bb7e6a79d4d2ba Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Mon, 24 Nov 2025 20:52:32 +0800 Subject: [PATCH] optimize --- ggml/src/ggml-cuda/solve_tri.cu | 58 +++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu index 93b804fa717..b6369b6cc49 100644 --- a/ggml/src/ggml-cuda/solve_tri.cu +++ b/ggml/src/ggml-cuda/solve_tri.cu @@ -13,7 +13,7 @@ static __global__ void solve_tri_f32_fast( const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ X, - const int64_t ne02, + const uint3 ne02, const size_t nb02, const size_t nb03, const size_t nb12, const size_t nb13, const size_t nb2, const size_t nb3) { @@ -26,8 +26,9 @@ static __global__ void solve_tri_f32_fast( return; } - const int64_t i03 = batch_idx / ne02; - const int64_t i02 = batch_idx % ne02; + const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02); + const int64_t i02 = i02_i03.y; + const int64_t i03 = i02_i03.x; const float* const A_batch = (const float*)((const char *)A + i02 * nb02 + i03 * nb03); const float* const B_batch = (const float*)((const char *)B + i02 * nb12 + i03 * nb13); @@ -37,14 +38,19 @@ static __global__ void solve_tri_f32_fast( __shared__ float sA[MAX_N_FAST * MAX_N_FAST]; __shared__ float sX[MAX_N_FAST * MAX_K_FAST]; + const int offset = threadIdx.x + threadIdx.y * blockDim.x; // Load A into shared memory (coalesced) - for (int i = threadIdx.x + threadIdx.y * blockDim.x; i < n * n; i += blockDim.x * blockDim.y) { - sA[i] = A_batch[i]; +#pragma unroll + for (int i = 0; i < n * n; i += k * WARP_SIZE) { + int i0 = i + offset; + sA[i0] = A_batch[i0]; } // Load B into shared memory (coalesced) - for (int i = threadIdx.x + threadIdx.y * blockDim.x; i < n * k; i += blockDim.x * blockDim.y) { - sX[i] = B_batch[i]; +#pragma unroll + for (int i = 0; i < n * k; i += k * WARP_SIZE) { + int i0 = i + threadIdx.x + threadIdx.y * blockDim.x; + sX[i0] = B_batch[i0]; } __syncthreads(); @@ -74,8 +80,10 @@ static __global__ void solve_tri_f32_fast( } // Write results from shared memory to global memory (coalesced) - for (int i = threadIdx.x + threadIdx.y * blockDim.x; i < n * k; i += blockDim.x * blockDim.y) { - X_batch[i] = sX[i]; +#pragma unroll + for (int i = 0; i < n * k; i += k * WARP_SIZE) { + const int i0 = i + threadIdx.x + threadIdx.y*blockDim.x; + X_batch[i0] = sX[i0]; } } @@ -83,7 +91,7 @@ static __global__ void solve_tri_f32_fast_general( const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ X, - const int64_t ne02, + const uint3 ne02, const size_t nb02, const size_t nb03, const size_t nb12, const size_t nb13, const size_t nb2, const size_t nb3, @@ -97,8 +105,9 @@ static __global__ void solve_tri_f32_fast_general( return; } - const int64_t i03 = batch_idx / ne02; - const int64_t i02 = batch_idx % ne02; + const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02); + const int64_t i02 = i02_i03.y; + const int64_t i03 = i02_i03.x; const float* const A_batch = (const float*)((const char *)A + i02 * nb02 + i03 * nb03); const float* const B_batch = (const float*)((const char *)B + i02 * nb12 + i03 * nb13); @@ -164,44 +173,45 @@ static void solve_tri_f32_cuda( cudaStream_t stream) { // n <= 64, k <= 32 + const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02); dim3 threads(WARP_SIZE, k); dim3 grid(ne02 * ne03); if (n == 64) { if (k == 32) { solve_tri_f32_fast<64, 32><<>>( - A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3); + A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3); } else if (k == 16) { solve_tri_f32_fast<64, 16><<>>( - A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3); + A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3); } else if (k == 14) { solve_tri_f32_fast<64, 14><<>>( - A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3); + A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3); } else if (k == 12) { solve_tri_f32_fast<64, 12><<>>( - A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3); + A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3); } else if (k == 10) { solve_tri_f32_fast<64, 10><<>>( - A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3); + A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3); } else if (k == 8) { solve_tri_f32_fast<64, 8><<>>( - A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3); + A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3); } else if (k == 6) { solve_tri_f32_fast<64, 6><<>>( - A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3); + A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3); } else if (k == 4) { solve_tri_f32_fast<64, 4><<>>( - A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3); + A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3); } else if (k == 2) { solve_tri_f32_fast<64, 2><<>>( - A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3); + A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3); } else if (k == 1) { solve_tri_f32_fast<64, 1><<>>( - A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3); + A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3); } else { - solve_tri_f32_fast_general<<>>(A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3, n, k); + solve_tri_f32_fast_general<<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k); } } else { // run general case - solve_tri_f32_fast_general<<>>(A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3, n, k); + solve_tri_f32_fast_general<<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k); } }