Skip to content
Permalink
Browse files

[triton/dnn/conv] merged optimizations branch

- Added forward/backward support for strided convolution
- Added support for bias
- Added support for reduction splitting
  • Loading branch information...
ptillet committed May 28, 2019
1 parent e526ffc commit a9d078c06f660f9d6aa353080206d69725528a00
Showing with 732 additions and 31,339 deletions.
  1. +4 −1 cmake/FindTorch.cmake
  2. +2 −240 examples/cpp/common.hpp
  3. +18 −15 examples/cpp/conv.cpp
  4. +3 −2 examples/cpp/dot.cpp
  5. +3 −2 examples/cpp/shift.cpp
  6. +52 −19 examples/python/pytorch/conv.cpp
  7. +11 −7 examples/python/pytorch/test.py
  8. +4 −3 examples/python/pytorch/triton.py
  9. +18 −149 include/triton/dnn/conv.h
  10. +4 −123 include/triton/dnn/gemm.h
  11. +2 −2 include/triton/driver/buffer.h
  12. +0 −73 include/triton/driver/dispatch.h
  13. +0 −64 include/triton/external/CUDA/builtin_types.h
  14. +0 −412 include/triton/external/CUDA/channel_descriptor.h
  15. +0 −266 include/triton/external/CUDA/crt/host_config.h
  16. +0 −216 include/triton/external/CUDA/crt/host_defines.h
  17. +0 −338 include/triton/external/CUDA/cuComplex.h
  18. +0 −565 include/triton/external/CUDA/cublas.h
  19. +0 −2,977 include/triton/external/CUDA/cublas_api.h
  20. +0 −274 include/triton/external/CUDA/cublas_v2.h
  21. +0 −248 include/triton/external/CUDA/cuda_device_runtime_api.h
  22. +0 −1,969 include/triton/external/CUDA/cuda_fp16.h
  23. +0 −1,797 include/triton/external/CUDA/cuda_fp16.hpp
  24. +0 −2,040 include/triton/external/CUDA/cuda_runtime.h
  25. +0 −7,422 include/triton/external/CUDA/cuda_runtime_api.h
  26. +0 −1,805 include/triton/external/CUDA/cudnn.h
  27. +0 −6,257 include/triton/external/CUDA/cusparse.h
  28. +0 −69 include/triton/external/CUDA/device_types.h
  29. +0 −145 include/triton/external/CUDA/driver_functions.h
  30. +0 −1,610 include/triton/external/CUDA/driver_types.h
  31. +0 −50 include/triton/external/CUDA/host_config.h
  32. +0 −50 include/triton/external/CUDA/host_defines.h
  33. +0 −80 include/triton/external/CUDA/library_types.h
  34. +0 −525 include/triton/external/CUDA/nvrtc.h
  35. +0 −119 include/triton/external/CUDA/surface_types.h
  36. +0 −217 include/triton/external/CUDA/texture_types.h
  37. +0 −177 include/triton/external/CUDA/vector_functions.h
  38. +0 −318 include/triton/external/CUDA/vector_functions.hpp
  39. +0 −425 include/triton/external/CUDA/vector_types.h
  40. +6 −1 include/triton/runtime/jit.h
  41. +50 −0 include/triton/tools/bench.hpp
  42. +408 −108 lib/dnn/conv.cpp
  43. +137 −0 lib/dnn/gemm.cpp
  44. +2 −2 lib/driver/buffer.cpp
  45. +0 −114 lib/driver/dispatch.cpp
  46. +0 −39 lib/driver/error.cpp
  47. +8 −4 lib/runtime/jit.cpp
