Skip to content

Commit

Permalink
Use getArray instead of castArray if types are the same in arithOp
Browse files Browse the repository at this point in the history
  • Loading branch information
umar456 committed Aug 17, 2021
1 parent a1de6ac commit da6518e
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 5 deletions.
15 changes: 12 additions & 3 deletions src/api/c/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <common/half.hpp>

using af::dim4;
using af::dtype;
using common::half;
using detail::arithOp;
using detail::arithOpD;
Expand All @@ -41,9 +42,17 @@ using detail::ushort;
template<typename T, af_op_t op>
static inline af_array arithOp(const af_array lhs, const af_array rhs,
const dim4 &odims) {
af_array res =
getHandle(arithOp<T, op>(castArray<T>(lhs), castArray<T>(rhs), odims));
return res;
const ArrayInfo &linfo = getInfo(lhs);
const ArrayInfo &rinfo = getInfo(rhs);

dtype type = static_cast<af::dtype>(af::dtype_traits<T>::af_type);

const detail::Array<T> &l =
linfo.getType() == type ? getArray<T>(lhs) : castArray<T>(lhs);
const detail::Array<T> &r =
rinfo.getType() == type ? getArray<T>(rhs) : castArray<T>(rhs);

return getHandle(arithOp<T, op>(l, r, odims));
}

template<typename T, af_op_t op>
Expand Down
4 changes: 3 additions & 1 deletion src/backend/common/jit/NaryNode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ common::Node_ptr createNaryNode(
const af::dim4 &odims, FUNC createNode,
std::array<const detail::Array<Ti> *, N> &&children) {
std::array<common::Node_ptr, N> childNodes;
for (int i = 0; i < N; i++) { childNodes[i] = children[i]->getNode(); }
for (int i = 0; i < N; i++) {
childNodes[i] = move(children[i]->getNode());
}

common::Node_ptr ptr = createNode(childNodes);

Expand Down
2 changes: 1 addition & 1 deletion src/backend/cpu/Array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ kJITHeuristics passesJitHeuristics(Node *root_node) {

template<typename T>
Array<T> createNodeArray(const dim4 &dims, Node_ptr node) {
Array<T> out = Array<T>(dims, node);
Array<T> out(dims, node);
return out;
}

Expand Down
6 changes: 6 additions & 0 deletions src/backend/cpu/arith.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@

namespace cpu {

template<typename T, af_op_t op>
Array<T> arithOp(const Array<T> &&lhs, const Array<T> &&rhs,
const af::dim4 &odims) {
return common::createBinaryNode<T, T, op>(lhs, rhs, odims);
}

template<typename T, af_op_t op>
Array<T> arithOp(const Array<T> &lhs, const Array<T> &rhs,
const af::dim4 &odims) {
Expand Down
7 changes: 7 additions & 0 deletions src/backend/cuda/arith.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
#include <af/dim4.hpp>

namespace cuda {

template<typename T, af_op_t op>
Array<T> arithOp(const Array<T> &&lhs, const Array<T> &&rhs,
const af::dim4 &odims) {
return common::createBinaryNode<T, T, op>(lhs, rhs, odims);
}

template<typename T, af_op_t op>
Array<T> arithOp(const Array<T> &lhs, const Array<T> &rhs,
const af::dim4 &odims) {
Expand Down
7 changes: 7 additions & 0 deletions src/backend/opencl/arith.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
#include <af/dim4.hpp>

namespace opencl {

template<typename T, af_op_t op>
Array<T> arithOp(const Array<T> &&lhs, const Array<T> &&rhs,
const af::dim4 &odims) {
return common::createBinaryNode<T, T, op>(lhs, rhs, odims);
}

template<typename T, af_op_t op>
Array<T> arithOp(const Array<T> &lhs, const Array<T> &rhs,
const af::dim4 &odims) {
Expand Down

0 comments on commit da6518e

Please sign in to comment.