Skip to content

Commit

Permalink
Update HPLAI_blas.cc
Browse files Browse the repository at this point in the history
  • Loading branch information
WuK authored Mar 24, 2021
1 parent 26dbe47 commit 5b2f262
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions src/blas/HPLAI_blas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,18 +220,25 @@ void blas::gemm<HPLAI_T_AFLOAT, HPLAI_T_AFLOAT, HPLAI_T_AFLOAT>(
return;
}

int64_t padding_size = 128 / sizeof(HPLAI_T_AFLOAT);
if (padding_size < 1)
padding_size = 1;

blas::Op TRANSC = blas::Op::NoTrans;

int64_t rC = TRANSC == blas::Op::NoTrans ? M : N;
int64_t cC = TRANSC == blas::Op::NoTrans ? N : M;
int64_t dLDC = rC + 127 >> 7 << 7;
int64_t dLDC = (rC + padding_size - 1) / padding_size * padding_size;
int64_t dsC = cC * dLDC;

int64_t rB = TRANSB == blas::Op::NoTrans ? K : N;
int64_t cB = TRANSB == blas::Op::NoTrans ? N : K;
int64_t dLDB = rB + 127 >> 7 << 7;
int64_t dLDB = (rB + padding_size - 1) / padding_size * padding_size;
int64_t dsB = cB * dLDB;

int64_t rA = TRANSA == blas::Op::NoTrans ? M : K;
int64_t cA = TRANSA == blas::Op::NoTrans ? K : M;
int64_t dLDA = rA + 127 >> 7 << 7;
int64_t dLDA = (rA + padding_size - 1) / padding_size * padding_size;
int64_t dsA = cA * dLDA;

if (HPLAI_DEVICE_BLASPP_BUFFER_SIZE < dsC + dsB + dsA)
Expand All @@ -254,12 +261,12 @@ void blas::gemm<HPLAI_T_AFLOAT, HPLAI_T_AFLOAT, HPLAI_T_AFLOAT>(
K,
ALPHA,
dA,
LDA,
dLDA,
dB,
LDB,
dLDB,
BETA,
dC,
LDC,
dLDC,
*HPLAI_DEVICE_BLASPP_QUEUE);

blas::device_getmatrix<HPLAI_T_AFLOAT>(rC, cC, dC, dLDC, C, LDC, *HPLAI_DEVICE_BLASPP_QUEUE);
Expand Down Expand Up @@ -708,14 +715,18 @@ void blas::trsm<HPLAI_T_AFLOAT, HPLAI_T_AFLOAT>(
return;
}

int64_t padding_size = 128 / sizeof(HPLAI_T_AFLOAT);
if (padding_size < 1)
padding_size = 1;

int64_t rB = M;
int64_t cB = N;
int64_t dLDB = rB + 127 >> 7 << 7;
int64_t dLDB = (rB + padding_size - 1) / padding_size * padding_size;
int64_t dsB = cB * dLDB;

int64_t rA = SIDE == blas::Side::Left ? M : N;
int64_t cA = SIDE == blas::Side::Left ? M : N;
int64_t dLDA = rA + 127 >> 7 << 7;
int64_t dLDA = (rA + padding_size - 1) / padding_size * padding_size;
int64_t dsA = cA * dLDA;

if (HPLAI_DEVICE_BLASPP_BUFFER_SIZE < dsB + dsA)
Expand Down

0 comments on commit 5b2f262

Please sign in to comment.