Skip to content
Permalink
Browse files

[testing] re-arranged util.h

  • Loading branch information...
ptillet committed Sep 12, 2019
1 parent f4beb71 commit 7f2bc5bb6624d313557901f67510bb9fc9e163bf
Showing with 73 additions and 47 deletions.
  1. +71 −20 tests/common/util.h
  2. +1 −1 tests/unit/dot.cc
  3. +1 −26 tests/unit/reduce.cc
@@ -9,6 +9,10 @@
namespace drv = triton::driver;
namespace rt = triton::runtime;

/* ------------------------
* Launch Grid
* ------------------------ */

inline size_t ceil(size_t x, size_t y) {
return (x + y - 1) / y;
};
@@ -26,10 +30,10 @@ inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
};
}

enum order_t {
ROWMAJOR,
COLMAJOR
};

/* ------------------------
* Tensor Initialization
* ------------------------ */

template<class T>
void init_rand(std::vector<T>& x) {
@@ -43,6 +47,49 @@ void init_zeros(std::vector<T>& x) {
x[i] = 0;
}

/* ------------------------
* Loop Nests
* ------------------------ */

void _loop_nest(std::vector<int> const & ranges,
std::function<void(std::vector<int> const &)> const & f){
int D = ranges.size();
std::vector<int> values(D, 0);
// Start with innermost loop
int i = D - 1;
while(true){
// Execute function
f(values);
while(values[i]++ == ranges[i] - 1){
if(i == 0)
return;
values[i--] = 0;
}
i = D - 1;
}
}

/* -----------------------
* TENSOR INDEXING
* ----------------------- */

enum order_t {
ROWMAJOR,
COLMAJOR
};


int offset(const std::vector<int>& idx, const std::vector<int>& shapes) {
int result = idx[0];
for(int i = 1; i < idx.size(); i++)
result += idx[i]*shapes[i-1];
return result;
}

/* -----------------------
* REDUCTION HELPERS
* ----------------------- */

enum reduce_op_t {
ADD,
MAX,
@@ -73,6 +120,26 @@ std::function<T(T,T)> get_accumulator(reduce_op_t op) {
}


/* -----------------------
* TENSOR COMPARISON
* ----------------------- */

template<class T>
bool diff(const std::vector<T>& hc, const std::vector<T>& rc) {
if(hc.size() != rc.size())
return false;
for(size_t i = 0; i < hc.size(); i++)
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
return false;
}
return true;
}

/* -----------------------
* PRETTY PRINTING
* ----------------------- */

namespace aux{
template<std::size_t...> struct seq{};

@@ -116,21 +183,5 @@ std::basic_ostream<Ch, Tr>& operator<<(std::basic_ostream<Ch, Tr>& os, reduce_op
}


namespace testing {

template<class T>
bool diff(const std::vector<T>& hc, const std::vector<T>& rc) {
if(hc.size() != rc.size())
return false;
for(size_t i = 0; i < hc.size(); i++)
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;

return false;
}
return true;
}

}

#endif
@@ -91,7 +91,7 @@ bool do_test(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_
stream->read(&*dc, true, 0, hc);
std::vector<NumericT> rc(hc.size());
cpu_ref(AT, BT, M, N, K, rc, ha, hb);
return testing::diff(hc, rc);
return diff(hc, rc);
}

int main() {
@@ -15,31 +15,6 @@
namespace drv = triton::driver;
namespace rt = triton::runtime;

void _loop_nest(std::vector<int> const & ranges,
std::function<void(std::vector<int> const &)> const & f){
int D = ranges.size();
std::vector<int> values(D, 0);
// Start with innermost loop
int i = D - 1;
while(true){
// Execute function
f(values);
while(values[i]++ == ranges[i] - 1){
if(i == 0)
return;
values[i--] = 0;
}
i = D - 1;
}
}

int offset(const std::vector<int>& idx, const std::vector<int>& shapes) {
int result = idx[0];
for(int i = 1; i < idx.size(); i++)
result += idx[i]*shapes[i-1];
return result;
}

template<class T>
void reduce_nd(std::vector<T> &y, const std::vector<T> &x, reduce_op_t op, size_t axis, const std::vector<int>& shapes) {
assert(axis <= shapes.size() - 1);
@@ -101,7 +76,7 @@ bool do_test(drv::stream* stream, std::vector<int> shape, int axis, reduce_op_t
stream->synchronize();
stream->read(&*dy, true, 0, hy);
reduce_nd(ry, hx, op, axis, shape);
return testing::diff(hy, ry);
return diff(hy, ry);
}

int main() {

0 comments on commit 7f2bc5b

Please sign in to comment.
You can’t perform that action at this time.