Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 174 additions & 13 deletions ggml/src/ggml-cpu/llamafile/sgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2673,7 +2673,6 @@ class tinyBLAS_PPC {

private:

void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);

template<typename VA>
void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) {
Expand Down Expand Up @@ -2920,7 +2919,72 @@ class tinyBLAS_PPC {
}
}

void KERNEL_4x4(int64_t ii, int64_t jj) {
void KERNEL_4x4(vec_t* vec_A, vec_t* vec_B, acc_t* acc){
for (int l = 0; l < k; l+=4) {
__builtin_mma_xvf32gerpp(&acc[0], vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc[0], vec_A[1], vec_B[1]);
__builtin_mma_xvf32gerpp(&acc[0], vec_A[2], vec_B[2]);
__builtin_mma_xvf32gerpp(&acc[0], vec_A[3], vec_B[3]);
}
}

void KERNEL_4x8(vec_t* vec_A, vec_t* vec_B, acc_t* acc ) {
__builtin_mma_xvf32gerpp(&acc[0], vec_A[0], (vec_t)vec_B[0]);
__builtin_mma_xvf32gerpp(&acc[1], vec_A[0], (vec_t)vec_B[1]);
__builtin_mma_xvf32gerpp(&acc[0], vec_A[1], (vec_t)vec_B[2]);
__builtin_mma_xvf32gerpp(&acc[1], vec_A[1], (vec_t)vec_B[3]);
__builtin_mma_xvf32gerpp(&acc[0], vec_A[2], (vec_t)vec_B[4]);
__builtin_mma_xvf32gerpp(&acc[1], vec_A[2], (vec_t)vec_B[5]);
__builtin_mma_xvf32gerpp(&acc[0], vec_A[3], (vec_t)vec_B[6]);
__builtin_mma_xvf32gerpp(&acc[1], vec_A[3], (vec_t)vec_B[7]);

}

void KERNEL_8x4(vec_t* vec_A, vec_t* vec_B, acc_t* acc) {
__builtin_mma_xvf32gerpp(&acc[0], (vec_t)vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc[1], (vec_t)vec_A[1], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc[0], (vec_t)vec_A[2], vec_B[1]);
__builtin_mma_xvf32gerpp(&acc[1], (vec_t)vec_A[3], vec_B[1]);
__builtin_mma_xvf32gerpp(&acc[0], (vec_t)vec_A[4], vec_B[2]);
__builtin_mma_xvf32gerpp(&acc[1], (vec_t)vec_A[5], vec_B[2]);
__builtin_mma_xvf32gerpp(&acc[0], (vec_t)vec_A[6], vec_B[3]);
__builtin_mma_xvf32gerpp(&acc[1], (vec_t)vec_A[7], vec_B[3]);
}

void KERNEL_8x8(vec_t * vec_A, vec_t* vec_B, acc_t* acc) {
for(int x = 0; x < 16; x+=2) {
__builtin_mma_xvf32gerpp(&acc[0], (vec_t)vec_A[x], vec_B[x]);
__builtin_mma_xvf32gerpp(&acc[1], (vec_t)vec_A[x], vec_B[x+1]);
__builtin_mma_xvf32gerpp(&acc[2], (vec_t)vec_A[x+1], vec_B[x]);
__builtin_mma_xvf32gerpp(&acc[3], (vec_t)vec_A[x+1], vec_B[x+1]);
}
}
void KERNEL_8x16(vec_t * vec_A, vec_t* vec_B, acc_t* acc) {
for(int x = 0; x < 16; x+=2) {
__builtin_mma_xvf32gerpp (&acc[0], (vec_t)vec_A[x], (vec_t)vec_B[x]);
__builtin_mma_xvf32gerpp (&acc[1], (vec_t)vec_A[x], (vec_t)vec_B[x+1]);
__builtin_mma_xvf32gerpp (&acc[2], (vec_t)vec_A[x], (vec_t)vec_B[x+16]);
__builtin_mma_xvf32gerpp (&acc[3], (vec_t)vec_A[x], (vec_t)vec_B[x+17]);
__builtin_mma_xvf32gerpp (&acc[4], (vec_t)vec_A[x+1], (vec_t)vec_B[x]);
__builtin_mma_xvf32gerpp (&acc[5], (vec_t)vec_A[x+1], (vec_t)vec_B[x+1]);
__builtin_mma_xvf32gerpp (&acc[6], (vec_t)vec_A[x+1], (vec_t)vec_B[x+16]);
__builtin_mma_xvf32gerpp (&acc[7], (vec_t)vec_A[x+1], (vec_t)vec_B[x+17]);
}
}

void KERNEL_16x8(vec_t * vec_A, vec_t* vec_B, acc_t* acc) {
for(int x = 0; x < 16; x+=2) {
__builtin_mma_xvf32gerpp (&acc[0], (vec_t)vec_A[x+0], (vec_t)vec_B[x]);
__builtin_mma_xvf32gerpp (&acc[1], (vec_t)vec_A[x+0], (vec_t)vec_B[x+1]);
__builtin_mma_xvf32gerpp (&acc[2], (vec_t)vec_A[x+1], (vec_t)vec_B[x]);
__builtin_mma_xvf32gerpp (&acc[3], (vec_t)vec_A[x+1], (vec_t)vec_B[x+1]);
__builtin_mma_xvf32gerpp (&acc[4], (vec_t)vec_A[x+16], (vec_t)vec_B[x]);
__builtin_mma_xvf32gerpp (&acc[5], (vec_t)vec_A[x+16], (vec_t)vec_B[x+1]);
__builtin_mma_xvf32gerpp (&acc[6], (vec_t)vec_A[x+17], (vec_t)vec_B[x]);
__builtin_mma_xvf32gerpp (&acc[7], (vec_t)vec_A[x+17], (vec_t)vec_B[x+1]);
}
}
/*void KERNEL_4x4(int64_t ii, int64_t jj) {
vec_t vec_A[4], vec_B[4], vec_C[4];
acc_t acc_0;
__builtin_mma_xxsetaccz(&acc_0);
Expand Down Expand Up @@ -2998,7 +3062,7 @@ class tinyBLAS_PPC {
SAVE_ACC(&acc_1, ii, jj+4);
SAVE_ACC(&acc_2, ii+4, jj);
SAVE_ACC(&acc_3, ii+4, jj+4);
}
}*/

void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int64_t mc, nc, mp, np;
Expand Down Expand Up @@ -3204,6 +3268,112 @@ class tinyBLAS_PPC {
}
}
}
template<int RM, int RN>
inline void kernel(int64_t ii, int64_t jj) {
if constexpr(RM == 4 && RN == 4) {
acc_t acc[1];
__builtin_mma_xxsetaccz(&acc[0]);
vec_t vec_A[4], vec_B[4], vec_C[4];
for (int l = 0; l < k; l += 4) {
packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
KERNEL_4x4(vec_A, vec_B, acc);
}
SAVE_ACC(&acc[0], ii, jj);
} else if constexpr(RM == 4 && RN == 8) {
vec_t vec_A[4], vec_B[8], vec_C[4];
acc_t acc[2];
__builtin_mma_xxsetaccz(&acc[0]);
__builtin_mma_xxsetaccz(&acc[1]);
for (int64_t l = 0; l < k; l+=4) {
packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
KERNEL_4x8(vec_A, vec_B, acc);
}
SAVE_ACC(&acc[0], ii, jj);
SAVE_ACC(&acc[1], ii, jj+4);

} else if constexpr(RM == 8 && RN == 4) {
vec_t vec_A[8], vec_B[4], vec_C[4];
acc_t acc[2];
__builtin_mma_xxsetaccz(&acc[0]);
__builtin_mma_xxsetaccz(&acc[1]);
for (int64_t l = 0; l < k; l+=4) {
packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
KERNEL_8x4(vec_A, vec_B, acc);
}
SAVE_ACC(&acc[0], ii, jj);
SAVE_ACC(&acc[1], ii+4, jj);
} else if constexpr(RM == 8 && RN == 8) {
vec_t vec_A[16], vec_B[16], vec_C[4];
acc_t acc[4];
__builtin_mma_xxsetaccz(&acc[0]);
__builtin_mma_xxsetaccz(&acc[1]);
__builtin_mma_xxsetaccz(&acc[2]);
__builtin_mma_xxsetaccz(&acc[3]);
for (int l = 0; l < k; l+=8) {
packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
KERNEL_8x8(vec_A, vec_B, acc);
}
SAVE_ACC(&acc[0], ii, jj);
SAVE_ACC(&acc[1], ii, jj+4);
SAVE_ACC(&acc[2], ii+4, jj);
SAVE_ACC(&acc[3], ii+4, jj+4);
} else if constexpr(RM == 8 && RN == 16) {
vec_t vec_A[16], vec_B[32], vec_C[4];
acc_t acc[8];
__builtin_mma_xxsetaccz(&acc[0]);
__builtin_mma_xxsetaccz(&acc[1]);
__builtin_mma_xxsetaccz(&acc[2]);
__builtin_mma_xxsetaccz(&acc[3]);
__builtin_mma_xxsetaccz(&acc[4]);
__builtin_mma_xxsetaccz(&acc[5]);
__builtin_mma_xxsetaccz(&acc[6]);
__builtin_mma_xxsetaccz(&acc[7]);
for (int l = 0; l < k; l+=8) {
packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 16, (float*)vec_B);
KERNEL_8x16(vec_A, vec_B, acc);
}
SAVE_ACC(&acc[0], ii, jj);
SAVE_ACC(&acc[1], ii, jj+4);
SAVE_ACC(&acc[2], ii, jj+8);
SAVE_ACC(&acc[3], ii, jj+12);
SAVE_ACC(&acc[4], ii+4, jj);
SAVE_ACC(&acc[5], ii+4, jj+4);
SAVE_ACC(&acc[6], ii+4, jj+8);
SAVE_ACC(&acc[7], ii+4, jj+12);

} else if constexpr(RM == 16 && RN == 8){
vec_t vec_A[32], vec_B[16], vec_C[4];
acc_t acc[8];
__builtin_mma_xxsetaccz(&acc[0]);
__builtin_mma_xxsetaccz(&acc[1]);
__builtin_mma_xxsetaccz(&acc[2]);
__builtin_mma_xxsetaccz(&acc[3]);
__builtin_mma_xxsetaccz(&acc[4]);
__builtin_mma_xxsetaccz(&acc[5]);
__builtin_mma_xxsetaccz(&acc[6]);
__builtin_mma_xxsetaccz(&acc[7]);
for (int l = 0; l < k; l+=8) {
packTranspose<vector float>(A+(ii*lda)+l, lda, 16, 8, (float*)vec_A);
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
KERNEL_16x8(vec_A, vec_B, acc);
}
SAVE_ACC(&acc[0], ii, jj);
SAVE_ACC(&acc[1], ii, jj+4);
SAVE_ACC(&acc[2], ii+4, jj);
SAVE_ACC(&acc[3], ii+4, jj+4);
SAVE_ACC(&acc[4], ii+8, jj);
SAVE_ACC(&acc[5], ii+8, jj+4);
SAVE_ACC(&acc[6], ii+12, jj);
SAVE_ACC(&acc[7], ii+12, jj+4);
}else {
static_assert(false, "RN/RM values not supported");
}
}

template <int RM, int RN>
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
Expand All @@ -3213,21 +3383,12 @@ class tinyBLAS_PPC {
int64_t duty = (tiles + nth - 1) / nth;
int64_t start = duty * ith;
int64_t end = start + duty;
if (RM == 4 && RN == 4) {
kernel = &tinyBLAS_PPC::KERNEL_4x4;
} else if (RM == 4 && RN == 8) {
kernel = &tinyBLAS_PPC::KERNEL_4x8;
} else if (RM == 8 && RN == 4) {
kernel = &tinyBLAS_PPC::KERNEL_8x4;
} else if (RM == 8 && RN == 8) {
kernel = &tinyBLAS_PPC::KERNEL_8x8;
}
if (end > tiles)
end = tiles;
for (int64_t job = start; job < end; ++job) {
int64_t ii = m0 + job / xtiles * RM;
int64_t jj = n0 + job % xtiles * RN;
(this->*kernel)(ii, jj);
kernel<RM,RN>(ii, jj);
}
}

Expand Down
Loading