From 505d77fc99856f12aee83e41dd92a03472f9d1a7 Mon Sep 17 00:00:00 2001 From: Shalini Salomi Bodapati Date: Fri, 4 Apr 2025 01:49:55 -0500 Subject: [PATCH] Enable MMA for BF16 data type on ppc64le This patch upstreams llamafile's cpu matrix multiplication kernels for ppc64le using MMA builtins for BF16 data type. This change results in 9x - 40x gains in total speed S t/s (ie all tokens/total time), across various batch sizes tested using llama-batched-bench benchmark. The patch is tested with Meta-Lllama-3-8B, and Mistral-7B models (BF16 models generated by using llama-quantize from corresponding FP32 models) on an IBM POWER10 machine. Signed-off-by: Shalini Salomi Bodapati --- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 629 +++++++++++++++++++++++++- 1 file changed, 627 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index f6374f7894a08..8288a2dd5387a 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -1054,6 +1054,578 @@ class tinyBLAS_Q0_AVX { } \ } \ +template +class tinyBLAS_BF16_PPC { + public: + tinyBLAS_BF16_PPC(int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc, + int ith, int nth) + : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + } + + void matmul(int64_t m, int64_t n) { + mnpack(0, m, 0, n); + } + + private: + + void packNormal(const TA* a, int64_t lda, int rows, int cols, unsigned char* vec) { + int64_t i, j; + TA *aoffset = NULL; + unsigned char *vecOffset = NULL; + TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; + TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; + vector unsigned char c1, c2, c3, c4, c5, c6, c7, c8; + vector unsigned char t1, t2, t3, t4, t5, t6, t7, t8; + vector unsigned char s1, s2, s3, s4, s5, s6, s7, s8; + vector unsigned char swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23}; + vector unsigned char swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31}; + + vector unsigned char swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + vector unsigned char swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + + aoffset = const_cast(a); + vecOffset = vec; + j = (rows >> 3); + if (j > 0) { + do { + if (cols == 4) { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset += 4 * lda; + + c1 = vec_xl(0, (vector unsigned char*)aoffset1); + c2 = vec_xl(0, (vector unsigned char*)aoffset2); + c3 = vec_xl(0, (vector unsigned char*)aoffset3); + c4 = vec_xl(0, (vector unsigned char*)aoffset4); + + t1 = vec_perm(c1, c2, swiz1); + t2 = vec_perm(c1, c2, swiz2); + t3 = vec_perm(c3, c4, swiz1); + t4 = vec_perm(c3, c4, swiz2); + s1 = vec_perm(t1, t3, swiz3); + s2 = vec_perm(t1, t3, swiz4); + s3 = vec_perm(t2, t4, swiz3); + s4 = vec_perm(t2, t4, swiz4); + + vec_xst(s1,0, (vector unsigned char*)vecOffset); + vec_xst(s2,0, (vector unsigned char*)(vecOffset+16)); + vec_xst(s3,0, (vector unsigned char*)(vecOffset+32)); + vec_xst(s4,0, (vector unsigned char*)(vecOffset+48)); + + aoffset1 +=lda; + aoffset2 +=lda; + aoffset3 +=lda; + aoffset4 +=lda; + vecOffset +=64; + } + + i = (cols >> 3); + if (i > 0) { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset5 = aoffset4 + lda; + aoffset6 = aoffset5 + lda; + aoffset7 = aoffset6 + lda; + aoffset8 = aoffset7 + lda; + aoffset += 8 * lda; + + do { + c1 = (vec_xl(0, (vector unsigned char*)aoffset1)); + c2 = (vec_xl(0, (vector unsigned char*)aoffset2)); + c3 = (vec_xl(0, (vector unsigned char*)aoffset3)); + c4 = (vec_xl(0, (vector unsigned char*)aoffset4)); + c5 = (vec_xl(0, (vector unsigned char*)aoffset5)); + c6 = (vec_xl(0, (vector unsigned char*)aoffset6)); + c7 = (vec_xl(0, (vector unsigned char*)aoffset7)); + c8 = (vec_xl(0, (vector unsigned char*)aoffset8)); + + t1 = vec_perm(c1, c2, swiz1); + t2 = vec_perm(c1, c2, swiz2); + t3 = vec_perm(c3, c4, swiz1); + t4 = vec_perm(c3, c4, swiz2); + s1 = vec_perm(t1, t3, swiz3); + s2 = vec_perm(t1, t3, swiz4); + s3 = vec_perm(t2, t4, swiz3); + s4 = vec_perm(t2, t4, swiz4); + t5 = vec_perm(c5, c6, swiz1); + t6 = vec_perm(c5, c6, swiz2); + t7 = vec_perm(c7, c8, swiz1); + t8 = vec_perm(c7, c8, swiz2); + s5 = vec_perm(t5, t7, swiz3); + s6 = vec_perm(t5, t7, swiz4); + s7 = vec_perm(t6, t8, swiz3); + s8 = vec_perm(t6, t8, swiz4); + + vec_xst(s1, 0, (vector unsigned char*)vecOffset); + vec_xst(s2, 0, (vector unsigned char*)(vecOffset+16)); + vec_xst(s3, 0, (vector unsigned char*)(vecOffset+32)); + vec_xst(s4, 0, (vector unsigned char*)(vecOffset+48)); + vec_xst(s5, 0, (vector unsigned char*)(vecOffset+64)); + vec_xst(s6, 0, (vector unsigned char*)(vecOffset+80)); + vec_xst(s7, 0, (vector unsigned char*)(vecOffset+96)); + vec_xst(s8, 0, (vector unsigned char*)(vecOffset+112)); + + aoffset1 += 8*lda; + aoffset2 += 8*lda; + aoffset3 += 8*lda; + aoffset4 += 8*lda; + aoffset5 += 8*lda; + aoffset6 += 8*lda; + aoffset7 += 8*lda; + aoffset8 += 8*lda; + vecOffset += 128; + i--; + } while(i > 0); + } + j--; + } while(j > 0); + } + + if (rows & 4) { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset += 4 * lda; + + if (cols == 4) { + c1 = vec_xl(0, (vector unsigned char*)aoffset1); + c2 = vec_xl(0, (vector unsigned char*)aoffset2); + c3 = vec_xl(0, (vector unsigned char*)aoffset3); + c4 = vec_xl(0, (vector unsigned char*)aoffset4); + + t1 = vec_perm(c1, c2, swiz1); + t2 = vec_perm(c3, c4, swiz1); + s1 = vec_perm(t1, t2, swiz3); + s2 = vec_perm(t1, t2, swiz4); + vec_xst(s1, 0, (vector unsigned char*)vecOffset); + vec_xst(s2, 0, (vector unsigned char*)(vecOffset+16)); + aoffset1 += lda; + aoffset2 += lda; + aoffset3 += lda; + aoffset4 += lda; + vecOffset += 32; + + } + + i = (cols >> 3); + if (i > 0) { + do { + c1 = vec_xl(0, (vector unsigned char*)aoffset1); + c2 = vec_xl(0, (vector unsigned char*)aoffset2); + c3 = vec_xl(0, (vector unsigned char*)aoffset3); + c4 = vec_xl(0, (vector unsigned char*)aoffset4); + + t1 = vec_perm(c1, c2, swiz1); + t2 = vec_perm(c1, c2, swiz2); + t3 = vec_perm(c3, c4, swiz1); + t4 = vec_perm(c3, c4, swiz2); + s1 = vec_perm(t1, t3, swiz3); + s2 = vec_perm(t1, t3, swiz4); + s3 = vec_perm(t2, t4, swiz3); + s4 = vec_perm(t2, t4, swiz4); + + + vec_xst(s1, 0, (vector unsigned char*)(vecOffset)); + vec_xst(s2, 0, (vector unsigned char*)(vecOffset+16)); + vec_xst(s3, 0, (vector unsigned char*)(vecOffset+32)); + vec_xst(s4, 0, (vector unsigned char*)(vecOffset+48)); + + + aoffset1 += 8*lda; + aoffset2 += 8*lda; + aoffset3 += 8*lda; + aoffset4 += 8*lda; + vecOffset += 64; + i--; + } while(i > 0); + } + } + if (rows & 3) { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + if (cols == 4) { + switch(rows) { + case 3: c3 = vec_xl(0, (vector unsigned char*)aoffset3); + case 2: c2 = vec_xl(0, (vector unsigned char*)aoffset2); + case 1: c1 = vec_xl(0, (vector unsigned char*)aoffset1); + break; + } + t1 = vec_perm(c1, c2, swiz1); + t2 = vec_perm(c3, c4, swiz1); + s1 = vec_perm(t1, t2, swiz3); + s2 = vec_perm(t1, t2, swiz4); + vec_xst(s1, 0, (vector unsigned char*)vecOffset); + vec_xst(s2, 0, (vector unsigned char*)(vecOffset+16)); + aoffset1 += lda; + aoffset2 += lda; + aoffset3 += lda; + vecOffset += 32; + + } + + i = (cols >> 3); + if (i > 0) { + do { + switch(rows) { + case 3: c3 = vec_xl(0, (vector unsigned char*)aoffset3); + case 2: c2 = vec_xl(0, (vector unsigned char*)aoffset2); + case 1: c1 = vec_xl(0, (vector unsigned char*)aoffset1); + break; + } + t1 = vec_perm(c1, c2, swiz1); + t2 = vec_perm(c1, c2, swiz2); + t3 = vec_perm(c3, c4, swiz1); + t4 = vec_perm(c3, c4, swiz2); + + s1 = vec_perm(t1, t3, swiz3); + s2 = vec_perm(t1, t3, swiz4); + s3 = vec_perm(t2, t4, swiz3); + s4 = vec_perm(t2, t4, swiz4); + vec_xst(s1, 0, (vector unsigned char*)(vecOffset)); + vec_xst(s2, 0, (vector unsigned char*)(vecOffset+16)); + vec_xst(s3, 0, (vector unsigned char*)(vecOffset+32)); + vec_xst(s4, 0, (vector unsigned char*)(vecOffset+48)); + + aoffset2 += 8*lda; + aoffset3 += 8*lda; + vecOffset += 64; + i--; + } while(i > 0); + } + } + } + + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + int m_rem = MIN(m - m0, 8); + int n_rem = MIN(n - n0, 8); + + if (m_rem >= 8 && n_rem >= 8) { + mc = 8; + nc = 8; + gemm<8,8>(m0, m, n0, n); + } else if (m_rem >= 4 && n_rem >= 8) { + mc = 4; + nc = 8; + gemm<4,8>(m0, m, n0, n); + } else if (m_rem >=8 && n_rem >=4){ + mc = 8; + nc = 4; + gemm<8,4>(m0, m, n0, n); + } else if ((m_rem < 4) && (n_rem >= 8)) { + nc = 8; + switch(m_rem) { + case 1: + mc = 1; + gemm_Mx8<1>(m0, m, n0, n); + break; + case 2: + mc = 2; + gemm_Mx8<2>(m0, m, n0, n); + break; + case 3: + mc = 3; + gemm_Mx8<3>(m0, m, n0, n); + break; + default: + return; + } + } else if (m_rem >= 4 && n_rem >= 4) { + mc = 4; + nc = 4; + gemm_small<4, 4>(m0, m, n0, n); + } else if ((m_rem > 4) && (n_rem < 4)) { + mc = 4; + switch(n_rem) { + case 1: + nc = 1; + gemm_small<4, 1>(m0, m, n0, n); + break; + case 2: + nc = 2; + gemm_small<4, 2>(m0, m, n0, n); + break; + case 3: + nc = 3; + gemm_small<4, 3>(m0, m, n0, n); + break; + + default: + return; + } + } else { + switch((m_rem << 4) | n_rem) { + case 0x43: + mc = 4; + nc = 3; + gemm_small<4, 3>(m0, m, n0, n); + break; + case 0x42: + mc = 4; + nc = 2; + gemm_small<4, 2>(m0, m, n0, n); + break; + case 0x41: + mc = 4; + nc = 1; + gemm_small<4, 1>(m0, m, n0, n); + break; + case 0x34: + mc = 3; + nc = 4; + gemm_small<3, 4>(m0, m, n0, n); + break; + case 0x33: + mc = 3; + nc = 3; + gemm_small<3, 3>(m0, m, n0, n); + break; + case 0x32: + mc = 3; + nc = 2; + gemm_small<3, 2>(m0, m, n0, n); + break; + case 0x31: + mc = 3; + nc = 1; + gemm_small<3, 1>(m0, m, n0, n); + break; + case 0x24: + mc = 2; + nc = 4; + gemm_small<2,4>(m0, m, n0, n); + break; + case 0x23: + mc = 2; + nc = 3; + gemm_small<2, 3>(m0, m, n0, n); + break; + case 0x22: + mc = 2; + nc = 2; + gemm_small<2, 2>(m0, m, n0, n); + break; + case 0x21: + mc = 2; + nc = 1; + gemm_small<2, 1>(m0, m, n0, n); + break; + case 0x14: + mc = 1; + nc = 4; + gemm_small<1, 4>(m0, m, n0, n); + break; + case 0x13: + mc = 1; + nc = 3; + gemm_small<1, 3>(m0, m, n0, n); + break; + case 0x12: + mc = 1; + nc = 2; + gemm_small<1, 2>(m0, m, n0, n); + break; + case 0x11: + mc = 1; + nc = 1; + gemm_small<1, 1>(m0, m, n0, n); + break; + default: + return; + } + } + mp = m0 + (m - m0) / mc * mc; + np = n0 + (n - n0) / nc * nc; + mnpack(mp, m, n0, np); + mnpack(m0, m, np, n); + } + + void KERNEL_4x8(int64_t ii, int64_t jj) { + vec_t vec_A[4], vec_B[8] , vec_C[4]; + acc_t acc_0, acc_1; + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + for (int l = 0; l < k; l+=8) { + packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A); + packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B); + for (int x = 0; x < 4; x++) { + __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]); + } + } + SAVE_ACC(&acc_0, ii, jj); + SAVE_ACC(&acc_1, ii, jj+4); + } + + void KERNEL_8x4(int64_t ii, int64_t jj) { + vec_t vec_A[8], vec_B[4] , vec_C[4]; + acc_t acc_0, acc_1; + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + for (int l = 0; l < k; l+=8) { + packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A); + packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B); + for (int x = 0; x < 4; x++) { + __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]); + } + } + SAVE_ACC(&acc_0, ii, jj); + SAVE_ACC(&acc_1, ii+4, jj); + } + + + void KERNEL_8x8(int64_t ii, int64_t jj) { + vec_t vec_A[8], vec_B[8], vec_C[4]; + acc_t acc_0, acc_1, acc_2, acc_3; + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + __builtin_mma_xxsetaccz(&acc_2); + __builtin_mma_xxsetaccz(&acc_3); + for (int l = 0; l < k; l+=8) { + packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A); + packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B); + for (int x = 0; x < 4; x++) { + __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]); + __builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]); + __builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]); + } + } + + SAVE_ACC(&acc_0, ii, jj); + SAVE_ACC(&acc_1, ii, jj+4); + SAVE_ACC(&acc_2, ii+4, jj); + SAVE_ACC(&acc_3, ii+4, jj+4); + } + + template + void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + vec_t vec_C[4]; + acc_t acc_0; + __builtin_mma_xxsetaccz(&acc_0); + vec_t vec_A[2], vec_B[2]; + for (int l=0; l + void gemm_Mx8(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int RN = 8; + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + vec_t vec_C[4]; + acc_t acc_0, acc_1; + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + vec_t vec_A[4], vec_B[8]; + for (int l=0; l + inline void kernel(int64_t ii, int64_t jj) { + if constexpr(RM == 4 && RN == 8) { + KERNEL_4x8(ii,jj); + } else if constexpr(RM == 8 && RN == 8) { + KERNEL_8x8(ii,jj); + } else if constexpr(RM == 8 && RN == 4) { + KERNEL_8x4(ii,jj); + } else { + static_assert(false, "RN/RM values not supported"); + } + } + + template + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + kernel(ii, jj); + } + } + + const TA *const A; + const TB *const B; + TC *C; + TA *At; + TB *Bt; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; + const int ith; + const int nth; +}; + template class tinyBLAS_Q0_PPC { public: @@ -2188,6 +2760,23 @@ class tinyBLAS_PPC { void (tinyBLAS_PPC::*kernel)(int64_t, int64_t); + void dump_vec_f(const char * name, vector float vec){ + printf("%s:\t", name); + for(int i =0; i< 4; i++){ + printf("%-12.4f", (float)vec[i]); + } + printf("\n"); + } + void dump_acc(acc_t * acc, vector unsigned char* vec_C){ + __builtin_mma_disassemble_acc(vec_C, acc); + for (int j = 0; j<4; j++) { + for (int i = 0; i< 4; i++){ + printf("%-12.4f ", *((float*)&vec_C[j]+i)); + } + printf("\n"); + } + } + template void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) { int64_t i, j; @@ -2202,6 +2791,7 @@ class tinyBLAS_PPC { boffset = vec; j = (rows >> 3); if (j > 0) { + do { aoffset1 = aoffset; aoffset2 = aoffset1 + lda; @@ -2513,9 +3103,11 @@ class tinyBLAS_PPC { } void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + //printf("In mnpack m0 m n0 n %ld %ld %ld %ld\n", m0,m, n0, n); int64_t mc, nc, mp, np; int m_rem = MIN(m - m0, 16); int n_rem = MIN(n - n0, 16); + //printf("m_rem n_rem %d %d\n", m_rem, n_rem); if (m_rem >= 16 && n_rem >= 8) { mc = 8; nc = 8; @@ -2659,11 +3251,13 @@ class tinyBLAS_PPC { } mp = m0 + (m - m0) / mc * mc; np = n0 + (n - n0) / nc * nc; + //printf("mp np %d %d", mp, np); mnpack(mp, m, n0, np); mnpack(m0, m, np, n); } void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { + //printf("In gemm small m n %d, %d\n", RM, RN); int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; @@ -2700,15 +3294,28 @@ class tinyBLAS_PPC { vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2)); vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3)); } else { + //printf("pack A l: %d\n", l); packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A); + // printf("pack B l: %d\n", l); packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B); } + /*printf("dumping for l = %d\n", l); + dump_vec_f("A0", (vector float)vec_A[0]); + dump_vec_f("A1", (vector float)vec_A[1]); + dump_vec_f("A2", (vector float)vec_A[3]); + dump_vec_f("A3", (vector float)vec_A[3]); + dump_vec_f("B0", (vector float)vec_B[0]); + dump_vec_f("B1", (vector float)vec_B[1]); + dump_vec_f("B2", (vector float)vec_B[2]); + dump_vec_f("B3", (vector float)vec_B[3]);*/ + __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]); } __builtin_mma_disassemble_acc(vec_C, &acc_0); + // dump_acc(&acc_0, vec_C); for (int I = 0; I < RM; I++) { for (int J = 0; J < RN; J++) { *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J); @@ -2719,6 +3326,8 @@ class tinyBLAS_PPC { template NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + + //printf("In gemm m n %d, %d\n", RM, RN); int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; @@ -2801,6 +3410,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 assert(params->nth > 0); assert(params->ith < params->nth); + //printf("llama m n k: %d %d %d\n", m, n, k); // only enable sgemm for prompt processing #if !defined(__MMA__) if (n < 2) @@ -2843,7 +3453,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 (const float *)B, ldb, (float *)C, ldc, params->ith, params->nth}; - tb.matmul(m, n); + tb.matmul(m, n); return true; #else return false; @@ -2875,9 +3485,24 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 (float *)C, ldc}; return tb.matmul(m, n); } +#elif defined(__MMA__) + if ((k % 8)) + return false; + if(Btype == GGML_TYPE_BF16) { + tinyBLAS_BF16_PPC tb{ k, + (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc, + params->ith, params->nth}; + tb.matmul(m, n); + return true; + } +#else + return false; + #endif - return false; } + case GGML_TYPE_F16: { #if defined(__AVX512F__) if (Btype == GGML_TYPE_F16) {