Skip to content

Commit

Permalink
[examples/python/tensorflow] better skeleton for blocksparse
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet committed May 1, 2019
1 parent 55866f1 commit 70f49a5
Show file tree
Hide file tree
Showing 3 changed files with 280 additions and 140 deletions.
237 changes: 117 additions & 120 deletions examples/python/tensorflow/blocksparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
#include "tensorflow/core/framework/common_shape_fns.h"

using namespace tensorflow;
using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
using GPUDevice = Eigen::GpuDevice;


Expand All @@ -25,139 +28,133 @@ const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {8};
const tunable int32 GZ = {1};
void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
void bsmm (restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
int32 M, int32 N, int32 K,
int32 lda, int32 ldb, int32 ldc,
int32 *locks, int32 grid0, int32 grid1) {
int32 rxa[TM] = get_global_range[TM](0);
int32 ryb[TN] = get_global_range[TN](1);
int32 rz = get_global_range[1](2);
int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK;
fp32 c[TM, TN] = 0;
int32 div = K / GZ;
int32 rem = K % GZ;
K = select(rz < rem, div - 1, div);
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);
fp32* pa[TM, TK] = A + (offk + rka[newaxis, :])*lda + rxa[:, newaxis];
fp32* pb[TN, TK] = B + (offk + rkb[newaxis, :])*ldb + ryb[:, newaxis];
fp32 a[TM, TK] = *pa;
fp32 b[TN, TK] = *pb;
int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda;
int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb;
last_a = last_a / TK * TK;
last_b = last_b / TK * TK;
int32 bound = K - max(last_a, last_b);
for(int32 k = K; k > bound; k = k - TK){
c = dot(a, trans(b), c);
pa = pa + TK*lda;
pb = pb + TK*ldb;
a = *pa;
b = *pb;
}
int32 rxc[TM] = get_global_range[TM](0);
int32 ryc[TN] = get_global_range[TN](1);
for(int32 k = bound; k > 0; k = k - 1){
int1 checka[TM, 1] = rxc[:, newaxis] < M;
int1 checkb[TN, 1] = ryc[:, newaxis] < N;
fp32* pa[TM, 1] = A + (offk + K - k)*lda + rxc[:, newaxis];
fp32* pb[TN, 1] = B + (offk + K - k)*ldb + ryc[:, newaxis];
fp32 a[TM, 1] = checka ? *pa : 0;
fp32 b[TN, 1] = checkb ? *pb : 0;
c = dot(a, trans(b), c);
}
int32 ridx = get_range_id(0);
int32 ridy = get_range_id(1);
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
int32 *plock = locks + ridx + ridy*grid0;
while(__atomic_cas(plock, 0, 1));
int32 *pcount = plock + grid0*grid1;
int32 count = *pcount;
int32 countp1 = select(count == GZ - 1, 0, count + 1);
int1 checkc0[TM] = rxc < M;
int1 checkc1[TN] = ryc < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
if(count == 0) {
@checkc *pc = c;
*pcount = countp1;
}
else {
@checkc *pc = c + *pc;
*pcount = countp1;
}
__atomic_cas(plock, 1, 0);
}
)";

REGISTER_OP("BlockSparseMatMul")
.Input("a: T")
.Input("b: T")
.Input("locks: int32")
.Output("c: T")
.Attr("T: {float}")
;
Status XpropShape(InferenceContext* ctx)
{
int K; TF_RETURN_IF_ERROR(ctx->GetAttr( "K", &K));
int axis; TF_RETURN_IF_ERROR(ctx->GetAttr("axis", &axis));

// C ==> K
ShapeHandle x = ctx->input(0);
int rank = ctx->Rank(x);
//printf("XpropShape: %d\n", rank);
if (rank > 0)
{
std::vector<DimensionHandle> shape;
shape.reserve(rank);
for (int i = 0; i < rank; i++)
shape.push_back(i == axis ? ctx->MakeDim(K) : ctx->Dim(x, i));

ctx->set_output(0, ctx->MakeShape(shape));
}
else
ctx->set_output(0, ctx->UnknownShape());
ctx->set_output(1, ctx->UnknownShape());
return Status::OK();
}


REGISTER_OP("BlocksparseMatmul")
.Input("x: T")
.Input("w: T")
.Input("lut: int64")
.Input("lut_dx: int64")
.Input("lut_dw: int64")
.Input("gate: ngate * float")
.Output("y: T")
.Output("temp: int32")
.Attr("T: {half, float, bfloat16}")
.Attr("blocks: int >=0")
.Attr("bsize: int")
.Attr("segments: int = 0")
.Attr("segments_dx: int = 0")
.Attr("locks: int = 0")
.Attr("locks_dx: int = 0")
.Attr("axis: int = 1")
.Attr("C: int >=0")
.Attr("K: int >=0")
.Attr("shared: int = 0")
.Attr("shared_dx: int = 0")
.Attr("alpha: float = 1.0")
.Attr("beta: float = 0.0")
.Attr("gated_dw: bool = false")
.Attr("gate_grad: bool = false")
.Attr("bench: int = 0")
.Attr("ngate: int >= 0")
.SetShapeFn(XpropShape)
.Doc(R"doc(
Multiply the matrix "a" by the blocksparse matrix "b".
)doc");


