Skip to content
Permalink
Browse files

[dnn] added Triton-C derivative computations in conv

  • Loading branch information...
ptillet committed May 13, 2019
1 parent f6fe949 commit 5941501f70a00346f7065f54217ee97c980c423d
Showing with 246 additions and 281 deletions.
  1. +2 −2 examples/cpp/conv.cpp
  2. +125 −258 examples/python/pytorch/conv.cpp
  3. +44 −5 examples/python/pytorch/main.py
  4. +64 −13 include/triton/dnn/conv.h
  5. +11 −3 include/triton/driver/dispatch.h
@@ -10,7 +10,7 @@ int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context);
triton::dnn::conv::type ty = triton::dnn::conv::WGRAD;
triton::dnn::conv::type ty = triton::dnn::conv::FPROP;
// initialization
int32_t B = 4, NF = 32;
int32_t D = 1, H = 24, W = 240;
@@ -77,7 +77,7 @@ int main() {
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
exit(EXIT_FAILURE);
}
}
}
std::cout << "Pass!" << std::endl;
}

0 comments on commit 5941501

Please sign in to comment.
You can’t perform that action at this time.