From 364e5d80f76e21c4623da135092968a2c62b6b1c Mon Sep 17 00:00:00 2001 From: Shalini Salomi Bodapati Date: Mon, 2 Jun 2025 02:02:03 -0500 Subject: [PATCH] Refactor: Move matrix packing outside GEMM kernels In class tinyBLAS_PPC, previously, packing of input matrices A and B was performed on-the-fly within each GEMM microkernel. This patch refactors the code to decouple packing from kernel by introducing a preprocessing step that packs matrices once before any kernel is invoked. Benefits: - Enables better memory locality and data reuse - Simplifies the kernel logic by focusing purely on computation - Improves overall GEMM performance, especially for large matrix sizes Signed-off-by: Shalini Salomi Bodapati --- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 759 ++++++++++++-------------- 1 file changed, 359 insertions(+), 400 deletions(-) diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 1d46158f928c4..d4ec30f0c5124 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -2664,258 +2664,203 @@ class tinyBLAS_PPC { const TB *B, int64_t ldb, TC *C, int64_t ldc, int ith, int nth) - : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth), At(nullptr), Bt(nullptr) { } void matmul(int64_t m, int64_t n) { + At = new float[m*k]; + Bt = new float[n*k]; + m_orig = m; + n_orig = n; + + transpose_matrix(A, At, k, m); + transpose_matrix(B, Bt, k, n); mnpack(0, m, 0, n); } private: - void (tinyBLAS_PPC::*kernel)(int64_t, int64_t); - - template - void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) { - int64_t i, j; - TA *aoffset = NULL, *boffset = NULL; - TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; - __vector_pair C1, C2, C3, C4, C5, C6, C7, C8; - VA c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; - VA c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; - VA t1, t2, t3, t4, t5, t6, t7, t8; - aoffset = const_cast(a); - boffset = vec; - j = (rows >> 3); - if (j > 0) { - - do { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset5 = aoffset4 + lda; - aoffset6 = aoffset5 + lda; - aoffset7 = aoffset6 + lda; - aoffset8 = aoffset7 + lda; - aoffset += 8 * lda; - i = (cols >> 3); - if (i > 0) { - do { - C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1); - C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2); - C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3); - C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4); - C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5); - C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6); - C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7); - C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8); - __builtin_vsx_disassemble_pair(c1, &C1); - __builtin_vsx_disassemble_pair(c2, &C2); - __builtin_vsx_disassemble_pair(c3, &C3); - __builtin_vsx_disassemble_pair(c4, &C4); - __builtin_vsx_disassemble_pair(c5, &C5); - __builtin_vsx_disassemble_pair(c6, &C6); - __builtin_vsx_disassemble_pair(c7, &C7); - __builtin_vsx_disassemble_pair(c8, &C8); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_mergeh(c5[0], c6[0]); - t4 = vec_mergeh(c7[0], c8[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset); - vec_xst(t6, 0, boffset+4); - vec_xst(t7, 0, boffset+8); - vec_xst(t8, 0, boffset+12); - - t1 = vec_mergel(c1[0], c2[0]); - t2 = vec_mergel(c3[0], c4[0]); - t3 = vec_mergel(c5[0], c6[0]); - t4 = vec_mergel(c7[0], c8[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+16); - vec_xst(t6, 0, boffset+20); - vec_xst(t7, 0, boffset+24); - vec_xst(t8, 0, boffset+28); - - t1 = vec_mergeh(c1[1], c2[1]); - t2 = vec_mergeh(c3[1], c4[1]); - t3 = vec_mergeh(c5[1], c6[1]); - t4 = vec_mergeh(c7[1], c8[1]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+32); - vec_xst(t6, 0, boffset+36); - vec_xst(t7, 0, boffset+40); - vec_xst(t8, 0, boffset+44); - - t1 = vec_mergel(c1[1], c2[1]); - t2 = vec_mergel(c3[1], c4[1]); - t3 = vec_mergel(c5[1], c6[1]); - t4 = vec_mergel(c7[1], c8[1]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+48); - vec_xst(t6, 0, boffset+52); - vec_xst(t7, 0, boffset+56); - vec_xst(t8, 0, boffset+60); - - aoffset1 += 8*lda; - aoffset2 += 8*lda; - aoffset3 += 8*lda; - aoffset4 += 8*lda; - boffset += 64; - i--; - } while(i > 0); - } - if (cols & 4) { - c1[0] = vec_xl(0, aoffset1); - c2[0] = vec_xl(0, aoffset2); - c3[0] = vec_xl(0, aoffset3); - c4[0] = vec_xl(0, aoffset4); - c5[0] = vec_xl(0, aoffset5); - c6[0] = vec_xl(0, aoffset6); - c7[0] = vec_xl(0, aoffset7); - c8[0] = vec_xl(0, aoffset8); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_mergeh(c5[0], c6[0]); - t4 = vec_mergeh(c7[0], c8[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset); - vec_xst(t6, 0, boffset+4); - vec_xst(t7, 0, boffset+8); - vec_xst(t8, 0, boffset+12); - - t1 = vec_mergel(c1[0], c2[0]); - t2 = vec_mergel(c3[0], c4[0]); - t3 = vec_mergel(c5[0], c6[0]); - t4 = vec_mergel(c7[0], c8[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+16); - vec_xst(t6, 0, boffset+20); - vec_xst(t7, 0, boffset+24); - vec_xst(t8, 0, boffset+28); - } - j--; - } while(j > 0); + void transpose_tile_8x8(const float* src, float* dst, int ld_src, int ld_dst) { + vector float row[16]; + vector float t0, t1, t2, t3, t4, t5, t6, t7; + for (int i = 0; i < 8; i++) { + row[i] = vec_xl(0, src + i * ld_src); + row[i+8] = vec_xl(16, src + i * ld_src); } - - if (rows & 4) { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset += 4 * lda; - i = (cols >> 3); - if (i > 0) { - do { - C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1); - C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2); - C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3); - C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4); - __builtin_vsx_disassemble_pair(c1, &C1); - __builtin_vsx_disassemble_pair(c2, &C2); - __builtin_vsx_disassemble_pair(c3, &C3); - __builtin_vsx_disassemble_pair(c4, &C4); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_mergel(c1[0], c2[0]); - t4 = vec_mergel(c3[0], c4[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t1, t2, 3); - t7 = vec_xxpermdi(t3, t4, 0); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset); - vec_xst(t6, 0, boffset+4); - vec_xst(t7, 0, boffset+8); - vec_xst(t8, 0, boffset+12); - - t1 = vec_mergeh(c1[1], c2[1]); - t2 = vec_mergeh(c3[1], c4[1]); - t3 = vec_mergel(c1[1], c2[1]); - t4 = vec_mergel(c3[1], c4[1]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t1, t2, 3); - t7 = vec_xxpermdi(t3, t4, 0); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+16); - vec_xst(t6, 0, boffset+20); - vec_xst(t7, 0, boffset+24); - vec_xst(t8, 0, boffset+28); - - aoffset1 += 8*lda; - aoffset2 += 8*lda; - aoffset3 += 8*lda; - aoffset4 += 8*lda; - boffset += 32; - i--; - } while(i > 0); - } - - if (cols & 4) { - c1[0] = vec_xl(0, aoffset1); - c2[0] = vec_xl(0, aoffset2); - c3[0] = vec_xl(0, aoffset3); - c4[0] = vec_xl(0, aoffset4); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_xxpermdi(t1, t2, 0); - t4 = vec_xxpermdi(t1, t2, 3); - vec_xst(t3, 0, boffset); - vec_xst(t4, 0, boffset+4); - - t1 = vec_mergel(c1[0], c2[0]); - t2 = vec_mergel(c3[0], c4[0]); - t3 = vec_xxpermdi(t1, t2, 0); - t4 = vec_xxpermdi(t1, t2, 3); - vec_xst(t3, 0, boffset+8); - vec_xst(t4, 0, boffset+12); - } + t0 = vec_mergeh(row[0], row[1]); + t1 = vec_mergeh(row[2], row[3]); + t2 = vec_mergeh(row[4], row[5]); + t3 = vec_mergeh(row[6], row[7]); + t4 = vec_xxpermdi(t0, t1, 0); + t5 = vec_xxpermdi(t2, t3, 0); + t6 = vec_xxpermdi(t0, t1, 3); + t7 = vec_xxpermdi(t2, t3, 3); + + vec_xst(t4, 0, dst + 0* ld_dst); + vec_xst(t5, 16, dst + 0* ld_dst); + vec_xst(t6, 0, dst + 1* ld_dst); + vec_xst(t7, 16, dst + 1* ld_dst); + + t0 = vec_mergel(row[0], row[1]); + t1 = vec_mergel(row[2], row[3]); + t2 = vec_mergel(row[4], row[5]); + t3 = vec_mergel(row[6], row[7]); + t4 = vec_xxpermdi(t0, t1, 0); + t5 = vec_xxpermdi(t2, t3, 0); + t6 = vec_xxpermdi(t0, t1, 3); + t7 = vec_xxpermdi(t2, t3, 3); + + vec_xst(t4, 0, dst + 2* ld_dst); + vec_xst(t5, 16, dst + 2* ld_dst); + vec_xst(t6, 0, dst + 3* ld_dst); + vec_xst(t7, 16, dst + 3* ld_dst); + + t0 = vec_mergeh(row[8], row[9]); + t1 = vec_mergeh(row[10], row[11]); + t2 = vec_mergeh(row[12], row[13]); + t3 = vec_mergeh(row[14], row[15]); + t4 = vec_xxpermdi(t0, t1, 0); + t5 = vec_xxpermdi(t2, t3, 0); + t6 = vec_xxpermdi(t0, t1, 3); + t7 = vec_xxpermdi(t2, t3, 3); + + vec_xst(t4, 0, dst + 4* ld_dst); + vec_xst(t5, 16, dst + 4* ld_dst); + vec_xst(t6, 0, dst + 5* ld_dst); + vec_xst(t7, 16, dst + 5* ld_dst); + + t0 = vec_mergel(row[8], row[9]); + t1 = vec_mergel(row[10], row[11]); + t2 = vec_mergel(row[12], row[13]); + t3 = vec_mergel(row[14], row[15]); + t4 = vec_xxpermdi(t0, t1, 0); + t5 = vec_xxpermdi(t2, t3, 0); + t6 = vec_xxpermdi(t0, t1, 3); + t7 = vec_xxpermdi(t2, t3, 3); + + vec_xst(t4, 0, dst + 6* ld_dst); + vec_xst(t5, 16, dst + 6* ld_dst); + vec_xst(t6, 0, dst + 7* ld_dst); + vec_xst(t7, 16, dst + 7* ld_dst); + } + + void transpose_tile_8x4(const float* src, float* dst, int ld_src, int ld_dst) { + vector float row[8]; + vector float t0, t1, t2, t3, t4, t5, t6, t7; + for (int i = 0; i < 4; i++) { + row[i] = vec_xl(0, src + i * ld_src); + row[i+4] = vec_xl(16, src + i * ld_src); } - if (rows & 3) { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - if (cols & 4) { - c1[0] = vec_xl(0, aoffset1); - c2[0] = vec_xl(0, aoffset2); - c3[0] = vec_xl(0, aoffset3); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_xxpermdi(t1, t2, 0); - t4 = vec_xxpermdi(t1, t2, 3); - vec_xst(t3, 0, boffset); - vec_xst(t4, 0, boffset+4); - - t1 = vec_mergel(c1[0], c2[0]); - t2 = vec_mergel(c3[0], c4[0]); - t3 = vec_xxpermdi(t1, t2, 0); - t4 = vec_xxpermdi(t1, t2, 3); - vec_xst(t3, 0, boffset+8); - vec_xst(t4, 0, boffset+12); + // Transpose 8x8 block using pairwise merges + t0 = vec_mergeh(row[0], row[1]); + t1 = vec_mergeh(row[2], row[3]); + t2 = vec_mergel(row[0], row[1]); + t3 = vec_mergel(row[2], row[3]); + t4 = vec_xxpermdi(t0, t1, 0); + t5 = vec_xxpermdi(t0, t1, 3); + t6 = vec_xxpermdi(t2, t3, 0); + t7 = vec_xxpermdi(t2, t3, 3); + + vec_xst(t4, 0, dst + 0* ld_dst); + vec_xst(t5, 0, dst + 1* ld_dst); + vec_xst(t6, 0, dst + 2* ld_dst); + vec_xst(t7, 0, dst + 3* ld_dst); + + t0 = vec_mergeh(row[4], row[5]); + t1 = vec_mergeh(row[6], row[7]); + t2 = vec_mergel(row[4], row[5]); + t3 = vec_mergel(row[6], row[7]); + t4 = vec_xxpermdi(t0, t1, 0); + t5 = vec_xxpermdi(t0, t1, 3); + t6 = vec_xxpermdi(t2, t3, 0); + t7 = vec_xxpermdi(t2, t3, 3); + + vec_xst(t4, 0, dst + 4* ld_dst); + vec_xst(t5, 0, dst + 5* ld_dst); + vec_xst(t6, 0, dst + 6* ld_dst); + vec_xst(t7, 0, dst + 7* ld_dst); + } + + void transpose_tile_4x8(const float* src, float* dst, int ld_src, int ld_dst) { + vector float row[8]; + vector float t0, t1, t2, t3, t4, t5, t6, t7; + for (int i = 0; i < 4; i++) { + row[i] = vec_xl(0, src + i * ld_src); + row[i+4] = vec_xl(16, src + i * ld_src); + } + t0 = vec_mergeh(row[0], row[1]); + t1 = vec_mergeh(row[2], row[3]); + t2 = vec_mergel(row[0], row[1]); + t3 = vec_mergel(row[2], row[3]); + t4 = vec_xxpermdi(t0, t1, 0); + t5 = vec_xxpermdi(t0, t1, 3); + t6 = vec_xxpermdi(t2, t3, 0); + t7 = vec_xxpermdi(t2, t3, 3); + + vec_xst(t4, 0, dst + 0* ld_dst); + vec_xst(t5, 0, dst + 1* ld_dst); + vec_xst(t6, 0, dst + 2* ld_dst); + vec_xst(t7, 0, dst + 3* ld_dst); + + t0 = vec_mergeh(row[4], row[5]); + t1 = vec_mergeh(row[6], row[7]); + t2 = vec_mergel(row[4], row[5]); + t3 = vec_mergel(row[6], row[7]); + t4 = vec_xxpermdi(t0, t1, 0); + t5 = vec_xxpermdi(t0, t1, 3); + t6 = vec_xxpermdi(t2, t3, 0); + t7 = vec_xxpermdi(t2, t3, 3); + + vec_xst(t4, 0, dst + 4* ld_dst); + vec_xst(t5, 0, dst + 5* ld_dst); + vec_xst(t6, 0, dst + 6* ld_dst); + vec_xst(t7, 0, dst + 7* ld_dst); + + } + void transpose_tile_4x4(const float* src, float* dst, int ld_src, int ld_dst) { + vector float row[4]; + vector float t0, t1, t2, t3, t4, t5, t6, t7; + for (int i = 0; i < 4; i++) { + row[i] = vec_xl(0, src + i * ld_src); + } + t0 = vec_mergeh(row[0], row[1]); + t1 = vec_mergeh(row[2], row[3]); + t2 = vec_mergel(row[0], row[1]); + t3 = vec_mergel(row[2], row[3]); + t4 = vec_xxpermdi(t0, t1, 0); + t5 = vec_xxpermdi(t0, t1, 3); + t6 = vec_xxpermdi(t2, t3, 0); + t7 = vec_xxpermdi(t2, t3, 3); + + vec_xst(t4, 0, dst + 0* ld_dst); + vec_xst(t5, 0, dst + 1* ld_dst); + vec_xst(t6, 0, dst + 2* ld_dst); + vec_xst(t7, 0, dst + 3* ld_dst); + } + + void transpose_matrix(const float* A, float* B, int rows, int cols) { + for (int i = 0; i < rows; i += 8) { + for (int j = 0; j < cols; j += 8) { + int tile_h = std::min(8, rows - i); + int tile_w = std::min(8, cols - j); + if (tile_h == 8 && tile_w == 8) { + transpose_tile_8x8(A + j * rows + i, B + i * cols + j, rows, cols); + } else if (tile_h == 4 && tile_w == 8) { + transpose_tile_4x8(A + j * rows + i, B + i * cols + j, rows, cols); + } else if (tile_h == 8 && tile_w == 4) { + transpose_tile_8x4(A + j * rows + i, B + i * cols + j, rows, cols); + } else if (tile_h == 4 && tile_w == 4){ + transpose_tile_4x4(A + j * rows + i, B + i * cols + j, rows, cols); + } + else { + for (int ti = 0; ti < tile_h; ++ti) { + for (int tj = 0; tj < tile_w; ++tj) { + int a_index = (i + ti) + (j + tj) * rows; + int b_index = (j + tj) + (i + ti) * cols; + B[b_index] = A[a_index]; + } + } + } } } } @@ -2925,8 +2870,10 @@ class tinyBLAS_PPC { acc_t acc_0; __builtin_mma_xxsetaccz(&acc_0); for (int l = 0; l < k; l+=4) { - packTranspose(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B); + for (int x = 0; x< 4; x++){ + vec_A[x] = (vec_t)vec_xl(0, (float *)(At + (l + x) * m_orig + ii)); + vec_B[x] = (vec_t)vec_xl(0, (float*)(Bt + (l + x) * n_orig + jj)); + } __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]); @@ -2941,8 +2888,11 @@ class tinyBLAS_PPC { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); for (int64_t l = 0; l < k; l+=4) { - packTranspose(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B); + for (int x = 0; x< 4; x++){ + vec_A[x] = (vec_t)vec_xl(0, (float *)(At + (l + x) * m_orig + ii)); + vec_B[2*x] = (vec_t)vec_xl(0, (float*)(Bt + (l + x) * n_orig + jj)); + vec_B[2*x+1] = (vec_t)vec_xl(0, (float*)(Bt + (l + x) * n_orig + jj+4)); + } __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]); __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]); @@ -2951,7 +2901,7 @@ class tinyBLAS_PPC { __builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]); __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]); - } + } SAVE_ACC(&acc_0, ii, jj); SAVE_ACC(&acc_1, ii, jj+4); } @@ -2962,8 +2912,11 @@ class tinyBLAS_PPC { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); for (int64_t l = 0; l < k; l+=4) { - packTranspose(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B); + for (int x = 0; x< 4; x++){ + vec_A[2*x] = (vec_t)vec_xl(0, (float*)(At + (l + x) * m_orig + ii)); + vec_A[2*x+1] = (vec_t)vec_xl(0, (float*)(At + (l + x) * m_orig + ii + 4)); + vec_B[x] = (vec_t)vec_xl(0, (float *)(Bt + (l + x) * n_orig + jj)); + } __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]); @@ -2985,8 +2938,12 @@ class tinyBLAS_PPC { __builtin_mma_xxsetaccz(&acc_2); __builtin_mma_xxsetaccz(&acc_3); for (int l = 0; l < k; l+=8) { - packTranspose(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B); + for (int x = 0; x < 8; x++) { + vec_A[2*x] = (vec_t)vec_xl(0, (float *)(At + (l + x) * m_orig + ii)); + vec_A[2*x + 1] = (vec_t)vec_xl(0, (float *)(At + (l + x) * m_orig + ii + 4)); + vec_B[2*x] = (vec_t)vec_xl(0, (float *)(Bt + (l + x) * n_orig + jj)); + vec_B[2*x + 1] = (vec_t)vec_xl(0, (float *)(Bt + (l + x) * n_orig + jj + 4)); + } for(int x = 0; x < 16; x+=2) { __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]); __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]); @@ -3000,156 +2957,130 @@ class tinyBLAS_PPC { SAVE_ACC(&acc_3, ii+4, jj+4); } + void KERNEL_8x16(int64_t ii, int64_t jj) { + vec_t vec_A[16], vec_B[32], vec_C[4]; + acc_t acc[8]; + + for (int i = 0; i < 8; i++) + __builtin_mma_xxsetaccz(&acc[i]); + + for (int l = 0; l < k; l += 8) { + for (int x = 0; x < 8; x++) { + vec_A[2*x] = (vec_t)vec_xl(0, (float *)(At + (l + x) * m_orig + ii)); + vec_A[2*x + 1] = (vec_t)vec_xl(0, (float *)(At + (l + x) * m_orig + ii + 4)); + vec_B[4*x] = (vec_t)vec_xl(0, (float *)(Bt + (l + x) * n_orig + jj)); + vec_B[4*x + 1] = (vec_t)vec_xl(0, (float *)(Bt + (l + x) * n_orig + jj + 4)); + vec_B[4*x + 2] = (vec_t)vec_xl(0, (float *)(Bt + (l + x) * n_orig + jj + 8)); + vec_B[4*x + 3] = (vec_t)vec_xl(0, (float *)(Bt + (l + x) * n_orig + jj + 12)); + } + + for (int x = 0; x < 16; x += 2) { + __builtin_mma_xvf32gerpp(&acc[0], vec_A[x], vec_B[2*x]); + __builtin_mma_xvf32gerpp(&acc[1], vec_A[x], vec_B[2*x + 1]); + __builtin_mma_xvf32gerpp(&acc[2], vec_A[x], vec_B[2*x + 2]); + __builtin_mma_xvf32gerpp(&acc[3], vec_A[x], vec_B[2*x + 3]); + __builtin_mma_xvf32gerpp(&acc[4], vec_A[x + 1], vec_B[2*x]); + __builtin_mma_xvf32gerpp(&acc[5], vec_A[x + 1], vec_B[2*x + 1]); + __builtin_mma_xvf32gerpp(&acc[6], vec_A[x + 1], vec_B[2*x + 2]); + __builtin_mma_xvf32gerpp(&acc[7], vec_A[x + 1], vec_B[2*x + 3]); + } + } + + SAVE_ACC(&acc[0], ii, jj); + SAVE_ACC(&acc[1], ii, jj+4); + SAVE_ACC(&acc[2], ii, jj+8); + SAVE_ACC(&acc[3], ii, jj+12); + SAVE_ACC(&acc[4], ii+4, jj); + SAVE_ACC(&acc[5], ii+4, jj+4); + SAVE_ACC(&acc[6], ii+4, jj+8); + SAVE_ACC(&acc[7], ii+4, jj+12); + } + + void KERNEL_16x8(int64_t ii, int64_t jj) { + vec_t vec_A[32], vec_B[16], vec_C[4]; + acc_t acc[8]; + + for (int i = 0; i < 8; i++) + __builtin_mma_xxsetaccz(&acc[i]); + + for (int l = 0; l < k; l += 8) { + for (int x = 0; x < 8; x++) { + vec_A[4*x] = (vec_t)vec_xl(0, (float *)(At + (l + x) * m_orig + ii)); + vec_A[4*x+1] = (vec_t)vec_xl(0, (float *)(At + (l + x) * m_orig + ii + 4)); + vec_A[4*x+2] = (vec_t)vec_xl(0, (float *)(At + (l + x) * m_orig + ii + 8)); + vec_A[4*x+3] = (vec_t)vec_xl(0, (float *)(At + (l + x) * m_orig + ii + 12)); + vec_B[2*x] = (vec_t)vec_xl(0, (float *)(Bt + (l + x) * n_orig + jj)); + vec_B[2*x+1] = (vec_t)vec_xl(0, (float *)(Bt + (l + x) * n_orig + jj + 4)); + } + + for (int x = 0; x < 32; x += 4) { + __builtin_mma_xvf32gerpp(&acc[0], vec_A[x], vec_B[x / 2]); + __builtin_mma_xvf32gerpp(&acc[1], vec_A[x], vec_B[x / 2 + 1]); + __builtin_mma_xvf32gerpp(&acc[2], vec_A[x + 1], vec_B[x / 2]); + __builtin_mma_xvf32gerpp(&acc[3], vec_A[x + 1], vec_B[x / 2 + 1]); + __builtin_mma_xvf32gerpp(&acc[4], vec_A[x + 2], vec_B[x / 2]); + __builtin_mma_xvf32gerpp(&acc[5], vec_A[x + 2], vec_B[x / 2 + 1]); + __builtin_mma_xvf32gerpp(&acc[6], vec_A[x + 3], vec_B[x / 2]); + __builtin_mma_xvf32gerpp(&acc[7], vec_A[x + 3], vec_B[x / 2 + 1]); + } + } + SAVE_ACC(&acc[0], ii, jj); + SAVE_ACC(&acc[1], ii, jj + 4); + SAVE_ACC(&acc[2], ii+4, jj); + SAVE_ACC(&acc[3], ii+4, jj + 4); + SAVE_ACC(&acc[4], ii+8, jj); + SAVE_ACC(&acc[5], ii+8, jj + 4); + SAVE_ACC(&acc[6], ii+12, jj); + SAVE_ACC(&acc[7], ii+12, jj+4); + } + + void handle_small_tile(int64_t m0, int64_t m, int64_t n0, int64_t n, int m_rem, int n_rem) { + int mc = m_rem > 4 ? 4 : m_rem; + int nc = n_rem > 4 ? 4 : n_rem; + + if (mc > 0 && nc > 0) { + gemm_small(m0, m, n0, n, mc, nc); + // Recurse over remaining blocks + if (m0 + mc < m) { + handle_small_tile(m0 + mc, m, n0, n0 + nc, m - (m0 + mc), nc); + } + if (n0 + nc < n) { + handle_small_tile(m0, m0 + mc, n0 + nc, n, mc, n - (n0 + nc)); + } + } + } + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t mc, nc, mp, np; int m_rem = MIN(m - m0, 16); int n_rem = MIN(n - n0, 16); + int mc = 0, nc = 0; if (m_rem >= 16 && n_rem >= 8) { - mc = 8; - nc = 8; - gemm<8,8>(m0, m, n0, n); - } else if(m_rem >= 8 && n_rem >= 16) { - mc = 8; - nc = 8; - gemm<8,8>(m0, m, n0, n); - } else if (m_rem >= 8 && n_rem >= 8) { - mc = 8; - nc = 8; + mc = 16; nc = 8; + gemm<16,8>(m0, m, n0, n); + } else if (m_rem >= 8 && n_rem >= 16) { + mc = 8; nc = 16; + gemm<8,16>(m0, m, n0, n); + } else if (m_rem >= 8 && n_rem >=8) { + mc = 8; nc = 8; gemm<8,8>(m0, m, n0, n); } else if (m_rem >= 4 && n_rem >= 8) { - mc = 4; - nc = 8; + mc = 4; nc = 8; gemm<4,8>(m0, m, n0, n); } else if (m_rem >= 8 && n_rem >= 4) { - mc = 8; - nc = 4; + mc = 8; nc = 4; gemm<8,4>(m0, m, n0, n); } else if (m_rem >= 4 && n_rem >= 4) { - mc = 4; - nc = 4; + mc = 4; nc = 4; gemm<4,4>(m0, m, n0, n); - } else if ((m_rem < 4) && (n_rem > 4)) { - nc = 4; - switch(m_rem) { - case 1: - mc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 2: - mc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 3: - mc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - default: - return; - } - } else if ((m_rem > 4) && (n_rem < 4)) { - mc = 4; - switch(n_rem) { - case 1: - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 2: - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 3: - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - default: - return; - } } else { - switch((m_rem << 4) | n_rem) { - case 0x43: - mc = 4; - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x42: - mc = 4; - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x41: - mc = 4; - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x34: - mc = 3; - nc = 4; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x33: - mc = 3; - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x32: - mc = 3; - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x31: - mc = 3; - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x24: - mc = 2; - nc = 4; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x23: - mc = 2; - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x22: - mc = 2; - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x21: - mc = 2; - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x14: - mc = 1; - nc = 4; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x13: - mc = 1; - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x12: - mc = 1; - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x11: - mc = 1; - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - default: - return; - } + handle_small_tile(m0, m, n0, n, m_rem, n_rem); + return; } - mp = m0 + (m - m0) / mc * mc; - np = n0 + (n - n0) / nc * nc; + int64_t mp = m0 + ((m - m0) / mc) * mc; + int64_t np = n0 + ((n - n0) / nc) * nc; mnpack(mp, m, n0, np); mnpack(m0, m, np, n); - } + } void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { int64_t ytiles = (m - m0) / RM; @@ -3174,23 +3105,39 @@ class tinyBLAS_PPC { * matrix elements. */ if (RM == 1) { + //printf("RM =1 case\n"); TA* a = const_cast(A+(ii)*lda+l); - packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B); + for (int x = 0; x < 4; x++) { + vec_B[x] = (vec_t)vec_xl(0, (float*)(Bt + (l + x)* n_orig + jj)); + } vec_A[0] = (vec_t)vec_xl(0,a); vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1)); vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2)); vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3)); } else if (RN == 1) { - packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A); + //printf("RN =1 case\n"); + for (int x = 0; x < 4; x++) { + vec_A[x] = (vec_t)vec_xl(0, (float*)(At + (l + x)* m_orig + ii)); + } TB* b = const_cast(B+(jj)*ldb+l); vec_B[0] = (vec_t)vec_xl(0,b); vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1)); vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2)); vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3)); } else { - packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B); + //printf("Third case\n"); + + for (int x = 0; x < 4; x++) { + vec_A[x] = (vec_t)vec_xl(0, (float *)(At + (l + x) * m_orig + ii)); + vec_B[x] = (vec_t)vec_xl(0, (float *)(Bt + (l + x) * n_orig + jj)); + } } + /*printf("In gemm small\n"); + for (int i = 0; i< 4; i++){ + dump_vec_f("A", (vector float)vec_A[i]); + dump_vec_f("B", (vector float)vec_B[i]); + }*/ + __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]); @@ -3204,6 +3151,24 @@ class tinyBLAS_PPC { } } } + template + inline void kernel(int64_t ii, int64_t jj, int64_t m, int64_t n) { + if constexpr(RM == 4 && RN == 4) { + KERNEL_4x4(ii,jj); + } else if constexpr(RM == 4 && RN == 8) { + KERNEL_4x8(ii,jj); + } else if constexpr(RM == 8 && RN == 4) { + KERNEL_8x4(ii,jj); + } else if constexpr(RM == 8 && RN == 8) { + KERNEL_8x8(ii,jj); + } else if constexpr(RM == 8 && RN == 16) { + KERNEL_8x16(ii,jj); + }else if constexpr(RM == 16 && RN == 8) { + KERNEL_16x8(ii,jj); + }else { + static_assert(false, "RN/RM values not supported"); + } + } template NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { @@ -3213,21 +3178,13 @@ class tinyBLAS_PPC { int64_t duty = (tiles + nth - 1) / nth; int64_t start = duty * ith; int64_t end = start + duty; - if (RM == 4 && RN == 4) { - kernel = &tinyBLAS_PPC::KERNEL_4x4; - } else if (RM == 4 && RN == 8) { - kernel = &tinyBLAS_PPC::KERNEL_4x8; - } else if (RM == 8 && RN == 4) { - kernel = &tinyBLAS_PPC::KERNEL_8x4; - } else if (RM == 8 && RN == 8) { - kernel = &tinyBLAS_PPC::KERNEL_8x8; - } + if (end > tiles) end = tiles; for (int64_t job = start; job < end; ++job) { int64_t ii = m0 + job / xtiles * RM; int64_t jj = n0 + job % xtiles * RN; - (this->*kernel)(ii, jj); + kernel(ii, jj, m, n); } } @@ -3237,6 +3194,8 @@ class tinyBLAS_PPC { TA *At; TB *Bt; const int64_t k; + int64_t m_orig; + int64_t n_orig; const int64_t lda; const int64_t ldb; const int64_t ldc;