@@ -4,7 +4,10 @@ execute_process(COMMAND python -c "import torch; import os; print(os.path.dirnam

find_package_handle_standard_args(TORCH DEFAULT_MSG TORCH_INSTALL_PREFIX)
if(TORCH_INSTALL_PREFIX)
set(TORCH_INCLUDE_DIRS ${TORCH_INSTALL_PREFIX}/lib/include/ ${TORCH_INSTALL_PREFIX}/lib/include/torch/csrc/api/include)
set(TORCH_INCLUDE_DIRS ${TORCH_INSTALL_PREFIX}/lib/include/
${TORCH_INSTALL_PREFIX}/lib/include/torch/csrc/api/include
${TORCH_INSTALL_PREFIX}/include/
${TORCH_INSTALL_PREFIX}/include/torch/csrc/api/include/)
set(TORCH_LIBRARY_DIRS ${TORCH_INSTALL_PREFIX}/lib/)
endif()

@@ -1,5 +1,6 @@
#include <vector>
#include <chrono>
#include <cmath>
#include "triton/driver/device.h"
#include <algorithm>

@@ -26,245 +27,6 @@ void simple_gemm(bool AT, bool BT, std::vector<T> &c, const std::vector<T> &a, c
simple_gemm<T, false, false>(c, a, b, M, N, K);
}

class timer{
typedef std::chrono::high_resolution_clock high_resolution_clock;
typedef std::chrono::nanoseconds nanoseconds;

public:
explicit timer(bool run = false)
{ if (run) start(); }

void start()
{ _start = high_resolution_clock::now(); }

nanoseconds get() const
{ return std::chrono::duration_cast<nanoseconds>(high_resolution_clock::now() - _start); }

private:
high_resolution_clock::time_point _start;
};

template<class T>
T min(std::vector<T> x)
{ return *std::min_element(x.begin(), x.end()); }


template<class OP, class SYNC>
double bench(OP const & op, SYNC const & sync, triton::driver::device const & device)
{
timer tmr;
std::vector<size_t> times;
double total_time = 0;
op();
sync();
while(total_time*1e-9 < 1e-3){
float norm = 1;
// normalize clock if possible to get roughly constant result
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(&device))
norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
tmr.start();
op();
sync();
times.push_back(norm*tmr.get().count());
total_time+=times.back();
}
return min(times);
}

//

void build_conv_lut(int TK,
int stride_d, int stride_h, int stride_w, int stride_c,
int pad_d, int pad_h, int pad_w,
int T, int R, int S,
std::vector<int>& res, std::vector<int>& masks) {
/* convolution parameters */
int F = T * R * S;
int Nlut = (TK + F - 1) / F * F;
int upsample_w = 1;
int upsample_h = 1;
int upsample_d = 1;
/* unpack index wrt filters */
auto unpack = [&](int32_t trs){
int32_t tr = trs / S;
int32_t s = trs - tr*S;
int32_t t = tr / R;
int32_t r = tr - t*R;
return std::make_tuple(t, r, s);
};
/* increments */
for(size_t i = 0; i < Nlut; ++i)
res[i] = (((i + TK) % Nlut) - i);
/* deltas */
size_t Ds0 = Nlut;
size_t Ds1 = upsample_w;
size_t Ds2 = upsample_h;
size_t Ds3 = upsample_d;
for(size_t pd = 0; pd < Ds3; ++pd)
for(size_t ph = 0; ph < Ds2; ++ph)
for(size_t pw = 0; pw < Ds1; ++pw){
int32_t* deltas_ptr = &res[Nlut + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2];
// cumulative increments
for(size_t i = 0; i < Ds0; ++i){
int32_t ctrs = i;
int32_t c = ctrs / F;
int32_t t, r, s;
std::tie(t, r, s) = unpack(ctrs % F);
// next indices
int32_t nextctrs = ctrs + TK;
int32_t nextc = nextctrs / F;
int32_t nextt, nextr, nexts;
std::tie(nextt, nextr, nexts) = unpack(nextctrs % F);
// diffs
int32_t cdiff = nextc - c;
int32_t tdiff = (nextt + pd)/upsample_d - (t + pd)/upsample_d;
int32_t rdiff = (nextr + ph)/upsample_h - (r + ph)/upsample_h;
int32_t sdiff = (nexts + pw)/upsample_w - (s + pw)/upsample_w;
// delta pointers
deltas_ptr[i] = cdiff*stride_c + sdiff*stride_w + rdiff*stride_h + tdiff*stride_d;
}
}

/* Masks */
size_t Ms0 = Nlut;
size_t Ms1 = 2*pad_w + 1;
size_t Ms2 = 2*pad_h + 1;
size_t Ms3 = 2*pad_d + 1;

for(size_t pd = 0; pd < Ms3; ++pd)
for(size_t ph = 0; ph < Ms2; ++ph)
for(size_t pw = 0; pw < Ms1; ++pw){
int32_t* masks_ptr = &masks[Nlut + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2];
for(size_t i = 0; i < Ms0; ++i){
int32_t t, r, s;
int32_t mask = 0x0;
for(size_t j = 0; j < TK; ++j){
std::tie(t, r, s) = unpack((i + j) % F);
bool in_bounds_d = (t + pd) >= pad_d && (t + pd) < (T + pad_d);
bool in_bounds_h = (r + ph) >= pad_h && (r + ph) < (R + pad_h);
bool in_bounds_w = (s + pw) >= pad_w && (s + pw) < (S + pad_w);
mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j;
}
masks_ptr[i] = mask;
}
}
for(size_t i = 0; i < Nlut; ++i)
masks[i] = 0x0;
}


