Skip to content

Commit

Permalink
Ruy ARM32: Optimize 8bit kernel
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 270087717
  • Loading branch information
talumbau authored and tensorflower-gardener committed Sep 19, 2019
1 parent 0d3b70f commit 2359c4e
Showing 1 changed file with 68 additions and 76 deletions.
144 changes: 68 additions & 76 deletions tensorflow/lite/experimental/ruy/kernel_arm32.cc
Expand Up @@ -620,8 +620,9 @@ void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 2>& params) {
// /------------------------\ /-----------------------------\
// |d0.b[0-7] ... d1.b[0-7] | | q6 ..... q10 |
// |d2.b[0-7] ... d3.b[0-7] | | q7 ..... q11 |
// |d4.b[0-7] ... d5.b[0-7] | | q8 ..... q12 |
// |d6.b[0-7] ... d7.b[0-7] | | q9 ..... q13 |
// (Reload d0, d1, d2, d3)
// |d0.b[0-7] ... d1.b[0-7] | | q8 ..... q12 |
// |d2.b[0-7] ... d3.b[0-7] | | q9 ..... q13 |
// \------------------------/ \-----------------------------/
// 128-bit accumulators 4x2 block
//
Expand All @@ -634,24 +635,32 @@ void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 2>& params) {

// Load the first 64 bytes of LHS and RHS data.
"vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
// Clear accumulators.
RUY_MAKE_ZERO(q6)
"vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
"vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
"vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
RUY_MAKE_ZERO(q8)
RUY_MAKE_ZERO(q9)
RUY_MAKE_ZERO(q10)
"vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
RUY_MAKE_ZERO(q11)
"vld1.8 {d10, d11}, [%[rhs_ptr]]!\n"

"sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"

"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
RUY_MAKE_ZERO(q12)
"str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"

"ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
RUY_MAKE_ZERO(q13)
"str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"

"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
RUY_MAKE_ZERO(q14)
"str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"

"ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n"
RUY_MAKE_ZERO(q15)
"str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"

"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
Expand All @@ -660,17 +669,6 @@ void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 2>& params) {
"ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n"
"str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"

// Clear accumulators.
RUY_MAKE_ZERO(q6)
RUY_MAKE_ZERO(q7)
RUY_MAKE_ZERO(q8)
RUY_MAKE_ZERO(q9)
RUY_MAKE_ZERO(q10)
RUY_MAKE_ZERO(q11)
RUY_MAKE_ZERO(q12)
RUY_MAKE_ZERO(q13)
RUY_MAKE_ZERO(q14)
RUY_MAKE_ZERO(q15)

// r1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
Expand All @@ -689,50 +687,47 @@ void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 2>& params) {

"2:\n"

// Mult, mult-acc in to q14, q15
// Mult, mult-acc in to q14, q15, q2, q3
"vmull.s8 q14, d0, d8\n"
"vmull.s8 q2, d0, d10\n"

"vmull.s8 q15, d2, d8\n"
"vmull.s8 q3, d2, d10\n"

"vmlal.s8 q14, d1, d9\n"
"vmlal.s8 q2, d1, d11\n"
"vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" // Reload LHS
"vmlal.s8 q15, d3, d9\n"
"vmlal.s8 q3, d3, d11\n"
"vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS

// Then pairwise accumulate in to q6, q7
// Then pairwise accumulate in to q6, q7, q10, q11
"vpadal.s16 q6, q14\n"
"vpadal.s16 q7, q15\n"
"vpadal.s16 q10, q2\n"
"vpadal.s16 q11, q3\n"

// Mult, mult-acc in to q14, q15
"vmull.s8 q14, d4, d8\n"
"vmull.s8 q15, d6, d8\n"
"vmlal.s8 q14, d5, d9\n"
"vmlal.s8 q15, d7, d9\n"
"vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
// Mult, mult-acc in to q14, q15, q2, q3
"vmull.s8 q14, d0, d8\n"
"vmull.s8 q2, d0, d10\n"

// Then pairwise accumulate in to q8, q9
"vpadal.s16 q8, q14\n"
"vpadal.s16 q9, q15\n"
"vmull.s8 q15, d2, d8\n"
"vmull.s8 q3, d2, d10\n"

// Mult, mult-acc in to q14, q15
"vmull.s8 q14, d0, d10\n"
"vmull.s8 q15, d2, d10\n"
"vmlal.s8 q14, d1, d11\n"
"vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
"vmlal.s8 q15, d3, d11\n"
"vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
"vmlal.s8 q14, d1, d9\n"
"vmlal.s8 q2, d1, d11\n"
"vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" // Reload LHS
"vmlal.s8 q15, d3, d9\n"
"vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
"vmlal.s8 q3, d3, d11\n"
"vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS

// Then pairwise accumulate in to q8, q9
"vpadal.s16 q10, q14\n"
"vpadal.s16 q11, q15\n"

// Mult, mult-acc in to q14, q15
"vmull.s8 q14, d4, d10\n"
"vmull.s8 q15, d6, d10\n"
"vmlal.s8 q14, d5, d11\n"
"vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
"vmlal.s8 q15, d7, d11\n"
"vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
// Then pairwise accumulate in to q12, q13
"vpadal.s16 q12, q14\n"
// Then pairwise accumulate in to q8, q9, q12, q13
"vpadal.s16 q8, q14\n"
"vld1.8 {d10, d11}, [%[rhs_ptr]]!\n"
"vpadal.s16 q13, q15\n"
"vpadal.s16 q9, q15\n"
"vpadal.s16 q12, q2\n"
"vpadal.s16 q13, q3\n"

// Prefetch the next 64 bytes of LHS and RHS data.
RUY_PREFETCH("pld [%[lhs_ptr]]\n")
Expand All @@ -748,45 +743,44 @@ void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 2>& params) {

"79:\n"

// Mult, mult-acc in to q14, q15
// Mult, mult-acc in to q14, q15, q2, q3
"vmull.s8 q14, d0, d8\n"
"vmull.s8 q2, d0, d10\n"

"vmull.s8 q15, d2, d8\n"
"vmull.s8 q3, d2, d10\n"

"vmlal.s8 q14, d1, d9\n"
"vmlal.s8 q2, d1, d11\n"
"vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" // Reload LHS
"vmlal.s8 q15, d3, d9\n"
"vmlal.s8 q3, d3, d11\n"
"vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS

// Then pairwise accumulate in to q6, q7
// Then pairwise accumulate in to q6, q7, q10, q11
"vpadal.s16 q6, q14\n"
"vpadal.s16 q7, q15\n"
"vpadal.s16 q10, q2\n"
"vpadal.s16 q11, q3\n"

// Mult, mult-acc in to q14, q15
"vmull.s8 q14, d4, d8\n"
"vmull.s8 q15, d6, d8\n"
"vmlal.s8 q14, d5, d9\n"
"vmlal.s8 q15, d7, d9\n"

// Then pairwise accumulate in to q8, q9
"vpadal.s16 q8, q14\n"
"vpadal.s16 q9, q15\n"
// Mult, mult-acc in to q14, q15, q2, q3
"vmull.s8 q14, d0, d8\n"
"vmull.s8 q2, d0, d10\n"

// Mult, mult-acc in to q14, q15
"vmull.s8 q14, d0, d10\n"
"vmull.s8 q15, d2, d10\n"
"vmlal.s8 q14, d1, d11\n"
"vmlal.s8 q15, d3, d11\n"
"vmull.s8 q15, d2, d8\n"
"vmull.s8 q3, d2, d10\n"

// Then pairwise accumulate in to q10, q11
"vpadal.s16 q10, q14\n"
"vpadal.s16 q11, q15\n"
"vmlal.s8 q14, d1, d9\n"
"vmlal.s8 q2, d1, d11\n"
"vmlal.s8 q15, d3, d9\n"
"vmlal.s8 q3, d3, d11\n"

// Then pairwise accumulate in to q8, q9
"vmull.s8 q14, d4, d10\n"
"vmull.s8 q15, d6, d10\n"
"vmlal.s8 q14, d5, d11\n"
"vmlal.s8 q15, d7, d11\n"
// Then pairwise accumulate in to q8, q9, q12, q13
"vpadal.s16 q8, q14\n"
"vpadal.s16 q9, q15\n"
"vpadal.s16 q12, q2\n"
"vpadal.s16 q13, q3\n"

// Then pairwise accumulate in to q12, q13
"vpadal.s16 q12, q14\n"
"vpadal.s16 q13, q15\n"

// All accumulation over depth done. q6 - q13 contain the 4x32b
// accumulators for the 4x2 final matrix.
Expand Down Expand Up @@ -889,8 +883,6 @@ void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 2>& params) {
// in the rest of the work on the current block.
"vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
"vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
"vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
"vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
RUY_PREFETCH("pld [%[lhs_ptr]]\n")
"vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n"
RUY_PREFETCH("pld [%[rhs_ptr]]\n")
Expand Down

0 comments on commit 2359c4e

Please sign in to comment.