Skip to content

Commit

Permalink
Fix reference count if array used in JIT operations.
Browse files Browse the repository at this point in the history
Previously when an af::array was used in a jit operation and it was backed by a
buffer, a buffer node was created and the internal shared_ptr was stored in the
Array for future use and returned when getNode was called. This increased the
reference count of the internal buffer. This reference count never decreased
because of the internal reference to the shared_ptr.

This commit changes this behavior by createing new buffer nodes for each
call the getNode. We use the new hash function to ensure the equality of
the buffer node when the jit code is generated. This avoids holding the
call_once flag in the buffer object and simplifies the management of
the buffer node objects. Additionally when a jit node goes out of scope
the reference count decrements as expected.
  • Loading branch information
umar456 committed Aug 17, 2021
1 parent 0465a6b commit 5e4d034
Show file tree
Hide file tree
Showing 21 changed files with 311 additions and 222 deletions.
9 changes: 4 additions & 5 deletions src/backend/common/cast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@ template<typename To, typename Ti>
struct CastWrapper {
detail::Array<To> operator()(const detail::Array<Ti> &in) {
using cpu::jit::UnaryNode;

Node_ptr in_node = in.getNode();
UnaryNode<To, Ti, af_cast_t> *node =
new UnaryNode<To, Ti, af_cast_t>(in_node);
return detail::createNodeArray<To>(
in.dims(),
common::Node_ptr(reinterpret_cast<common::Node *>(node)));
auto node = std::make_shared<UnaryNode<To, Ti, af_cast_t>>(in_node);

return detail::createNodeArray<To>(in.dims(), move(node));
}
};
#else
Expand Down
10 changes: 7 additions & 3 deletions src/backend/common/jit/BinaryNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <complex.hpp>
#include <types.hpp>

#include <memory>

using af::dim4;
using af::dtype_traits;
using detail::Array;
Expand All @@ -13,6 +15,8 @@ using detail::cdouble;
using detail::cfloat;
using detail::createNodeArray;

using std::make_shared;

