Skip to content

Commit

Permalink
[dnn/conv] minor cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet committed May 15, 2019
1 parent be2ba03 commit 15a967c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
4 changes: 2 additions & 2 deletions examples/cpp/conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context);
triton::dnn::conv::type ty = triton::dnn::conv::BPROP;
triton::dnn::conv::type ty = triton::dnn::conv::WGRAD;
// initialization
int32_t B = 4, NF = 32;
int32_t D = 1, H = 24, W = 240;
Expand Down Expand Up @@ -66,7 +66,7 @@ int main() {
return configuration.get_nflops() / ts * 1e-3;
};
std::string src = configuration.src();
jit.autotune("conv", src.c_str(), benchmark);
// jit.autotune("conv", src.c_str(), benchmark);
jit.add_module("conv", src.c_str(), configuration.default_params());
triton::driver::kernel* kernel = jit.get_function("conv");
triton::jit::launch_information info = jit.get_launch_info("conv");
Expand Down
18 changes: 13 additions & 5 deletions include/triton/dnn/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,18 @@ class conv {
kernel->setArg(5, K_);
kernel->setArg(6, AH_);
kernel->setArg(7, AW_);
kernel->setArg(8, BH_);
kernel->setArg(9, BW_);
kernel->setArg(10, CH_);
kernel->setArg(11, CW_);
if(ty_ == WGRAD){
kernel->setArg(8, CH_);
kernel->setArg(9, CW_);
kernel->setArg(10, BH_);
kernel->setArg(11, BW_);
}
else{
kernel->setArg(8, BH_);
kernel->setArg(9, BW_);
kernel->setArg(10, CH_);
kernel->setArg(11, CW_);
}
kernel->setArg(12, ld_a_[0]);
kernel->setArg(13, ld_a_[1]);
kernel->setArg(14, ld_a_[2]);
Expand Down Expand Up @@ -360,8 +368,8 @@ class conv {
fp32 *c,
int32 M, int32 N, int32 K,
int32 AH, int32 AW,
int32 CH, int32 CW,
int32 BH, int32 BW,
int32 CH, int32 CW,
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
int32 ldb_n, int32 ldb_k, int32 ldb_m, int32 ldb_p, int32 ldb_q,
int32 ldc_c, int32 ldc_t, int32 ldc_r, int32 ldc_s, int32 ldc_k,
Expand Down

0 comments on commit 15a967c

Please sign in to comment.