Skip to content
Permalink
Browse files

[triton/examples/cpp] removed common.hpp helper

  • Loading branch information...
ptillet committed May 28, 2019
1 parent a9d078c commit 8102efc0643380ba3ac176165b4e49375b56702b
Showing with 55 additions and 62 deletions.
  1. +0 −59 examples/cpp/common.hpp
  2. +1 −2 examples/cpp/dot.cpp
  3. +31 −1 examples/cpp/shift.cpp
  4. +23 −0 include/triton/dnn/gemm.h

This file was deleted.

@@ -1,6 +1,5 @@
#include <cstring>
#include <cstdio>
#include "common.hpp"
#include "triton/runtime/jit.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
@@ -67,7 +66,7 @@ int main() {
triton::jit::launch_information info = jit.get_launch_info("matmul");
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
stream->read(dc, true, 0, hc);
simple_gemm<float>(AT, BT, rc, ha, hb, M, N, K);
triton::dnn::gemm::cpu_ref<float>(AT, BT, rc, ha, hb, M, N, K);
for(size_t i = 0; i < M*N; i++)
if(!std::isnan(hc[i]) && std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
@@ -1,11 +1,41 @@
#include <cstring>
#include <cstdio>
#include "common.hpp"
#include "triton/runtime/jit.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/tools/bench.hpp"

// input layout: C, H, W, BS
// filter layout: C, K
// output layout: K, H, W, BS
template<class IN_DTYPE, class OUT_DTYPE>
void shift_conv(int32_t C, int32_t H, int32_t W, int32_t BS,
int32_t K,
std::vector<OUT_DTYPE>& O,
const std::vector<IN_DTYPE>& I,
const std::vector<IN_DTYPE>& F,
const std::vector<int32_t> shift_h,
const std::vector<int32_t> shift_w)
{
OUT_DTYPE acc;
for(int32_t p = 0; p < H; ++p)
for(int32_t q = 0; q < W; ++q)
for(int32_t bs = 0; bs < BS; ++bs)
for(int32_t k = 0; k < K; ++k)
{
acc = 0;
for(int32_t c = 0; c < C; ++c){
int32_t h = p + shift_h[c];
int32_t w = q + shift_w[c];
bool in_bounds = (h >= 0 && w >= 0 && h < H && w < W);
IN_DTYPE a = in_bounds?I[bs + w*BS + h*BS*W + c*BS*H*W]:0;
IN_DTYPE b = F[k + c*K];
acc = std::fma(a, b, acc);
}
O[bs + q*BS + p*BS*W + k*BS*H*W] = acc;
}
}

// K = channels
// M = batch * height * width
// N = number of feature maps
@@ -14,6 +14,29 @@ class gemm {
driver::buffer *locks, int32_t grid_0, int32_t grid_1);
static std::vector<unsigned> default_params(bool AT, bool BT);
static std::string src(bool AT, bool BT);

template<class T, bool AT, bool BT>
static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b, size_t M, size_t N, size_t K){
for(size_t m = 0; m < M; m++)
for(size_t n = 0; n < N; n++){
T acc = 0;
for(size_t k = 0; k < K; k++)
acc += (AT?a[k + m*K]:a[m + k*M]) * (BT?b[n + k*N]:b[k + n*K]);
c[m + n*M] = acc;
}
}

template<class T>
static void cpu_ref(bool AT, bool BT, std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b, size_t M, size_t N, size_t K) {
if(AT && BT)
gemm::cpu_ref<T, true, true>(c, a, b, M, N, K);
else if(AT && !BT)
gemm::cpu_ref<T, true, false>(c, a, b, M, N, K);
else if(!AT && BT)
gemm::cpu_ref<T, false, true>(c, a, b, M, N, K);
else
gemm::cpu_ref<T, false, false>(c, a, b, M, N, K);
}
};

}

0 comments on commit 8102efc

Please sign in to comment.
You can’t perform that action at this time.