Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 34 additions & 24 deletions ggml/src/ggml-cuda/solve_tri.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ static __global__ void solve_tri_f32_fast(
const float* __restrict__ A,
const float* __restrict__ B,
float* __restrict__ X,
const int64_t ne02,
const uint3 ne02,
const size_t nb02, const size_t nb03,
const size_t nb12, const size_t nb13,
const size_t nb2, const size_t nb3) {
Expand All @@ -26,8 +26,9 @@ static __global__ void solve_tri_f32_fast(
return;
}

const int64_t i03 = batch_idx / ne02;
const int64_t i02 = batch_idx % ne02;
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
const int64_t i02 = i02_i03.y;
const int64_t i03 = i02_i03.x;

const float* const A_batch = (const float*)((const char *)A + i02 * nb02 + i03 * nb03);
const float* const B_batch = (const float*)((const char *)B + i02 * nb12 + i03 * nb13);
Expand All @@ -37,14 +38,19 @@ static __global__ void solve_tri_f32_fast(
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
__shared__ float sX[MAX_N_FAST * MAX_K_FAST];

const int offset = threadIdx.x + threadIdx.y * blockDim.x;
// Load A into shared memory (coalesced)
for (int i = threadIdx.x + threadIdx.y * blockDim.x; i < n * n; i += blockDim.x * blockDim.y) {
sA[i] = A_batch[i];
#pragma unroll
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
int i0 = i + offset;
sA[i0] = A_batch[i0];
}

// Load B into shared memory (coalesced)
for (int i = threadIdx.x + threadIdx.y * blockDim.x; i < n * k; i += blockDim.x * blockDim.y) {
sX[i] = B_batch[i];
#pragma unroll
for (int i = 0; i < n * k; i += k * WARP_SIZE) {
int i0 = i + threadIdx.x + threadIdx.y * blockDim.x;
sX[i0] = B_batch[i0];
}
__syncthreads();

Expand Down Expand Up @@ -74,16 +80,18 @@ static __global__ void solve_tri_f32_fast(
}

// Write results from shared memory to global memory (coalesced)
for (int i = threadIdx.x + threadIdx.y * blockDim.x; i < n * k; i += blockDim.x * blockDim.y) {
X_batch[i] = sX[i];
#pragma unroll
for (int i = 0; i < n * k; i += k * WARP_SIZE) {
const int i0 = i + threadIdx.x + threadIdx.y*blockDim.x;
X_batch[i0] = sX[i0];
}
}

static __global__ void solve_tri_f32_fast_general(
const float* __restrict__ A,
const float* __restrict__ B,
float* __restrict__ X,
const int64_t ne02,
const uint3 ne02,
const size_t nb02, const size_t nb03,
const size_t nb12, const size_t nb13,
const size_t nb2, const size_t nb3,
Expand All @@ -97,8 +105,9 @@ static __global__ void solve_tri_f32_fast_general(
return;
}

const int64_t i03 = batch_idx / ne02;
const int64_t i02 = batch_idx % ne02;
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
const int64_t i02 = i02_i03.y;
const int64_t i03 = i02_i03.x;

const float* const A_batch = (const float*)((const char *)A + i02 * nb02 + i03 * nb03);
const float* const B_batch = (const float*)((const char *)B + i02 * nb12 + i03 * nb13);
Expand Down Expand Up @@ -164,44 +173,45 @@ static void solve_tri_f32_cuda(
cudaStream_t stream)
{
// n <= 64, k <= 32
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
dim3 threads(WARP_SIZE, k);
dim3 grid(ne02 * ne03);
if (n == 64) {
if (k == 32) {
solve_tri_f32_fast<64, 32><<<grid, threads, 0, stream>>>(
A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3);
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
} else if (k == 16) {
solve_tri_f32_fast<64, 16><<<grid, threads, 0, stream>>>(
A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3);
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
} else if (k == 14) {
solve_tri_f32_fast<64, 14><<<grid, threads, 0, stream>>>(
A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3);
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
} else if (k == 12) {
solve_tri_f32_fast<64, 12><<<grid, threads, 0, stream>>>(
A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3);
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
} else if (k == 10) {
solve_tri_f32_fast<64, 10><<<grid, threads, 0, stream>>>(
A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3);
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
} else if (k == 8) {
solve_tri_f32_fast<64, 8><<<grid, threads, 0, stream>>>(
A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3);
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
} else if (k == 6) {
solve_tri_f32_fast<64, 6><<<grid, threads, 0, stream>>>(
A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3);
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
} else if (k == 4) {
solve_tri_f32_fast<64, 4><<<grid, threads, 0, stream>>>(
A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3);
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
} else if (k == 2) {
solve_tri_f32_fast<64, 2><<<grid, threads, 0, stream>>>(
A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3);
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
} else if (k == 1) {
solve_tri_f32_fast<64, 1><<<grid, threads, 0, stream>>>(
A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3);
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
} else {
solve_tri_f32_fast_general<<<grid, threads, 0, stream>>>(A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
solve_tri_f32_fast_general<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
}
} else { // run general case
solve_tri_f32_fast_general<<<grid, threads, 0, stream>>>(A, B, X, ne02, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
solve_tri_f32_fast_general<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
}
}

Expand Down
Loading