// Index computation
inline int32_t idx(int32_t x, int32_t y, int32_t z, int32_t w, int32_t u,
int32_t /*s0*/, int32_t s1, int32_t s2, int32_t s3, int32_t s4)
{ return u + w*s4 + z*s4*s3 + y*s4*s3*s2 + x*s4*s3*s2*s1; }


// Pack

template <class T> T clamp(T x, T lo, T hi){
return std::max<T>(lo, std::min<T>(x, hi));
}


template<class T, class U>
T pack(U* tmp, U scale);

template<>
double pack<double, double>(double* tmp, double scale)
{ return tmp[0]*scale; }

template<>
float pack<float, float>(float* tmp, float scale)
{ return tmp[0]*scale; }

template<>
int pack<int, float>(float* tmp, float scale)
{
int res = 0;
for(int i = 0; i < 4; i++){
int8_t clamped = std::round(clamp(tmp[i]*scale, (float)-128, (float)127));
res |= (clamped & 0xFF) << (8*i);
}
return res;
}

template<class T> struct pack_increment
{ enum{ VALUE = 1}; };

template<> struct pack_increment<int>
{ enum{ VALUE = 4}; };

// Dot
template<class T>
inline T dot(T x, T y, T z)
{
return std::fma(x, y, z);
}

inline int dot(int x, int y, int z){
int res = 0;
for(int i = 0; i < 4; i++){
int32_t a = ((x >> (8*i)) & 0x000000FF);
int32_t b = ((y >> (8*i)) & 0x000000FF);
res += (*(int8_t*)(&a)) * (*(int8_t*)(&b));
}
return res + z;
}



template<class IN_DTYPE, class OUT_DTYPE>
void cpp_conv_nchw(int32_t C, int32_t N, int32_t K,
int32_t D, int32_t H, int32_t W,
int32_t T, int32_t R, int32_t S,
int32_t pad_d, int32_t pad_h, int32_t pad_w,
int32_t stride_d, int32_t stride_h, int32_t stride_w,
int32_t M, int32_t P, int32_t Q,
std::vector<OUT_DTYPE>& O,
const std::vector<IN_DTYPE>& I,
const std::vector<IN_DTYPE>& F)
{
static const int PACK_IN = pack_increment<IN_DTYPE>::VALUE;
static const int PACK_OUT = pack_increment<OUT_DTYPE>::VALUE;
if(C % PACK_IN != 0) throw std::runtime_error("Number of input channels must be a multiple of 4");
if(K % PACK_OUT != 0) throw std::runtime_error("Number of output channels must be a multiple of 4");
C /= PACK_IN;
K /= PACK_OUT;
int32_t Kout = K;
IN_DTYPE accs[PACK_OUT];
float tmp[PACK_OUT];
for(int32_t m = 0 ; m < M; ++m)
for(int32_t p = 0 ; p < P; ++p)
for(int32_t q = 0; q < Q; ++q)
for(int32_t n = 0; n < N; ++n)
for(int32_t k = 0; k < Kout ; ++k)
{
for(int32_t i = 0; i < PACK_OUT; ++i)
accs[i] = 0;
int32_t mm = m*stride_d - pad_d;
int32_t pp = p*stride_h - pad_h;
int32_t qq = q*stride_w - pad_w;
for(int32_t kk = 0; kk < PACK_OUT; ++kk)
for(int32_t c = 0; c < C; ++c)
for(int32_t t = 0; t < T; ++t)
for(int32_t r = 0; r < R; ++r)
for(int32_t s = 0; s < S; ++s){
int32_t d = mm + t;
int32_t h = pp + r;
int32_t w = qq + s;
bool in_bounds = (d >= 0 && h >= 0 && w >= 0 && d < D && h < H && w < W);
IN_DTYPE i = in_bounds?I[idx(n, c, d, h, w, N, C, D, H, W)]:0;
IN_DTYPE f = F[idx(c, t, r, s, k*PACK_OUT + kk, C, T, R, S, K*PACK_OUT)];
accs[kk] = dot(i, f, accs[kk]);
}
for(int32_t kk = 0; kk < PACK_OUT; ++kk){
tmp[kk] = accs[kk];
}
O[idx(n, k, m, p, q, N, K, M, P, Q)] = tmp[0];
}
}


