Skip to content
Permalink
Browse files

more cleaning of conv

  • Loading branch information...
ptillet committed May 6, 2019
1 parent fd91368 commit 615569287e06d7aa39be9d580cc4bae1ad15f8ba
Showing with 120 additions and 57 deletions.
  1. +3 −1 examples/cpp/conv.cpp
  2. +117 −56 include/triton/dnn/conv.h
@@ -25,6 +25,7 @@ int main() {
int32_t M = B*RD*RH*RW;
int32_t N = NF;
int32_t K = NC*T*R*S;
// convolution configuration
std::vector<float> hc(B*RH*RW*NF);
std::vector<float> rc(B*RH*RW*NF);
std::vector<float> ha(B*NC*H*W);
@@ -57,8 +58,9 @@ int main() {
int32_t stride_o_k = RD*stride_o_m;
int32_t stride_o_n = NF*stride_o_k;
// look-up table
triton::dnn::conv configuration(B, NC, H, W, R, S, NF, 1, 1, 0, 0);
std::vector<int> h_delta, h_masks;
triton::dnn::conv::init_cst(stride_i_d, stride_i_h, stride_i_w, stride_i_c, pad_d, pad_h, pad_w, T, R, S, h_delta, h_masks);
configuration.build_lut(h_delta, h_masks);
// benchmark a given convolution kernel
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
@@ -12,81 +12,106 @@ class conv {
WGRAD
};

static void build_lut(int TK,
int stride_d, int stride_h, int stride_w, int stride_c,
int pad_d, int pad_h, int pad_w,
int T, int R, int S,
std::vector<int>& res, std::vector<int>& masks) {
/* convolution parameters */
int F = T * R * S;
int Nlut = (TK + F - 1) / F * F;
int upsample_w = 1;
int upsample_h = 1;
int upsample_d = 1;

conv(int B, int NC, int H, int W, int R, int S, int NF,
int upsample_h, int upsample_w,
int pad_h, int pad_w)
: B_(B), NC_(NC), D_(1), H_(H), W_(W), T_(1), R_(R), S_(S), NF_(NF),
upsample_d_(1), upsample_h_(upsample_h), upsample_w_(upsample_w),
pad_d_(0), pad_h_(pad_h), pad_w_(pad_w)
{
RD_ = (D_*upsample_d_ - T_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_;
RH_ = (H_*upsample_h_ - R_ + 1 + 2*pad_h_ + stride_h_ - 1)/stride_h_;
RW_ = (W_*upsample_w_ - S_ + 1 + 2*pad_w_ + stride_w_ - 1)/stride_w_;
M_ = B*RD_*RH_*RW_;
N_ = NF;
K_ = NC*T_*R_*S_;
Fs_ = T_*R_*S_;
TK_ = 8;
Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_;
// memory strides for data
stride_a_w_ = 1;
stride_a_h_ = W_*stride_a_w_;
stride_a_d_ = H_*stride_a_h_;
stride_a_c_ = D_*stride_a_d_;
stride_a_n_ = NC_*stride_a_c_;
// memory stride for activations
stride_c_q_ = 1;
stride_c_p_ = RW_*stride_c_q_;
stride_c_m_ = RH_*stride_c_p_;
stride_c_k_ = RD_*stride_c_m_;
stride_c_n_ = NF_*stride_c_k_;
}


void build_lut(std::vector<int>& delta, std::vector<int>& masks) {
delta.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_);
masks.resize(Luts_ + (2*pad_h_+1)*(2*pad_w_+1)*(2*pad_d_+1)*Luts_);

/* unpack index wrt filters */
auto unpack = [&](int32_t trs){
int32_t tr = trs / S;
int32_t s = trs - tr*S;
int32_t t = tr / R;
int32_t r = tr - t*R;
int32_t tr = trs / S_;
int32_t s = trs - tr*S_;
int32_t t = tr / R_;
int32_t r = tr - t*R_;
return std::make_tuple(t, r, s);
};
/* increments */
for(size_t i = 0; i < Nlut; ++i)
res[i] = (((i + TK) % Nlut) - i);
for(size_t i = 0; i < Luts_; ++i)
delta[i] = (((i + TK_) % Luts_) - i);
/* deltas */
size_t Ds0 = Nlut;
size_t Ds1 = upsample_w;
size_t Ds2 = upsample_h;
size_t Ds3 = upsample_d;
size_t Ds0 = Luts_;
size_t Ds1 = upsample_w_;
size_t Ds2 = upsample_h_;
size_t Ds3 = upsample_d_;
for(size_t pd = 0; pd < Ds3; ++pd)
for(size_t ph = 0; ph < Ds2; ++ph)
for(size_t pw = 0; pw < Ds1; ++pw){
int32_t* deltas_ptr = &res[Nlut + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2];
int32_t* deltas_ptr = &delta[Luts_ + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2];
// cumulative increments
for(size_t i = 0; i < Ds0; ++i){
int32_t ctrs = i;
int32_t c = ctrs / F;
int32_t c = ctrs / Fs_;
int32_t t, r, s;
std::tie(t, r, s) = unpack(ctrs % F);
std::tie(t, r, s) = unpack(ctrs % Fs_);
// next indices
int32_t nextctrs = ctrs + TK;
int32_t nextc = nextctrs / F;
int32_t nextctrs = ctrs + TK_;
int32_t nextc = nextctrs / Fs_;
int32_t nextt, nextr, nexts;
std::tie(nextt, nextr, nexts) = unpack(nextctrs % F);
std::tie(nextt, nextr, nexts) = unpack(nextctrs % Fs_);
// diffs
int32_t cdiff = nextc - c;
int32_t tdiff = (nextt + pd)/upsample_d - (t + pd)/upsample_d;
int32_t rdiff = (nextr + ph)/upsample_h - (r + ph)/upsample_h;
int32_t sdiff = (nexts + pw)/upsample_w - (s + pw)/upsample_w;
int32_t tdiff = (nextt + pd)/upsample_d_ - (t + pd)/upsample_d_;
int32_t rdiff = (nextr + ph)/upsample_h_ - (r + ph)/upsample_h_;
int32_t sdiff = (nexts + pw)/upsample_w_ - (s + pw)/upsample_w_;
// delta pointers
deltas_ptr[i] = cdiff*stride_c + sdiff*stride_w + rdiff*stride_h + tdiff*stride_d;
deltas_ptr[i] = cdiff*stride_a_c_ + sdiff*stride_a_w_ + rdiff*stride_a_h_ + tdiff*stride_a_d_;
}
}

/* Masks */
size_t Ms0 = Nlut;
size_t Ms1 = 2*pad_w + 1;
size_t Ms2 = 2*pad_h + 1;
size_t Ms3 = 2*pad_d + 1;
size_t Ms0 = Luts_;
size_t Ms1 = 2*pad_w_ + 1;
size_t Ms2 = 2*pad_h_ + 1;
size_t Ms3 = 2*pad_d_ + 1;
for(size_t pd = 0; pd < Ms3; ++pd)
for(size_t ph = 0; ph < Ms2; ++ph)
for(size_t pw = 0; pw < Ms1; ++pw){
int32_t* masks_ptr = &masks[Nlut + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2];
int32_t* masks_ptr = &masks[Luts_ + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2];
for(size_t i = 0; i < Ms0; ++i){
int32_t t, r, s;
int32_t mask = 0x0;
for(size_t j = 0; j < TK; ++j){
std::tie(t, r, s) = unpack((i + j) % F);
bool in_bounds_d = (t + pd) >= pad_d && (t + pd) < (T + pad_d);
bool in_bounds_h = (r + ph) >= pad_h && (r + ph) < (R + pad_h);
bool in_bounds_w = (s + pw) >= pad_w && (s + pw) < (S + pad_w);
for(size_t j = 0; j < TK_; ++j){
std::tie(t, r, s) = unpack((i + j) % Fs_);
bool in_bounds_d = (t + pd) >= pad_d_ && (t + pd) < (T_ + pad_d_);
bool in_bounds_h = (r + ph) >= pad_h_ && (r + ph) < (R_ + pad_h_);
bool in_bounds_w = (s + pw) >= pad_w_ && (s + pw) < (S_ + pad_w_);
mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j;
}
masks_ptr[i] = mask;
}
}
for(size_t i = 0; i < Nlut; ++i)
for(size_t i = 0; i < Luts_; ++i)
masks[i] = 0x0;

}
@@ -95,20 +120,6 @@ class conv {
return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4 };
}

static void init_cst(int stride_d, int stride_h, int stride_w, int stride_c,
int pad_d, int pad_h, int pad_w,
int T, int R, int S,
std::vector<int> &h_delta, std::vector<int> &h_masks) {
int upsample_d = 1;
int upsample_h = 1;
int upsample_w = 1;
int TK = 8;
int F = T * R * S;
int nlut = (TK + F - 1) / F * F;
h_delta.resize(nlut + upsample_d*upsample_h*upsample_w*nlut);
h_masks.resize(nlut + (2*pad_h+1)*(2*pad_w+1)*(2*pad_d+1)*nlut);
build_lut(TK, stride_d, stride_h, stride_w, stride_c, pad_d, pad_h, pad_w, T, R, S, h_delta, h_masks);
}

static std::string src(type ty = FPROP) {

@@ -191,6 +202,56 @@ class conv {
})";
return res;
}

private:
// image size
int B_;
int NC_;
int D_;
int H_;
int W_;
// filter size
int T_;
int R_;
int S_;
int NF_;
// activation size
int RD_;
int RH_;
int RW_;
// upsampling
int upsample_d_;
int upsample_h_;
int upsample_w_;
// padding
int pad_d_;
int pad_h_;
int pad_w_;
// striding
int stride_d_;
int stride_h_;
int stride_w_;
// equivalent matmul
int M_;
int N_;
int K_;
// helpers
int Fs_;
int TK_;
int Luts_;
// memory strides for data
int32_t stride_a_w_;
int32_t stride_a_h_;
int32_t stride_a_d_;
int32_t stride_a_c_;
int32_t stride_a_n_;
// memory stride for activations
int32_t stride_c_q_;
int32_t stride_c_p_;
int32_t stride_c_m_;
int32_t stride_c_k_;
int32_t stride_c_n_;

};

}

0 comments on commit 6155692

Please sign in to comment.
You can’t perform that action at this time.