namespace common {
#ifdef AF_CPU
template<typename To, typename Ti, af_op_t op>
Expand All @@ -21,10 +25,10 @@ Array<To> createBinaryNode(const Array<Ti> &lhs, const Array<Ti> &rhs,
common::Node_ptr lhs_node = lhs.getNode();
common::Node_ptr rhs_node = rhs.getNode();

detail::jit::BinaryNode<To, Ti, op> *node =
new detail::jit::BinaryNode<To, Ti, op>(lhs_node, rhs_node);
auto node =
make_shared<detail::jit::BinaryNode<To, Ti, op>>(lhs_node, rhs_node);

return createNodeArray<To>(odims, common::Node_ptr(node));
return createNodeArray<To>(odims, move(node));
}

#else
Expand Down
19 changes: 6 additions & 13 deletions src/backend/common/jit/BufferNodeBase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
#include <common/jit/Node.hpp>
#include <jit/kernel_generators.hpp>

#include <iomanip>
#include <mutex>
#include <sstream>

namespace common {
Expand All @@ -24,25 +22,20 @@ class BufferNodeBase : public common::Node {
DataType m_data;
ParamType m_param;
unsigned m_bytes;
std::once_flag m_set_data_flag;
bool m_linear_buffer;

public:
BufferNodeBase(af::dtype type) : Node(type, 0, {}) {
// This class is not movable because of std::once_flag
}
BufferNodeBase(af::dtype type)
: Node(type, 0, {}), m_bytes(0), m_linear_buffer(true) {}

bool isBuffer() const final { return true; }

void setData(ParamType param, DataType data, const unsigned bytes,
bool is_linear) {
std::call_once(m_set_data_flag,
[this, param, data, bytes, is_linear]() {
m_param = param;
m_data = data;
m_bytes = bytes;
m_linear_buffer = is_linear;
});
m_param = param;
m_data = data;
m_bytes = bytes;
m_linear_buffer = is_linear;
}

bool isLinear(dim_t dims[4]) const final {
Expand Down
48 changes: 38 additions & 10 deletions src/backend/common/jit/Node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
#include <types.hpp>
#include <af/defines.h>

#include <algorithm>
#include <array>
#include <functional>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

enum class kJITHeuristics {
Expand All @@ -34,6 +35,17 @@ namespace common {
class Node;
}

#ifdef AF_CPU
namespace cpu {
namespace kernel {

template<typename T>
void evalMultiple(std::vector<Param<T>> arrays,
std::vector<std::shared_ptr<common::Node>> output_nodes_);
}
} // namespace cpu
#endif

namespace std {
template<>
struct hash<common::Node *> {
Expand Down Expand Up @@ -107,15 +119,6 @@ class Node {
template<typename T>
friend class NodeIterator;

void swap(Node &other) noexcept {
using std::swap;
for (int i = 0; i < kMaxChildren; i++) {
swap(m_children[i], other.m_children[i]);
}
swap(m_type, other.m_type);
swap(m_height, other.m_height);
}

public:
Node() = default;
Node(const af::dtype type, const int height,
Expand All @@ -125,6 +128,15 @@ class Node {
"Node is not move assignable");
}

void swap(Node &other) noexcept {
using std::swap;
for (int i = 0; i < kMaxChildren; i++) {
swap(m_children[i], other.m_children[i]);
}
swap(m_type, other.m_type);
swap(m_height, other.m_height);
}

/// Default move constructor operator
Node(Node &&node) noexcept = default;

Expand Down Expand Up @@ -266,6 +278,22 @@ class Node {
virtual bool operator==(const Node &other) const noexcept {
return this == &other;
}

#ifdef AF_CPU
/// Replaces a child node pointer in the cpu::jit::BinaryNode<T> or the
/// cpu::jit::UnaryNode classes at \p id with *ptr. Used only in the CPU
/// backend and does not modify the m_children pointers in the
/// common::Node_ptr class.
virtual void replaceChild(int id, void *ptr) noexcept {
UNUSED(id);
UNUSED(ptr);
}

template<typename U>
friend void cpu::kernel::evalMultiple(
std::vector<cpu::Param<U>> arrays,
std::vector<common::Node_ptr> output_nodes_);
#endif
};

struct Node_ids {
Expand Down
76 changes: 28 additions & 48 deletions src/backend/cpu/Array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,19 @@ using common::Node_map_t;
using common::Node_ptr;
using common::NodeIterator;
using cpu::jit::BufferNode;

using std::adjacent_find;
using std::copy;
using std::is_standard_layout;
using std::make_shared;
using std::move;
using std::vector;

namespace cpu {

template<typename T>
Node_ptr bufferNodePtr() {
return Node_ptr(reinterpret_cast<Node *>(new BufferNode<T>()));
shared_ptr<BufferNode<T>> bufferNodePtr() {
return std::make_shared<BufferNode<T>>();
}

template<typename T>
Expand All @@ -62,8 +64,7 @@ Array<T>::Array(dim4 dims)
static_cast<af_dtype>(dtype_traits<T>::af_type))
, data(memAlloc<T>(dims.elements()).release(), memFree<T>)
, data_dims(dims)
, node(bufferNodePtr<T>())
, ready(true)
, node()
, owner(true) {}

template<typename T>
Expand All @@ -75,8 +76,7 @@ Array<T>::Array(const dim4 &dims, T *const in_data, bool is_device,
: memAlloc<T>(dims.elements()).release(),
memFree<T>)
, data_dims(dims)
, node(bufferNodePtr<T>())
, ready(true)
, node()
, owner(true) {
static_assert(is_standard_layout<Array<T>>::value,
"Array<T> must be a standard layout type");
Expand All @@ -101,7 +101,6 @@ Array<T>::Array(const af::dim4 &dims, Node_ptr n)
, data()
, data_dims(dims)
, node(move(n))
, ready(false)
, owner(true) {}

template<typename T>
Expand All @@ -111,8 +110,7 @@ Array<T>::Array(const Array<T> &parent, const dim4 &dims, const dim_t &offset_,
static_cast<af_dtype>(dtype_traits<T>::af_type))
, data(parent.getData())
, data_dims(parent.getDataDims())
, node(bufferNodePtr<T>())
, ready(true)
, node()
, owner(false) {}

template<typename T>
Expand All @@ -123,8 +121,7 @@ Array<T>::Array(const dim4 &dims, const dim4 &strides, dim_t offset_,
, data(is_device ? in_data : memAlloc<T>(info.total()).release(),
memFree<T>)
, data_dims(dims)
, node(bufferNodePtr<T>())
, ready(true)
, node()
, owner(true) {
if (!is_device) {
// Ensure the memory being written to isnt used anywhere else.
Expand All @@ -135,40 +132,27 @@ Array<T>::Array(const dim4 &dims, const dim4 &strides, dim_t offset_,

template<typename T>
void Array<T>::eval() {
if (isReady()) { return; }
if (getQueue().is_worker()) {
AF_ERROR("Array not evaluated", AF_ERR_INTERNAL);
}

this->setId(getActiveDeviceId());

data = shared_ptr<T>(memAlloc<T>(elements()).release(), memFree<T>);

getQueue().enqueue(kernel::evalArray<T>, *this, this->node);
// Reset shared_ptr
this->node = bufferNodePtr<T>();
ready = true;
evalMultiple<T>({this});
}

template<typename T>
void Array<T>::eval() const {
if (isReady()) { return; }
const_cast<Array<T> *>(this)->eval();
}

template<typename T>
T *Array<T>::device() {
getQueue().sync();
if (!isOwner() || getOffset() || data.use_count() > 1) {
*this = copyArray<T>(*this);
}
getQueue().sync();
return this->get();
}

template<typename T>
void evalMultiple(vector<Array<T> *> array_ptrs) {
vector<Array<T> *> outputs;
vector<Node_ptr> nodes;
vector<common::Node_ptr> nodes;
vector<Param<T>> params;
if (getQueue().is_worker()) {
AF_ERROR("Array not evaluated", AF_ERR_INTERNAL);
Expand All @@ -187,41 +171,39 @@ void evalMultiple(vector<Array<T> *> array_ptrs) {
}

for (Array<T> *array : array_ptrs) {
if (array->ready) { continue; }
if (array->isReady()) { continue; }

array->setId(getActiveDeviceId());
array->data =
shared_ptr<T>(memAlloc<T>(array->elements()).release(), memFree<T>);

outputs.push_back(array);
params.push_back(*array);
params.emplace_back(array->getData().get(), array->dims(),
array->strides());
nodes.push_back(array->node);
}

if (!outputs.empty()) {
getQueue().enqueue(kernel::evalMultiple<T>, params, nodes);
for (Array<T> *array : outputs) {
array->ready = true;
array->node = bufferNodePtr<T>();
}
}
if (params.empty()) return;

getQueue().enqueue(cpu::kernel::evalMultiple<T>, params, nodes);

for (Array<T> *array : outputs) { array->node.reset(); }
}

template<typename T>
Node_ptr Array<T>::getNode() {
if (node->isBuffer()) {
auto *bufNode = reinterpret_cast<BufferNode<T> *>(node.get());
unsigned bytes = this->getDataDims().elements() * sizeof(T);
bufNode->setData(data, bytes, getOffset(), dims().get(),
strides().get(), isLinear());
}
return node;
if (node) { return node; }

std::shared_ptr<BufferNode<T>> out = bufferNodePtr<T>();
unsigned bytes = this->getDataDims().elements() * sizeof(T);
out->setData(data, bytes, getOffset(), dims().get(), strides().get(),
isLinear());
return out;
}

template<typename T>
Node_ptr Array<T>::getNode() const {
if (node->isBuffer()) { return const_cast<Array<T> *>(this)->getNode(); }
return node;
return const_cast<Array<T> *>(this)->getNode();
}

template<typename T>
Expand All @@ -236,8 +218,7 @@ Array<T> createDeviceDataArray(const dim4 &dims, void *data) {

template<typename T>
Array<T> createValueArray(const dim4 &dims, const T &value) {
auto *node = new jit::ScalarNode<T>(value);
return createNodeArray<T>(dims, Node_ptr(node));
return createNodeArray<T>(dims, make_shared<jit::ScalarNode<T>>(value));
}

template<typename T>
Expand Down Expand Up @@ -337,7 +318,6 @@ template<typename T>
void Array<T>::setDataDims(const dim4 &new_dims) {
modDims(new_dims);
data_dims = new_dims;
if (node->isBuffer()) { node = bufferNodePtr<T>(); }
}

#define INSTANTIATE(T) \
Expand Down
Loading

0 comments on commit 5e4d034

Please sign in to comment.