// input layout: C, H, W, BS
// filter layout: C, K
// output layout: K, H, W, BS
@@ -290,7 +52,7 @@ void shift_conv(int32_t C, int32_t H, int32_t W, int32_t BS,
bool in_bounds = (h >= 0 && w >= 0 && h < H && w < W);
IN_DTYPE a = in_bounds?I[bs + w*BS + h*BS*W + c*BS*H*W]:0;
IN_DTYPE b = F[k + c*K];
acc = dot(a, b, acc);
acc = std::fma(a, b, acc);
}
O[bs + q*BS + p*BS*W + k*BS*H*W] = acc;
}
@@ -1,24 +1,26 @@
#include <cstring>
#include <cstdio>
#include "common.hpp"
#include <sstream>
#include "triton/runtime/jit.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/dnn/conv.h"
#include "triton/tools/bench.hpp"

int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context);
triton::dnn::conv::type ty = triton::dnn::conv::FPROP;
// initialization
int32_t B = 4, NF = 32;
int32_t D = 1, H = 56, W = 56;
int32_t NC = 16, T = 1, R = 3, S = 3;
int32_t B = 64, NF = 64;
int32_t D = 1, H = 8, W = 8;
int32_t NC = 3, T = 1, R = 3, S = 3;
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, upsample_d, upsample_h, upsample_w, ty);
triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, triton::dnn::conv::FPROP, 0);
// triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, upsample_d, upsample_h, upsample_w, ty);
// convolution configuration
std::vector<float> hc(configuration.c_size());
std::vector<float> rc(configuration.c_size());
@@ -43,22 +45,23 @@ int main() {
// benchmark a given convolution kernel
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
configuration.init(stream, (triton::driver::cu_module*)kernel->module());
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
std::array<size_t, 3> grid = configuration.get_grid(TM, TN);
configuration.init(stream, (triton::driver::cu_module*)kernel->module());
stream->synchronize();
configuration.set_arg(kernel, da, db, dc, nullptr);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
unsigned GZ = jit.get_int("GZ");
configuration.enqueue(stream, kernel, da, db, dc, nullptr, TM, TN, GZ, nthreads);
stream->synchronize();
double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
[&](){ stream->synchronize(); }, *context->device());
double ts = triton::tools::bench([&](){ configuration.enqueue(stream, kernel, da, db, dc, nullptr, TM, TN, GZ, nthreads); },
[&](){ stream->synchronize(); }, nullptr);
return configuration.get_nflops() / ts * 1e-3;
};
std::string src = configuration.src();
// jit.autotune("conv", src.c_str(), benchmark);
jit.add_module("conv", src.c_str(), configuration.default_params());
std::ostringstream oss;
configuration.src(oss);
std::string src = oss.str();
triton::jit::tune_res_t best = jit.autotune("conv", src.c_str(), benchmark);
jit.add_module("conv", src.c_str(), best.params);
// jit.add_module("conv", src.c_str(), configuration.default_params());
triton::driver::kernel* kernel = jit.get_function("conv");
triton::jit::launch_information info = jit.get_launch_info("conv");
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
@@ -5,6 +5,7 @@
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/dnn/gemm.h"
#include "triton/tools/bench.hpp"


int main() {
@@ -52,8 +53,8 @@ int main() {
triton::dnn::gemm::set_arg(kernel, da, db, dc, M, N, K, dlocks, grid[0], grid[1]);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
stream->synchronize();
double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
[&](){ stream->synchronize(); }, *context->device());
double ts = triton::tools::bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
[&](){ stream->synchronize(); }, context->device());
return 2.*M*N*K / ts * 1e-3;
};

@@ -4,6 +4,7 @@
#include "triton/runtime/jit.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/tools/bench.hpp"

// K = channels
// M = batch * height * width
@@ -180,8 +181,8 @@ int main() {
stream->enqueue(kernel, grid, {nthreads, 1, 1});
stream->synchronize();
// benchmark
double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
[&](){ stream->synchronize(); }, *context->device());
double ts = triton::tools::bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
[&](){ stream->synchronize(); }, context->device());
ts = ts * 1e-9;
double tflops = 2.*M*N*K / ts * 1e-12;
return tflops;

0 comments on commit a9d078c

Please sign in to comment.
You can’t perform that action at this time.