typedef struct bsmm_params
{
const int* Lut;
const float* Gate;
int* Lock;
//float4* Scratch;
int blocks;
int bsize;
int segments;
int locks;
int C;
int K;
int N;
int shared;
int pcount;
uint blk_a;
uint blk_A;
uint blk_b;
uint blk_B;
float alpha;
float beta;
CUstream stream;
} bsmm_params;

class BlockSparseGemmOp : public OpKernel {
class BlocksparseMatmulOp : public OpKernel {
public:
explicit BlockSparseGemmOp(OpKernelConstruction* context) : OpKernel(context) {
explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("segments", &params_.segments));
OP_REQUIRES_OK(ctx, ctx->GetAttr("locks", &params_.locks ));
OP_REQUIRES_OK(ctx, ctx->GetAttr("blocks", &params_.blocks ));
OP_REQUIRES_OK(ctx, ctx->GetAttr("bsize", &params_.bsize ));
OP_REQUIRES_OK(ctx, ctx->GetAttr("C", &params_.C ));
OP_REQUIRES_OK(ctx, ctx->GetAttr("K", &params_.K ));
OP_REQUIRES_OK(ctx, ctx->GetAttr("shared", &params_.shared ));
OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &params_.alpha ));
OP_REQUIRES_OK(ctx, ctx->GetAttr("beta", &params_.beta ));
OP_REQUIRES_OK(ctx, ctx->GetAttr("gated_dw", &gated_dw_ ));
OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_ ));
OP_REQUIRES_OK(ctx, ctx->GetAttr("bench", &bench_));
OP_REQUIRES(ctx, params_.K < params_.bsize*65536, errors::InvalidArgument("K < bsize*65536"));
OP_REQUIRES(ctx, params_.C < params_.bsize*65536, errors::InvalidArgument("C < bsize*65536"));
params_.pcount = 1;
params_.blk_A = 0;
is_gpu_ = ctx->device_type() == DEVICE_GPU;
if (bench_) {
repeat_ = bench_;
flops_ = (float)(params_.blocks * params_.bsize*params_.bsize);
const char* op = "FPROP";
sprintf(bench_string_, "%s %02d-%d C:%05d K:%05d blks:%d", op, params_.bsize, axis_, params_.C, params_.K, params_.blocks);
}
}

void Compute(OpKernelContext* context){
// get device/stream
GPUDevice device = context->eigen_device<GPUDevice>();
triton::driver::cu_stream sstream(device.stream(), false);
triton::driver::context* ctx = sstream.context();
triton::driver::stream* stream = &sstream;
// get inputs
const Tensor& a = context->input(0);
const Tensor& b = context->input(1);
const Tensor& locks = context->input(2);
// get shapes
const int32_t M = a.dim_size(0);
const int32_t N = b.dim_size(0);
const int32_t K = a.dim_size(1);
// allocate output
Tensor* c = nullptr;
TensorShape out_shape({(int64)M, (int64)N});
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &c));
// return early if possible
if (out_shape.num_elements() == 0)
return;
// initialize default compute device
triton::jit jit(ctx);
// matrix multiplication parameters
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<float>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<float>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks.flat<int32_t>().data(), false);
stream->synchronize();
// just-in-time compile source-code
jit.add_module("matmul", src, {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1});
triton::driver::kernel* kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
unsigned GZ = jit.get_int("GZ");
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ};
// set argument
kernel->setArg(0, *da.cu());
kernel->setArg(1, *db.cu());
kernel->setArg(2, *dc.cu());
kernel->setArg(3, M);
kernel->setArg(4, N);
kernel->setArg(5, K);
kernel->setArg(6, M);
kernel->setArg(7, N);
kernel->setArg(8, M);
kernel->setArg(9, *dlocks.cu());
kernel->setArg(10, grid[0]);
kernel->setArg(11, grid[1]);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
}

private:
bsmm_params params_;
int axis_, bench_, repeat_, SMs_, major_, grid_n_;
float flops_;
bool gated_dw_, is_gpu_;
char bench_string_[256];
};

REGISTER_KERNEL_BUILDER(Name("BlockSparseMatMul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlockSparseGemmOp);
REGISTER_KERNEL_BUILDER(Name("BlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp);
20 changes: 0 additions & 20 deletions examples/python/tensorflow/blocksparse.py

This file was deleted.

Loading

0 comments on commit 70f49a5

Please sign in to comment.