Skip to content

Commit

Permalink
dot, linalg thread safety
Browse files Browse the repository at this point in the history
  • Loading branch information
wkcn committed Jun 28, 2018
1 parent 212b691 commit db016f2
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 73 deletions.
81 changes: 8 additions & 73 deletions mobula_op/inc/func.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include "defines.h"
#include "context/context.h"

#include "linalg.h"

namespace mobula {


Expand All @@ -29,59 +31,12 @@ MOBULA_KERNEL binary_kernel(const int n, const T *a, const T *b, T *out, BINARY_
template <typename T>
MOBULA_KERNEL dot_add_kernel(const int n, const T *a, const T *b, const int U, const int K, const int M, T *out) {
parfor(n, [&](int index) {
const int i = index / (K * U);
const int k = (index / U) % K;
const int u = index % U;
for (int m = 0; m < M; ++m) {
out[(i * K + k) * M + m] += a[i * U + u] * b[(k * U + u) * M + m];
}
});
}

// out[i, j] = sum(a[i, :] * b[:, j])
template <typename T>
MOBULA_KERNEL linalg_gemm_ff_kernel(const int n, const T *a, const T *b, const int U, const int J, T *out) {
parfor(n, [&](int index) {
const int i = index / U;
const int u = index % U;
for (int j = 0; j < J; ++j) {
out[i * J + j] += a[i * U + u] * b[u * J + j];
}
});
}

// out[i, j] = sum(a[i, :] * b[j, :])
template <typename T>
MOBULA_KERNEL linalg_gemm_ft_kernel(const int n, const T *a, const T *b, const int U, const int J, T *out) {
parfor(n, [&](int index) {
const int i = index / J;
const int j = index % J;
const int i = index / K;
const int k = index % K;
for (int u = 0; u < U; ++u) {
out[i * J + j] += a[i * U + u] * b[j * U + u];
}
});
}

// out[i, j] = sum(a[:, i] * b[:, j])
template <typename T>
MOBULA_KERNEL linalg_gemm_tf_kernel(const int n, const T *a, const T *b, const int I, const int J, T *out) {
parfor(n, [&](int index) {
const int u = index / I;
const int i = index % I;
for (int j = 0; j < J; ++j) {
out[i * J + j] += a[u * I + i] * b[u * J + j];
}
});
}

// out[i, j] = sum(a[:, i] * b[j, :])
template <typename T>
MOBULA_KERNEL linalg_gemm_tt_kernel(const int n, const T *a, const T *b, const int I, const int U, const int J, T *out) {
parfor(n, [&](int index) {
const int j = index / U;
const int u = index % U;
for (int i = 0; i < I; ++i) {
out[i * J + j] += a[u * I + i] * b[j * U + u];
for (int m = 0; m < M; ++m) {
out[(i * K + k) * M + m] += a[i * U + u] * b[(k * U + u) * M + m];
}
}
});
}
Expand Down Expand Up @@ -141,30 +96,10 @@ REGISTER_BINARY_FUNC(mul, []MOBULA_DEVICE(const DType &a, const DType &b){return
REGISTER_BINARY_FUNC(div_, []MOBULA_DEVICE(const DType &a, const DType &b){return a / b;})

void dot_add(const DType *a, const DType *b, const int I, const int U, const int K, const int M, DType *out) {
const int N = I * K * U;
const int N = I * K;
KERNEL_RUN(dot_add_kernel<DType>, N)(N, a, b, U, K, M, out);
}

void linalg_gemm_ff(const DType *a, const DType *b, const int I, const int U, const int J, DType *out) {
const int N = I * U;
KERNEL_RUN(linalg_gemm_ff_kernel<DType>, N)(N, a, b, U, J, out);
}

void linalg_gemm_ft(const DType *a, const DType *b, const int I, const int U, const int J, DType *out) {
const int N = I * J;
KERNEL_RUN(linalg_gemm_ft_kernel<DType>, N)(N, a, b, U, J, out);
}

void linalg_gemm_tf(const DType *a, const DType *b, const int I, const int U, const int J, DType *out) {
const int N = U * I;
KERNEL_RUN(linalg_gemm_tf_kernel<DType>, N)(N, a, b, I, J, out);
}

void linalg_gemm_tt(const DType *a, const DType *b, const int I, const int U, const int J, DType *out) {
const int N = J * U;
KERNEL_RUN(linalg_gemm_tt_kernel<DType>, N)(N, a, b, I, U, J, out);
}

void print_carray(CArray<DType> ca) {
bool first = true;
for (size_t i = 0; i < ca.size; ++i) {
Expand Down
84 changes: 84 additions & 0 deletions mobula_op/inc/linalg.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#ifndef _MOBULA_LINALG_H_
#define _MOBULA_LINALG_H_

#include "defines.h"
#include "context/context.h"

namespace mobula {

// out[i, j] = sum(a[i, :] * b[:, j])
template <typename T>
MOBULA_KERNEL linalg_gemm_ff_kernel(const int n, const T *a, const T *b, const int U, const int J, T *out) {
parfor(n, [&](int i) {
for (int u = 0; u < U; ++u) {
for (int j = 0; j < J; ++j) {
out[i * J + j] += a[i * U + u] * b[u * J + j];
}
}
});
}

// out[i, j] = sum(a[i, :] * b[j, :])
template <typename T>
MOBULA_KERNEL linalg_gemm_ft_kernel(const int n, const T *a, const T *b, const int U, const int J, T *out) {
parfor(n, [&](int i) {
for (int j = 0; j < J; ++j) {
for (int u = 0; u < U; ++u) {
out[i * J + j] += a[i * U + u] * b[j * U + u];
}
}
});
}

// out[i, j] = sum(a[:, i] * b[:, j])
template <typename T>
MOBULA_KERNEL linalg_gemm_tf_kernel(const int n, const T *a, const T *b, const int I, const int J, T *out) {
parfor(n, [&](int u) {
for (int i = 0; i < I; ++i) {
for (int j = 0; j < J; ++j) {
out[i * J + j] += a[u * I + i] * b[u * J + j];
}
}
});
}

// out[i, j] = sum(a[:, i] * b[j, :])
template <typename T>
MOBULA_KERNEL linalg_gemm_tt_kernel(const int n, const T *a, const T *b, const int I, const int U, const int J, T *out) {
parfor(n, [&](int j) {
for (int u = 0; u < U; ++u) {
for (int i = 0; i < I; ++i) {
out[i * J + j] += a[u * I + i] * b[j * U + u];
}
}
});
}

}

extern "C" {
using namespace mobula;

void linalg_gemm_ff(const DType *a, const DType *b, const int I, const int U, const int J, DType *out) {
const int N = I;
KERNEL_RUN(linalg_gemm_ff_kernel<DType>, N)(N, a, b, U, J, out);
}

void linalg_gemm_ft(const DType *a, const DType *b, const int I, const int U, const int J, DType *out) {
const int N = I;
KERNEL_RUN(linalg_gemm_ft_kernel<DType>, N)(N, a, b, U, J, out);
}

void linalg_gemm_tf(const DType *a, const DType *b, const int I, const int U, const int J, DType *out) {
const int N = U;
KERNEL_RUN(linalg_gemm_tf_kernel<DType>, N)(N, a, b, I, J, out);
}

void linalg_gemm_tt(const DType *a, const DType *b, const int I, const int U, const int J, DType *out) {
const int N = J;
KERNEL_RUN(linalg_gemm_tt_kernel<DType>, N)(N, a, b, I, U, J, out);
}

}

#endif
1 change: 1 addition & 0 deletions mobula_op/maths.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,4 @@ def linalg_gemm(a, b, out = None, tA = False, tB = False, req = const.req.write)
if req != const.req.add:
out[:] = 0
LINALG_GEMM_FUNC[tA][tB](a, b, I, U, J, out)
return out
34 changes: 34 additions & 0 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,37 @@ def test_tensordot():
c = np.tensordot(a, b, axes = axes)
d = mobula_op.math.tensordot(a, b, axes = axes)
assert_almost_equal(c, d)

def check_math_func(func, target, **kwargs):
out = func(**kwargs)
assert_almost_equal(out, target)
out = np.zeros_like(out, dtype = np.float32)
func(out = out, **kwargs)
assert_almost_equal(out, target)
# test req add
base = np.random.random(out.shape).astype(np.float32)
out = base.copy()
func(out = out, req = mobula_op.const.req.add, **kwargs)
assert_almost_equal(out, target + base)
func(out = out, req = mobula_op.const.req.write, **kwargs)
assert_almost_equal(out, target)

def test_linalg_gemm():
I, J, K = 10, 11, 12
a = np.random.random((I, J)).astype(np.float32)
b = np.random.random((J, K)).astype(np.float32)
c = np.empty((I, K), dtype = np.float32)
t = np.dot(a, b)
check_math_func(mobula_op.math.linalg_gemm, t, a = a, b = b, tA = False, tB = False)

b = b.reshape((K, J))
t = np.dot(a, b.T)
check_math_func(mobula_op.math.linalg_gemm, t, a = a, b = b, tA = False, tB = True)

a = a.reshape((J, I))
t = np.dot(a.T, b.T)
check_math_func(mobula_op.math.linalg_gemm, t, a = a, b = b, tA = True, tB = True)

b = b.T
t = np.dot(a.T, b)
check_math_func(mobula_op.math.linalg_gemm, t, a = a, b = b, tA = True, tB = False)

0 comments on commit db016f2

Please sign in to comment.