From 179a4d283b1ffe7e0adc2ebd2ac0e2d94ad217ca Mon Sep 17 00:00:00 2001 From: Shalini Salomi Bodapati Date: Mon, 9 Jun 2025 05:23:27 -0500 Subject: [PATCH] Move packing from inside MMA kernels This patch moves calls from packing routines from inside MMA kernel to one step behind. Current call stack : matmul->mnpack->gemm->kernel->PackTanspose+MMA instructions Changed call stack: matmul->mnpack->gemm->PackTranspose->kernel->MMA instrutcions Not seeing much perf difference with this change Signed-off-by: Shalini Salomi Bodapati --- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 187 ++++++++++++++++++++++++-- 1 file changed, 174 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 1d46158f928c4..94cd58c2d05ca 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -2673,7 +2673,6 @@ class tinyBLAS_PPC { private: - void (tinyBLAS_PPC::*kernel)(int64_t, int64_t); template void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) { @@ -2920,7 +2919,72 @@ class tinyBLAS_PPC { } } - void KERNEL_4x4(int64_t ii, int64_t jj) { + void KERNEL_4x4(vec_t* vec_A, vec_t* vec_B, acc_t* acc){ + for (int l = 0; l < k; l+=4) { + __builtin_mma_xvf32gerpp(&acc[0], vec_A[0], vec_B[0]); + __builtin_mma_xvf32gerpp(&acc[0], vec_A[1], vec_B[1]); + __builtin_mma_xvf32gerpp(&acc[0], vec_A[2], vec_B[2]); + __builtin_mma_xvf32gerpp(&acc[0], vec_A[3], vec_B[3]); + } +} + + void KERNEL_4x8(vec_t* vec_A, vec_t* vec_B, acc_t* acc ) { + __builtin_mma_xvf32gerpp(&acc[0], vec_A[0], (vec_t)vec_B[0]); + __builtin_mma_xvf32gerpp(&acc[1], vec_A[0], (vec_t)vec_B[1]); + __builtin_mma_xvf32gerpp(&acc[0], vec_A[1], (vec_t)vec_B[2]); + __builtin_mma_xvf32gerpp(&acc[1], vec_A[1], (vec_t)vec_B[3]); + __builtin_mma_xvf32gerpp(&acc[0], vec_A[2], (vec_t)vec_B[4]); + __builtin_mma_xvf32gerpp(&acc[1], vec_A[2], (vec_t)vec_B[5]); + __builtin_mma_xvf32gerpp(&acc[0], vec_A[3], (vec_t)vec_B[6]); + __builtin_mma_xvf32gerpp(&acc[1], vec_A[3], (vec_t)vec_B[7]); + + } + + void KERNEL_8x4(vec_t* vec_A, vec_t* vec_B, acc_t* acc) { + __builtin_mma_xvf32gerpp(&acc[0], (vec_t)vec_A[0], vec_B[0]); + __builtin_mma_xvf32gerpp(&acc[1], (vec_t)vec_A[1], vec_B[0]); + __builtin_mma_xvf32gerpp(&acc[0], (vec_t)vec_A[2], vec_B[1]); + __builtin_mma_xvf32gerpp(&acc[1], (vec_t)vec_A[3], vec_B[1]); + __builtin_mma_xvf32gerpp(&acc[0], (vec_t)vec_A[4], vec_B[2]); + __builtin_mma_xvf32gerpp(&acc[1], (vec_t)vec_A[5], vec_B[2]); + __builtin_mma_xvf32gerpp(&acc[0], (vec_t)vec_A[6], vec_B[3]); + __builtin_mma_xvf32gerpp(&acc[1], (vec_t)vec_A[7], vec_B[3]); + } + + void KERNEL_8x8(vec_t * vec_A, vec_t* vec_B, acc_t* acc) { + for(int x = 0; x < 16; x+=2) { + __builtin_mma_xvf32gerpp(&acc[0], (vec_t)vec_A[x], vec_B[x]); + __builtin_mma_xvf32gerpp(&acc[1], (vec_t)vec_A[x], vec_B[x+1]); + __builtin_mma_xvf32gerpp(&acc[2], (vec_t)vec_A[x+1], vec_B[x]); + __builtin_mma_xvf32gerpp(&acc[3], (vec_t)vec_A[x+1], vec_B[x+1]); + } + } + void KERNEL_8x16(vec_t * vec_A, vec_t* vec_B, acc_t* acc) { + for(int x = 0; x < 16; x+=2) { + __builtin_mma_xvf32gerpp (&acc[0], (vec_t)vec_A[x], (vec_t)vec_B[x]); + __builtin_mma_xvf32gerpp (&acc[1], (vec_t)vec_A[x], (vec_t)vec_B[x+1]); + __builtin_mma_xvf32gerpp (&acc[2], (vec_t)vec_A[x], (vec_t)vec_B[x+16]); + __builtin_mma_xvf32gerpp (&acc[3], (vec_t)vec_A[x], (vec_t)vec_B[x+17]); + __builtin_mma_xvf32gerpp (&acc[4], (vec_t)vec_A[x+1], (vec_t)vec_B[x]); + __builtin_mma_xvf32gerpp (&acc[5], (vec_t)vec_A[x+1], (vec_t)vec_B[x+1]); + __builtin_mma_xvf32gerpp (&acc[6], (vec_t)vec_A[x+1], (vec_t)vec_B[x+16]); + __builtin_mma_xvf32gerpp (&acc[7], (vec_t)vec_A[x+1], (vec_t)vec_B[x+17]); + } + } + + void KERNEL_16x8(vec_t * vec_A, vec_t* vec_B, acc_t* acc) { + for(int x = 0; x < 16; x+=2) { + __builtin_mma_xvf32gerpp (&acc[0], (vec_t)vec_A[x+0], (vec_t)vec_B[x]); + __builtin_mma_xvf32gerpp (&acc[1], (vec_t)vec_A[x+0], (vec_t)vec_B[x+1]); + __builtin_mma_xvf32gerpp (&acc[2], (vec_t)vec_A[x+1], (vec_t)vec_B[x]); + __builtin_mma_xvf32gerpp (&acc[3], (vec_t)vec_A[x+1], (vec_t)vec_B[x+1]); + __builtin_mma_xvf32gerpp (&acc[4], (vec_t)vec_A[x+16], (vec_t)vec_B[x]); + __builtin_mma_xvf32gerpp (&acc[5], (vec_t)vec_A[x+16], (vec_t)vec_B[x+1]); + __builtin_mma_xvf32gerpp (&acc[6], (vec_t)vec_A[x+17], (vec_t)vec_B[x]); + __builtin_mma_xvf32gerpp (&acc[7], (vec_t)vec_A[x+17], (vec_t)vec_B[x+1]); + } + } + /*void KERNEL_4x4(int64_t ii, int64_t jj) { vec_t vec_A[4], vec_B[4], vec_C[4]; acc_t acc_0; __builtin_mma_xxsetaccz(&acc_0); @@ -2998,7 +3062,7 @@ class tinyBLAS_PPC { SAVE_ACC(&acc_1, ii, jj+4); SAVE_ACC(&acc_2, ii+4, jj); SAVE_ACC(&acc_3, ii+4, jj+4); - } + }*/ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { int64_t mc, nc, mp, np; @@ -3204,6 +3268,112 @@ class tinyBLAS_PPC { } } } + template + inline void kernel(int64_t ii, int64_t jj) { + if constexpr(RM == 4 && RN == 4) { + acc_t acc[1]; + __builtin_mma_xxsetaccz(&acc[0]); + vec_t vec_A[4], vec_B[4], vec_C[4]; + for (int l = 0; l < k; l += 4) { + packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); + KERNEL_4x4(vec_A, vec_B, acc); + } + SAVE_ACC(&acc[0], ii, jj); + } else if constexpr(RM == 4 && RN == 8) { + vec_t vec_A[4], vec_B[8], vec_C[4]; + acc_t acc[2]; + __builtin_mma_xxsetaccz(&acc[0]); + __builtin_mma_xxsetaccz(&acc[1]); + for (int64_t l = 0; l < k; l+=4) { + packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B); + KERNEL_4x8(vec_A, vec_B, acc); + } + SAVE_ACC(&acc[0], ii, jj); + SAVE_ACC(&acc[1], ii, jj+4); + + } else if constexpr(RM == 8 && RN == 4) { + vec_t vec_A[8], vec_B[4], vec_C[4]; + acc_t acc[2]; + __builtin_mma_xxsetaccz(&acc[0]); + __builtin_mma_xxsetaccz(&acc[1]); + for (int64_t l = 0; l < k; l+=4) { + packTranspose(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); + KERNEL_8x4(vec_A, vec_B, acc); + } + SAVE_ACC(&acc[0], ii, jj); + SAVE_ACC(&acc[1], ii+4, jj); + } else if constexpr(RM == 8 && RN == 8) { + vec_t vec_A[16], vec_B[16], vec_C[4]; + acc_t acc[4]; + __builtin_mma_xxsetaccz(&acc[0]); + __builtin_mma_xxsetaccz(&acc[1]); + __builtin_mma_xxsetaccz(&acc[2]); + __builtin_mma_xxsetaccz(&acc[3]); + for (int l = 0; l < k; l+=8) { + packTranspose(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B); + KERNEL_8x8(vec_A, vec_B, acc); + } + SAVE_ACC(&acc[0], ii, jj); + SAVE_ACC(&acc[1], ii, jj+4); + SAVE_ACC(&acc[2], ii+4, jj); + SAVE_ACC(&acc[3], ii+4, jj+4); + } else if constexpr(RM == 8 && RN == 16) { + vec_t vec_A[16], vec_B[32], vec_C[4]; + acc_t acc[8]; + __builtin_mma_xxsetaccz(&acc[0]); + __builtin_mma_xxsetaccz(&acc[1]); + __builtin_mma_xxsetaccz(&acc[2]); + __builtin_mma_xxsetaccz(&acc[3]); + __builtin_mma_xxsetaccz(&acc[4]); + __builtin_mma_xxsetaccz(&acc[5]); + __builtin_mma_xxsetaccz(&acc[6]); + __builtin_mma_xxsetaccz(&acc[7]); + for (int l = 0; l < k; l+=8) { + packTranspose(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 8, 16, (float*)vec_B); + KERNEL_8x16(vec_A, vec_B, acc); + } + SAVE_ACC(&acc[0], ii, jj); + SAVE_ACC(&acc[1], ii, jj+4); + SAVE_ACC(&acc[2], ii, jj+8); + SAVE_ACC(&acc[3], ii, jj+12); + SAVE_ACC(&acc[4], ii+4, jj); + SAVE_ACC(&acc[5], ii+4, jj+4); + SAVE_ACC(&acc[6], ii+4, jj+8); + SAVE_ACC(&acc[7], ii+4, jj+12); + + } else if constexpr(RM == 16 && RN == 8){ + vec_t vec_A[32], vec_B[16], vec_C[4]; + acc_t acc[8]; + __builtin_mma_xxsetaccz(&acc[0]); + __builtin_mma_xxsetaccz(&acc[1]); + __builtin_mma_xxsetaccz(&acc[2]); + __builtin_mma_xxsetaccz(&acc[3]); + __builtin_mma_xxsetaccz(&acc[4]); + __builtin_mma_xxsetaccz(&acc[5]); + __builtin_mma_xxsetaccz(&acc[6]); + __builtin_mma_xxsetaccz(&acc[7]); + for (int l = 0; l < k; l+=8) { + packTranspose(A+(ii*lda)+l, lda, 16, 8, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B); + KERNEL_16x8(vec_A, vec_B, acc); + } + SAVE_ACC(&acc[0], ii, jj); + SAVE_ACC(&acc[1], ii, jj+4); + SAVE_ACC(&acc[2], ii+4, jj); + SAVE_ACC(&acc[3], ii+4, jj+4); + SAVE_ACC(&acc[4], ii+8, jj); + SAVE_ACC(&acc[5], ii+8, jj+4); + SAVE_ACC(&acc[6], ii+12, jj); + SAVE_ACC(&acc[7], ii+12, jj+4); + }else { + static_assert(false, "RN/RM values not supported"); + } + } template NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { @@ -3213,21 +3383,12 @@ class tinyBLAS_PPC { int64_t duty = (tiles + nth - 1) / nth; int64_t start = duty * ith; int64_t end = start + duty; - if (RM == 4 && RN == 4) { - kernel = &tinyBLAS_PPC::KERNEL_4x4; - } else if (RM == 4 && RN == 8) { - kernel = &tinyBLAS_PPC::KERNEL_4x8; - } else if (RM == 8 && RN == 4) { - kernel = &tinyBLAS_PPC::KERNEL_8x4; - } else if (RM == 8 && RN == 8) { - kernel = &tinyBLAS_PPC::KERNEL_8x8; - } if (end > tiles) end = tiles; for (int64_t job = start; job < end; ++job) { int64_t ii = m0 + job / xtiles * RM; int64_t jj = n0 + job % xtiles * RN; - (this->*kernel)(ii, jj); + kernel(ii, jj); } }