Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions stan/math/rev/arr/fun/sum.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class sum_v_vari : public vari {

explicit sum_v_vari(const std::vector<var>& v1)
: vari(sum_of_val(v1)),
v_(reinterpret_cast<vari**>(ChainableStack::context().memalloc_.alloc(
v1.size() * sizeof(vari*)))),
v_(reinterpret_cast<vari**>(
chainable_stack.memalloc_.alloc(v1.size() * sizeof(vari*)))),
length_(v1.size()) {
for (size_t i = 0; i < length_; i++)
v_[i] = v1[i].vi_;
Expand Down
13 changes: 0 additions & 13 deletions stan/math/rev/core/autodiffstackstorage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,6 @@ namespace math {

template <typename ChainableT, typename ChainableAllocT>
struct AutodiffStackStorage {
typedef AutodiffStackStorage<ChainableT, ChainableAllocT>
AutodiffStackStorage_t;

static AutodiffStackStorage_t& context() {
#ifndef STAN_THREADS
static AutodiffStackStorage_t ad_stack = AutodiffStackStorage_t();
#else
static thread_local AutodiffStackStorage_t ad_stack
= AutodiffStackStorage_t();
#endif
return ad_stack;
}

std::vector<ChainableT*> var_stack_;
std::vector<ChainableT*> var_nochain_stack_;
std::vector<ChainableAllocT*> var_alloc_stack_;
Expand Down
4 changes: 1 addition & 3 deletions stan/math/rev/core/chainable_alloc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ namespace math {
*/
class chainable_alloc {
public:
chainable_alloc() {
ChainableStack::context().var_alloc_stack_.push_back(this);
}
chainable_alloc() { chainable_stack.var_alloc_stack_.push_back(this); }
virtual ~chainable_alloc() {}
};

Expand Down
5 changes: 5 additions & 0 deletions stan/math/rev/core/chainablestack.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ class chainable_alloc;

typedef AutodiffStackStorage<vari, chainable_alloc> ChainableStack;

#ifdef STAN_THREADS
thread_local
#endif
static ChainableStack chainable_stack;

} // namespace math
} // namespace stan
#endif
2 changes: 1 addition & 1 deletion stan/math/rev/core/empty_nested.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace math {
* Return true if there is no nested autodiff being executed.
*/
static inline bool empty_nested() {
return ChainableStack::context().nested_var_stack_sizes_.empty();
return chainable_stack.nested_var_stack_sizes_.empty();
}

} // namespace math
Expand Down
2 changes: 1 addition & 1 deletion stan/math/rev/core/gevv_vvv_vari.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class gevv_vvv_vari : public vari {
alpha_ = alpha->vi_;
// TODO(carpenter): replace this with array alloc fun call
v1_ = reinterpret_cast<vari**>(
ChainableStack::context().memalloc_.alloc(2 * length_ * sizeof(vari*)));
chainable_stack.memalloc_.alloc(2 * length_ * sizeof(vari*)));
v2_ = v1_ + length_;
for (size_t i = 0; i < length_; i++)
v1_[i] = v1[i * stride1].vi_;
Expand Down
4 changes: 2 additions & 2 deletions stan/math/rev/core/grad.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ static void grad(vari* vi) {

typedef std::vector<vari*>::reverse_iterator it_t;
vi->init_dependent();
it_t begin = ChainableStack::context().var_stack_.rbegin();
it_t end = empty_nested() ? ChainableStack::context().var_stack_.rend()
it_t begin = chainable_stack.var_stack_.rbegin();
it_t end = empty_nested() ? chainable_stack.var_stack_.rend()
: begin + nested_size();
for (it_t it = begin; it < end; ++it) {
(*it)->chain();
Expand Down
4 changes: 2 additions & 2 deletions stan/math/rev/core/nested_size.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ namespace stan {
namespace math {

static inline size_t nested_size() {
return ChainableStack::context().var_stack_.size()
- ChainableStack::context().nested_var_stack_sizes_.back();
return chainable_stack.var_stack_.size()
- chainable_stack.nested_var_stack_sizes_.back();
}

} // namespace math
Expand Down
6 changes: 2 additions & 4 deletions stan/math/rev/core/precomputed_gradients.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,8 @@ class precomputed_gradients_vari : public vari {
const std::vector<double>& gradients)
: vari(val),
size_(vars.size()),
varis_(ChainableStack::context().memalloc_.alloc_array<vari*>(
vars.size())),
gradients_(ChainableStack::context().memalloc_.alloc_array<double>(
vars.size())) {
varis_(chainable_stack.memalloc_.alloc_array<vari*>(vars.size())),
gradients_(chainable_stack.memalloc_.alloc_array<double>(vars.size())) {
check_consistent_sizes("precomputed_gradients_vari", "vars", vars,
"gradients", gradients);
for (size_t i = 0; i < vars.size(); ++i)
Expand Down
13 changes: 5 additions & 8 deletions stan/math/rev/core/print_stack.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,12 @@ namespace math {
* @param o ostream to modify
*/
inline void print_stack(std::ostream& o) {
o << "STACK, size=" << ChainableStack::context().var_stack_.size()
<< std::endl;
o << "STACK, size=" << chainable_stack.var_stack_.size() << std::endl;
// TODO(carpenter): this shouldn't need to be cast any more
for (size_t i = 0; i < ChainableStack::context().var_stack_.size(); ++i)
o << i << " " << ChainableStack::context().var_stack_[i] << " "
<< (static_cast<vari*>(ChainableStack::context().var_stack_[i]))->val_
<< " : "
<< (static_cast<vari*>(ChainableStack::context().var_stack_[i]))->adj_
<< std::endl;
for (size_t i = 0; i < chainable_stack.var_stack_.size(); ++i)
o << i << " " << chainable_stack.var_stack_[i] << " "
<< (static_cast<vari*>(chainable_stack.var_stack_[i]))->val_ << " : "
<< (static_cast<vari*>(chainable_stack.var_stack_[i]))->adj_ << std::endl;
}

} // namespace math
Expand Down
10 changes: 5 additions & 5 deletions stan/math/rev/core/recover_memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ static inline void recover_memory() {
throw std::logic_error(
"empty_nested() must be true"
" before calling recover_memory()");
ChainableStack::context().var_stack_.clear();
ChainableStack::context().var_nochain_stack_.clear();
for (auto &x : ChainableStack::context().var_alloc_stack_) {
chainable_stack.var_stack_.clear();
chainable_stack.var_nochain_stack_.clear();
for (auto &x : chainable_stack.var_alloc_stack_) {
delete x;
}
ChainableStack::context().var_alloc_stack_.clear();
ChainableStack::context().memalloc_.recover_all();
chainable_stack.var_alloc_stack_.clear();
chainable_stack.memalloc_.recover_all();
}

} // namespace math
Expand Down
31 changes: 15 additions & 16 deletions stan/math/rev/core/recover_memory_nested.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,23 @@ static inline void recover_memory_nested() {
"empty_nested() must be false"
" before calling recover_memory_nested()");

ChainableStack::context().var_stack_.resize(
ChainableStack::context().nested_var_stack_sizes_.back());
ChainableStack::context().nested_var_stack_sizes_.pop_back();

ChainableStack::context().var_nochain_stack_.resize(
ChainableStack::context().nested_var_nochain_stack_sizes_.back());
ChainableStack::context().nested_var_nochain_stack_sizes_.pop_back();

for (size_t i
= ChainableStack::context().nested_var_alloc_stack_starts_.back();
i < ChainableStack::context().var_alloc_stack_.size(); ++i) {
delete ChainableStack::context().var_alloc_stack_[i];
chainable_stack.var_stack_.resize(
chainable_stack.nested_var_stack_sizes_.back());
chainable_stack.nested_var_stack_sizes_.pop_back();

chainable_stack.var_nochain_stack_.resize(
chainable_stack.nested_var_nochain_stack_sizes_.back());
chainable_stack.nested_var_nochain_stack_sizes_.pop_back();

for (size_t i = chainable_stack.nested_var_alloc_stack_starts_.back();
i < chainable_stack.var_alloc_stack_.size(); ++i) {
delete chainable_stack.var_alloc_stack_[i];
}
ChainableStack::context().var_alloc_stack_.resize(
ChainableStack::context().nested_var_alloc_stack_starts_.back());
ChainableStack::context().nested_var_alloc_stack_starts_.pop_back();
chainable_stack.var_alloc_stack_.resize(
chainable_stack.nested_var_alloc_stack_starts_.back());
chainable_stack.nested_var_alloc_stack_starts_.pop_back();

ChainableStack::context().memalloc_.recover_nested();
chainable_stack.memalloc_.recover_nested();
}

} // namespace math
Expand Down
4 changes: 2 additions & 2 deletions stan/math/rev/core/set_zero_all_adjoints.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ namespace math {
* Reset all adjoint values in the stack to zero.
*/
static void set_zero_all_adjoints() {
for (auto &x : ChainableStack::context().var_stack_)
for (auto &x : chainable_stack.var_stack_)
x->set_zero_adjoint();
for (auto &x : ChainableStack::context().var_nochain_stack_)
for (auto &x : chainable_stack.var_nochain_stack_)
x->set_zero_adjoint();
}

Expand Down
13 changes: 6 additions & 7 deletions stan/math/rev/core/set_zero_all_adjoints_nested.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,16 @@ static void set_zero_all_adjoints_nested() {
throw std::logic_error(
"empty_nested() must be false before calling"
" set_zero_all_adjoints_nested()");
size_t start1 = ChainableStack::context().nested_var_stack_sizes_.back();
size_t start1 = chainable_stack.nested_var_stack_sizes_.back();
// avoid wrap with unsigned when start1 == 0
for (size_t i = (start1 == 0U) ? 0U : (start1 - 1);
i < ChainableStack::context().var_stack_.size(); ++i)
ChainableStack::context().var_stack_[i]->set_zero_adjoint();
i < chainable_stack.var_stack_.size(); ++i)
chainable_stack.var_stack_[i]->set_zero_adjoint();

size_t start2
= ChainableStack::context().nested_var_nochain_stack_sizes_.back();
size_t start2 = chainable_stack.nested_var_nochain_stack_sizes_.back();
for (size_t i = (start2 == 0U) ? 0U : (start2 - 1);
i < ChainableStack::context().var_nochain_stack_.size(); ++i) {
ChainableStack::context().var_nochain_stack_[i]->set_zero_adjoint();
i < chainable_stack.var_nochain_stack_.size(); ++i) {
chainable_stack.var_nochain_stack_[i]->set_zero_adjoint();
}
}

Expand Down
14 changes: 7 additions & 7 deletions stan/math/rev/core/start_nested.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ namespace math {
* can find it.
*/
static inline void start_nested() {
ChainableStack::context().nested_var_stack_sizes_.push_back(
ChainableStack::context().var_stack_.size());
ChainableStack::context().nested_var_nochain_stack_sizes_.push_back(
ChainableStack::context().var_nochain_stack_.size());
ChainableStack::context().nested_var_alloc_stack_starts_.push_back(
ChainableStack::context().var_alloc_stack_.size());
ChainableStack::context().memalloc_.start_nested();
chainable_stack.nested_var_stack_sizes_.push_back(
chainable_stack.var_stack_.size());
chainable_stack.nested_var_nochain_stack_sizes_.push_back(
chainable_stack.var_nochain_stack_.size());
chainable_stack.nested_var_alloc_stack_starts_.push_back(
chainable_stack.var_alloc_stack_.size());
chainable_stack.memalloc_.start_nested();
}

} // namespace math
Expand Down
8 changes: 4 additions & 4 deletions stan/math/rev/core/vari.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ class vari {
* @param x Value of the constructed variable.
*/
explicit vari(double x) : val_(x), adj_(0.0) {
ChainableStack::context().var_stack_.push_back(this);
chainable_stack.var_stack_.push_back(this);
}

vari(double x, bool stacked) : val_(x), adj_(0.0) {
if (stacked)
ChainableStack::context().var_stack_.push_back(this);
chainable_stack.var_stack_.push_back(this);
else
ChainableStack::context().var_nochain_stack_.push_back(this);
chainable_stack.var_nochain_stack_.push_back(this);
}

/**
Expand Down Expand Up @@ -123,7 +123,7 @@ class vari {
* @return Pointer to allocated bytes.
*/
static inline void* operator new(size_t nbytes) {
return ChainableStack::context().memalloc_.alloc(nbytes);
return chainable_stack.memalloc_.alloc(nbytes);
}

/**
Expand Down
8 changes: 4 additions & 4 deletions stan/math/rev/mat/fun/cholesky_decompose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ class cholesky_block : public vari {
const Eigen::Matrix<double, -1, -1>& L_A)
: vari(0.0),
M_(A.rows()),
variRefA_(ChainableStack::context().memalloc_.alloc_array<vari*>(
variRefA_(chainable_stack.memalloc_.alloc_array<vari*>(
A.rows() * (A.rows() + 1) / 2)),
variRefL_(ChainableStack::context().memalloc_.alloc_array<vari*>(
variRefL_(chainable_stack.memalloc_.alloc_array<vari*>(
A.rows() * (A.rows() + 1) / 2)) {
size_t pos = 0;
block_size_ = std::max((M_ / 8 / 16) * 16, 8);
Expand Down Expand Up @@ -159,9 +159,9 @@ class cholesky_scalar : public vari {
const Eigen::Matrix<double, -1, -1>& L_A)
: vari(0.0),
M_(A.rows()),
variRefA_(ChainableStack::context().memalloc_.alloc_array<vari*>(
variRefA_(chainable_stack.memalloc_.alloc_array<vari*>(
A.rows() * (A.rows() + 1) / 2)),
variRefL_(ChainableStack::context().memalloc_.alloc_array<vari*>(
variRefL_(chainable_stack.memalloc_.alloc_array<vari*>(
A.rows() * (A.rows() + 1) / 2)) {
size_t accum = 0;
size_t accum_i = accum;
Expand Down
18 changes: 6 additions & 12 deletions stan/math/rev/mat/fun/cov_exp_quad.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,11 @@ class cov_exp_quad_vari : public vari {
l_d_(value_of(l)),
sigma_d_(value_of(sigma)),
sigma_sq_d_(sigma_d_ * sigma_d_),
dist_(ChainableStack::context().memalloc_.alloc_array<double>(
size_ltri_)),
dist_(chainable_stack.memalloc_.alloc_array<double>(size_ltri_)),
l_vari_(l.vi_),
sigma_vari_(sigma.vi_),
cov_lower_(
ChainableStack::context().memalloc_.alloc_array<vari*>(size_ltri_)),
cov_diag_(
ChainableStack::context().memalloc_.alloc_array<vari*>(size_)) {
cov_lower_(chainable_stack.memalloc_.alloc_array<vari*>(size_ltri_)),
cov_diag_(chainable_stack.memalloc_.alloc_array<vari*>(size_)) {
double inv_half_sq_l_d = 0.5 / (l_d_ * l_d_);
size_t pos = 0;
for (size_t j = 0; j < size_ - 1; ++j) {
Expand Down Expand Up @@ -164,13 +161,10 @@ class cov_exp_quad_vari<T_x, double, T_l> : public vari {
l_d_(value_of(l)),
sigma_d_(value_of(sigma)),
sigma_sq_d_(sigma_d_ * sigma_d_),
dist_(ChainableStack::context().memalloc_.alloc_array<double>(
size_ltri_)),
dist_(chainable_stack.memalloc_.alloc_array<double>(size_ltri_)),
l_vari_(l.vi_),
cov_lower_(
ChainableStack::context().memalloc_.alloc_array<vari*>(size_ltri_)),
cov_diag_(
ChainableStack::context().memalloc_.alloc_array<vari*>(size_)) {
cov_lower_(chainable_stack.memalloc_.alloc_array<vari*>(size_ltri_)),
cov_diag_(chainable_stack.memalloc_.alloc_array<vari*>(size_)) {
double inv_half_sq_l_d = 0.5 / (l_d_ * l_d_);
size_t pos = 0;
for (size_t j = 0; j < size_ - 1; ++j) {
Expand Down
7 changes: 3 additions & 4 deletions stan/math/rev/mat/fun/determinant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@ class determinant_vari : public vari {
: vari(determinant_vari_calc(A)),
rows_(A.rows()),
cols_(A.cols()),
A_(reinterpret_cast<double*>(ChainableStack::context().memalloc_.alloc(
A_(reinterpret_cast<double*>(chainable_stack.memalloc_.alloc(
sizeof(double) * A.rows() * A.cols()))),
adjARef_(
reinterpret_cast<vari**>(ChainableStack::context().memalloc_.alloc(
sizeof(vari*) * A.rows() * A.cols()))) {
adjARef_(reinterpret_cast<vari**>(chainable_stack.memalloc_.alloc(
sizeof(vari*) * A.rows() * A.cols()))) {
size_t pos = 0;
for (size_type j = 0; j < cols_; j++) {
for (size_type i = 0; i < rows_; i++) {
Expand Down
8 changes: 4 additions & 4 deletions stan/math/rev/mat/fun/dot_product.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class dot_product_vari : public vari {
vari** shared = nullptr) {
if (shared == nullptr) {
mem_v = reinterpret_cast<vari**>(
ChainableStack::context().memalloc_.alloc(length_ * sizeof(vari*)));
chainable_stack.memalloc_.alloc(length_ * sizeof(vari*)));
for (size_t i = 0; i < length_; i++)
mem_v[i] = inv[i].vi_;
} else {
Expand All @@ -97,7 +97,7 @@ class dot_product_vari : public vari {
vari** shared = nullptr) {
if (shared == nullptr) {
mem_v = reinterpret_cast<vari**>(
ChainableStack::context().memalloc_.alloc(length_ * sizeof(vari*)));
chainable_stack.memalloc_.alloc(length_ * sizeof(vari*)));
for (size_t i = 0; i < length_; i++)
mem_v[i] = inv(i).vi_;
} else {
Expand All @@ -109,7 +109,7 @@ class dot_product_vari : public vari {
double* shared = nullptr) {
if (shared == nullptr) {
mem_d = reinterpret_cast<double*>(
ChainableStack::context().memalloc_.alloc(length_ * sizeof(double)));
chainable_stack.memalloc_.alloc(length_ * sizeof(double)));
for (size_t i = 0; i < length_; i++)
mem_d[i] = ind[i];
} else {
Expand All @@ -121,7 +121,7 @@ class dot_product_vari : public vari {
double* shared = nullptr) {
if (shared == nullptr) {
mem_d = reinterpret_cast<double*>(
ChainableStack::context().memalloc_.alloc(length_ * sizeof(double)));
chainable_stack.memalloc_.alloc(length_ * sizeof(double)));
for (size_t i = 0; i < length_; i++)
mem_d[i] = ind(i);
} else {
Expand Down
4 changes: 2 additions & 2 deletions stan/math/rev/mat/fun/dot_self.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ class dot_self_vari : public vari {
explicit dot_self_vari(const Eigen::DenseBase<Derived>& v)
: vari(var_dot_self(v)), size_(v.size()) {
v_ = reinterpret_cast<vari**>(
ChainableStack::context().memalloc_.alloc(size_ * sizeof(vari*)));
chainable_stack.memalloc_.alloc(size_ * sizeof(vari*)));
for (size_t i = 0; i < size_; i++)
v_[i] = v[i].vi_;
}
template <int R, int C>
explicit dot_self_vari(const Eigen::Matrix<var, R, C>& v)
: vari(var_dot_self(v)), size_(v.size()) {
v_ = reinterpret_cast<vari**>(
ChainableStack::context().memalloc_.alloc(size_ * sizeof(vari*)));
chainable_stack.memalloc_.alloc(size_ * sizeof(vari*)));
for (size_t i = 0; i < size_; ++i)
v_[i] = v(i).vi_;
}
Expand Down
Loading