Skip to content
Permalink
Browse files

[codegen] added leading dimension padding for transposition in shared

memory
  • Loading branch information...
ptillet committed May 5, 2019
1 parent 4813bb0 commit f80441017c5d6c72fd8327ce2137c3f41a8891d3
@@ -14,6 +14,17 @@ void simple_gemm(std::vector<T> &c, const std::vector<T> &a, const std::vector<T
}
}

template<class T>
void simple_gemm(bool AT, bool BT, std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b, size_t M, size_t N, size_t K) {
if(AT && BT)
simple_gemm<T, true, true>(c, a, b, M, N, K);
else if(AT && !BT)
simple_gemm<T, true, false>(c, a, b, M, N, K);
else if(!AT && BT)
simple_gemm<T, false, true>(c, a, b, M, N, K);
else
simple_gemm<T, false, false>(c, a, b, M, N, K);
}

class timer{
typedef std::chrono::high_resolution_clock high_resolution_clock;
@@ -5,63 +5,104 @@
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"

const char* src =
R"(
const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {8};
const tunable int32 GZ = {1};

void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
int32 M, int32 N, int32 K,
int32 lda, int32 ldb, int32 ldc,
int32 *locks, int32 grid0, int32 grid1) {
int32 rxa[TM] = get_global_range[TM](0);
int32 ryb[TN] = get_global_range[TN](1);
int32 rz = get_global_range[1](2);
int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK;
fp32 c[TM, TN] = 0;
int32 div = K / GZ;
int32 rem = K % GZ;
K = select(rz < rem, div - 1, div);
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);
fp32* pa[TM, TK] = A + (offk + rka[newaxis, :])*lda + rxa[:, newaxis];
fp32* pb[TN, TK] = B + (offk + rkb[newaxis, :])*ldb + ryb[:, newaxis];
fp32 a[TM, TK] = *pa;
fp32 b[TN, TK] = *pb;
int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda;
int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb;
last_a = last_a / TK * TK;
last_b = last_b / TK * TK;
int32 bound = K - max(last_a, last_b);
for(int32 k = K; k > bound; k = k - TK){
c = dot(a, trans(b), c);
pa = pa + TK*lda;
pb = pb + TK*ldb;
a = *pa;
b = *pb;
std::string triton_source(bool AT, bool BT) {
std::string AS0 = "TM", AS1 = "TK";
std::string BS0 = "TK", BS1 = "TN";
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
std::string lda0 = "*lda", lda1 = "";
std::string ldb0 = "", ldb1 = "*ldb";
std::string usea = AT ? "trans(a)" : "a";
std::string useb = BT ? "trans(b)" : "b";
if(AT){
std::swap(AS0, AS1);
std::swap(bca0, bca1);
std::swap(lda0, lda1);
}
int32 rxc[TM] = get_global_range[TM](0);
int32 ryc[TN] = get_global_range[TN](1);
for(int32 k = bound; k > 0; k = k - 1){
int1 checka[TM, 1] = rxc[:, newaxis] < M;
int1 checkb[TN, 1] = ryc[:, newaxis] < N;
fp32* pa[TM, 1] = A + (offk + K - k)*lda + rxc[:, newaxis];
fp32* pb[TN, 1] = B + (offk + K - k)*ldb + ryc[:, newaxis];
fp32 a[TM, 1] = checka ? *pa : 0;
fp32 b[TN, 1] = checkb ? *pb : 0;
c = dot(a, trans(b), c);
if(BT){
std::swap(BS0, BS1);
std::swap(bcb0, bcb1);
std::swap(ldb0, ldb1);
}
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
int1 checkc0[TM] = rxc < M;
int1 checkc1[TN] = ryc < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
@checkc *pc = c;
std::string res =
R"(
const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {8};
const tunable int32 GZ = {1};
void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
int32 M, int32 N, int32 K,
int32 lda, int32 ldb, int32 ldc,
int32 *locks, int32 grid0, int32 grid1) {
int32 rxa[TM] = get_global_range[TM](0);
int32 ryb[TN] = get_global_range[TN](1);
int32 rz = get_global_range[1](2);
int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK;
fp32 c[TM, TN] = 0;
int32 div = K / GZ;
int32 rem = K % GZ;
K = select(rz < rem, div - 1, div);
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);
fp32* pa[)" + AS0 + ", " + AS1 + "] = A + (offk + rka" + bca0 + ")" + lda0 + " + rxa" + bca1 + lda1 + R"(;
fp32* pb[)" + BS0 + ", " + BS1 + "] = B + (offk + rkb" + bcb0 + ")" + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
fp32 a[)" + AS0 + ", " + AS1 + R"(] = *pa;
fp32 b[)" + BS0 + ", " + BS1 + R"(] = *pb;
int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda;
int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb;
last_a = last_a / TK * TK;
last_b = last_b / TK * TK;
int32 bound = K - max(last_a, last_b);
for(int32 k = K; k > bound; k = k - TK){
c = dot()" + usea + ", " + useb + R"(, c);
pa = pa + TK)" + lda0 + R"(;
pb = pb + TK)" + ldb0 + R"(;
a = *pa;
b = *pb;
}
int32 rxc[TM] = get_global_range[TM](0);
int32 ryc[TN] = get_global_range[TN](1);
for(int32 k = bound; k > 0; k = k - 1){
int1 checka[TM, 1] = rxc[:, newaxis] < M;
int1 checkb[TN, 1] = ryc[:, newaxis] < N;
fp32* pa[TM, 1] = A + (offk + K - k))" + lda0 + " + rxc[:, newaxis]" + lda1 + R"(;
fp32* pb[TN, 1] = B + (offk + K - k))" + ldb0 + " + ryc[:, newaxis]" + ldb1 + R"(;
fp32 a[TM, 1] = checka ? *pa : 0;
fp32 b[TN, 1] = checkb ? *pb : 0;
c = dot(a, trans(b), c);
}
int32 ridx = get_range_id(0);
int32 ridy = get_range_id(1);
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
int32 *plock = locks + ridx + ridy*grid0;
while(__atomic_cas(plock, 0, 1));
int32 *pcount = plock + grid0*grid1;
int32 count = *pcount;
int32 countp1 = select(count == GZ - 1, 0, count + 1);
int1 checkc0[TM] = rxc < M;
int1 checkc1[TN] = ryc < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
if(count == 0) {
@checkc *pc = c;
*pcount = countp1;
}
else {
@checkc *pc = c + *pc;
*pcount = countp1;
}
__atomic_cas(plock, 1, 0);
}
)";
return res;
}
)";


int main() {
bool AT = false;
bool BT = true;

// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context);
@@ -128,16 +169,16 @@ int main() {


// just-in-time compile source-code
std::vector<unsigned> params = {
16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1
};
// jit.autotune("matmul",src, benchmark);
jit.add_module("matmul", src, params);
std::string src = triton_source(AT, BT);
// jit.autotune("matmul",src.c_str(), benchmark);
jit.add_module("matmul", src.c_str(), {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1});
// jit.add_module("matmul", src.c_str(), {16, 2, 128, 32, 32, 32, 4, 2, 2, 8, 8, 4, 2, 1});
// jit.add_module("matmul", src.c_str(), {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1});
triton::driver::kernel* kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
stream->read(dc, true, 0, hc);
simple_gemm<float,false,true>(rc, ha, hb, M, N, K);
simple_gemm<float>(AT, BT, rc, ha, hb, M, N, K);
for(size_t i = 0; i < M*N; i++)
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;

0 comments on commit f804410

Please sign in to comment.
You can’t perform that action at this time.