Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
135 commits
Select commit Hold shift + click to select a range
8099c90
prim version
wds15 Dec 30, 2019
7027c12
make nested stuff work
wds15 Dec 30, 2019
302cd71
extend signature and test nested parallel AD
wds15 Dec 31, 2019
9740b45
add hierarchical example
wds15 Dec 31, 2019
e47ca13
add recover_memory_global
wds15 Dec 31, 2019
e94be85
Merge branch 'proto-parallel-v3' of https://github.com/stan-dev/math …
wds15 Dec 31, 2019
9565db9
add file
wds15 Dec 31, 2019
6faa614
Merge branch 'proto-parallel-v3' of https://github.com/stan-dev/math …
wds15 Dec 31, 2019
acbb3dc
fix
wds15 Dec 31, 2019
ca762fe
Merge branch 'proto-parallel-v3' of https://github.com/stan-dev/math …
wds15 Dec 31, 2019
bc2a601
fix
wds15 Dec 31, 2019
eec7472
remove debugging msg
wds15 Dec 31, 2019
2d696c8
omit recover_memory_global which is not needed
wds15 Jan 1, 2020
8a1d99e
Merge branch 'proto-parallel-v3' of https://github.com/stan-dev/math …
wds15 Jan 1, 2020
6d8ad8a
aggregate more efficiently the partial sums
wds15 Jan 1, 2020
1a54bde
simplify how values are copied
wds15 Jan 2, 2020
f689c93
const correctness
wds15 Jan 2, 2020
0da9de0
make code more generic (some meta magic bits are missing)
wds15 Jan 2, 2020
73fb5ae
make parallel reduce sum work with posted example... need more meta-p…
wds15 Jan 3, 2020
e7a83c6
rename to reduce_sum
wds15 Jan 5, 2020
2a370e7
add up to 4 arguments for reduce function
wds15 Jan 5, 2020
bd13907
make arguments optional
wds15 Jan 5, 2020
e1deb4f
more doc and const declares
wds15 Jan 5, 2020
358525c
doc
wds15 Jan 5, 2020
13641ea
Merge remote-tracking branch 'origin/develop' into proto-parallel-v3
wds15 Jan 5, 2020
5f76ca2
generalize possible input data structures
wds15 Jan 6, 2020
d9e5276
refactor such that any data strcuture (contained in an array) can be …
wds15 Jan 7, 2020
5e01bf0
fix looping order error
wds15 Jan 7, 2020
f2ed8a9
still struggling with performance regression
wds15 Jan 10, 2020
a3b4ad4
Merge remote-tracking branch 'origin/develop' into proto-parallel-v3
wds15 Jan 10, 2020
f3b37ae
start going back to better abstracted code
wds15 Jan 10, 2020
940fc82
move to shared_ptr, simplify counstructor call
wds15 Jan 12, 2020
8b66f94
move to partials being stored as flat vector
wds15 Jan 12, 2020
e56cdf2
get rid of obsolete code
wds15 Jan 12, 2020
2cece58
add preliminary support for structured slicing arguments which are no…
wds15 Jan 12, 2020
71c0196
add special case of local_operands_and_partials for non-var in a work…
wds15 Jan 12, 2020
e2b20d7
start cleanup bits of the reduce_sum code. Mostly to get an idea of w…
SteveBronder Jan 14, 2020
c0c118a
start cleanup bits of the reduce_sum code. Mostly to get an idea of w…
SteveBronder Jan 14, 2020
ff0f582
start cleanup bits of the reduce_sum code. Mostly to get an idea of w…
SteveBronder Jan 14, 2020
96e8c8e
start cleanup bits of the reduce_sum code. Mostly to get an idea of w…
SteveBronder Jan 14, 2020
3d62fa8
loose stuff needed removed
SteveBronder Jan 14, 2020
cbcf42f
Added variadic implementation of rev parallel_sum
bbbales2 Jan 15, 2020
8f90d8c
Adds enable_ifs to reduce_sum_impl
SteveBronder Jan 15, 2020
617d6e2
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Jan 15, 2020
fbfa9ba
Fixup little templating things
SteveBronder Jan 15, 2020
008255e
merge to develop
SteveBronder Jan 15, 2020
00e5743
Make accumulate adjoints accept more types
SteveBronder Jan 15, 2020
623ed9e
Catch arithmetics in count_var_impl
SteveBronder Jan 15, 2020
16db0f2
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Jan 15, 2020
9305fb0
move operator_paren checker out one scope
SteveBronder Jan 15, 2020
bb9a1b2
Merge branch 'cleanup/proto-parallel-v3' of github.com:stan-dev/math …
SteveBronder Jan 15, 2020
430ba3f
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Jan 15, 2020
0248bcb
add left_fold helper func
SteveBronder Jan 15, 2020
b2c1dec
Merge branch 'cleanup/proto-parallel-v3' of github.com:stan-dev/math …
SteveBronder Jan 15, 2020
cf8b530
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Jan 15, 2020
a27047f
Get slicing tests to pass
SteveBronder Jan 15, 2020
2dd8881
Merge branch 'cleanup/proto-parallel-v3' of github.com:stan-dev/math …
SteveBronder Jan 15, 2020
5e3e078
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Jan 15, 2020
92fe766
some fixes, but not yet there
weberse2 Jan 15, 2020
2d1d57e
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Jan 15, 2020
906c262
Merging in changes from pull
bbbales2 Jan 15, 2020
747c73a
Allow vars in sliced argument, fixed (at least temporarily) a couple …
bbbales2 Jan 15, 2020
3da42ff
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Jan 15, 2020
70001ac
make things work with proper cleanup
wds15 Jan 16, 2020
be1ff0c
Changed how the sliced gradients work again
bbbales2 Jan 16, 2020
7a6b558
[Jenkins] auto-formatting by clang-format version 5.0.2-svn328729-1~e…
stan-buildbot Jan 16, 2020
d440aab
Merge branch 'cleanup/proto-parallel-v3' of https://github.com/stan-d…
wds15 Jan 16, 2020
89bad42
ensure proper cleaning of vars on all child threads
wds15 Jan 16, 2020
c1ac292
Adds metaprogramming to accept std vectors and eigen vectors
SteveBronder Jan 17, 2020
4ea5f1c
Remove dead code
SteveBronder Jan 17, 2020
6075dcc
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot Jan 17, 2020
af7791e
idiot slaps keyboard
SteveBronder Jan 17, 2020
103333c
merge to remote
SteveBronder Jan 17, 2020
ab1e653
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot Jan 17, 2020
29f1f71
Merge branch 'develop' into cleanup/proto-parallel-v3
bbbales2 Jan 17, 2020
37b516f
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot Jan 17, 2020
b76cd8c
Added tests for reduce_sum
bbbales2 Jan 17, 2020
6b93f93
[Jenkins] auto-formatting by clang-format version 5.0.2-svn328729-1~e…
stan-buildbot Jan 17, 2020
e804157
Fix cpplint errors
SteveBronder Jan 17, 2020
66acef6
[Jenkins] auto-formatting by clang-format version 5.0.2-svn328729-1~e…
stan-buildbot Jan 17, 2020
906d7b0
Fix headers
SteveBronder Jan 17, 2020
71d8c71
Merge branch 'cleanup/proto-parallel-v3' of github.com:stan-dev/math …
SteveBronder Jan 17, 2020
d05d452
Added msgs argument to reduce_sum
bbbales2 Jan 19, 2020
658755d
Merge commit '10cc6ba675743f09832c6749fbb1a92d74888bd2' into HEAD
yashikno Jan 19, 2020
4e18f1e
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot Jan 19, 2020
d834a4f
Allow arrays of all the Stan types to reduce_sum
bbbales2 Jan 20, 2020
145ae95
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot Jan 20, 2020
9957909
Allocate deep copied varis on no-chain stack
bbbales2 Jan 20, 2020
3ecf79e
[Jenkins] auto-formatting by clang-format version 5.0.2-svn328729-1~e…
stan-buildbot Jan 20, 2020
94989a0
Fixed off by one error on start/end indices passed by reduce_sum to u…
bbbales2 Jan 31, 2020
c9ac276
Merge commit 'dd9774dbc03935433b25a00be21763b43a242191' into HEAD
yashikno Jan 31, 2020
70546f4
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Jan 31, 2020
4cfcfe8
Revert "Fixed off by one error on start/end indices passed by reduce_…
bbbales2 Jan 31, 2020
d46c022
Added extra template conditions for accumulate_adjoints_, count_vars_…
bbbales2 Jan 31, 2020
ccc82e7
Merge branch 'cleanup/proto-parallel-v3' of https://github.com/stan-d…
bbbales2 Jan 31, 2020
144bb9c
Changed how arg_adjoints_ is allocated in reduce_sum and added out of…
bbbales2 Feb 13, 2020
4adf05d
Merge commit '8fbce61858ae95745567d8c5f7e054e3d3b20834' into HEAD
yashikno Feb 13, 2020
80ae1bf
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Feb 13, 2020
7a1a27b
Made deep copy in reduce_sum a little more efficient
bbbales2 Feb 24, 2020
9e15292
Merge branch 'cleanup/proto-parallel-v3' of https://github.com/stan-d…
bbbales2 Feb 24, 2020
b6b215b
Merge commit 'b29eff6b027fb2a1bcc5af53d868ce8024235272' into HEAD
yashikno Feb 24, 2020
66ba573
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Feb 24, 2020
a5a50e2
Merge remote-tracking branch 'origin/develop' into cleanup/proto-para…
SteveBronder Mar 5, 2020
55171a4
Adds pf (yes I know but seriously this time it's good I think) to the…
SteveBronder Mar 6, 2020
bc3b7c6
add inline to functions
SteveBronder Mar 6, 2020
a017619
Add docs for templates in reduce_sum
SteveBronder Mar 9, 2020
2561972
Merge remote-tracking branch 'origin/develop' into cleanup/parallel-v5
SteveBronder Mar 9, 2020
3e36ba8
[Jenkins] auto-formatting by clang-format version 5.0.2-svn328729-1~e…
stan-buildbot Mar 9, 2020
38aa2af
Merge remote-tracking branch 'origin/develop' into cleanup/proto-para…
SteveBronder Mar 9, 2020
b99d299
Merge branch 'cleanup/proto-parallel-v3' into cleanup/parallel-v5
SteveBronder Mar 9, 2020
cc37f22
Add includes for test-headers and add threading environment variable …
SteveBronder Mar 9, 2020
b64fe87
remove -fopenmp from CXXFLAGS
SteveBronder Mar 9, 2020
7d27f90
remove old prim/scal.hpp include from apply_test
SteveBronder Mar 10, 2020
c53d29d
Fix broadcast_array to remove vec_partial_
SteveBronder Mar 10, 2020
fb9f25a
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Mar 10, 2020
ee1e0d7
Merge pull request #1768 from stan-dev/cleanup/parallel-v5
bbbales2 Mar 11, 2020
7aa19ce
Changed reduce_sum so sliced argument can be an array of any Stan type
bbbales2 Mar 13, 2020
cb57e75
Merge commit 'fe3a41c3e854fe604841b237fa0475f26d29fe98' into HEAD
yashikno Mar 13, 2020
09d3f52
[Jenkins] auto-formatting by clang-format version 5.0.2-svn328729-1~e…
stan-buildbot Mar 13, 2020
558c4ba
Broke out out deep_copy_vars, save_varis, count_vars, accumulate adjo…
bbbales2 Mar 20, 2020
5dd3481
Merge commit 'c26cda19159ef0f21e8d4354f3652cb2d9fdc1db' into HEAD
yashikno Mar 20, 2020
5089039
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot Mar 20, 2020
59d23cb
Added tests for deep_copy_vars and save_varis (stan-dev/design-docs p…
bbbales2 Mar 20, 2020
d6e2c00
Removed left_fold function.
bbbales2 Mar 20, 2020
3d03855
Merge branch 'cleanup/proto-parallel-v3' of https://github.com/stan-d…
bbbales2 Mar 20, 2020
2c3a65d
Merge commit 'a9a8fc2dc8059c3597c4e98fdf0c19182a08f21d' into HEAD
yashikno Mar 20, 2020
5f64b49
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Mar 20, 2020
6025286
Merge remote-tracking branch 'origin/develop' into cleanup/proto-para…
bbbales2 Mar 26, 2020
4c30bc1
Added grainsize check, grainsize tests, and more tests for std::vecto…
bbbales2 Mar 26, 2020
c702bd7
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Mar 26, 2020
c1fd34f
Added reduce_sum_static which uses tbb::simple_partitioner to have mo…
bbbales2 Mar 27, 2020
1dcad2a
Merge branch 'cleanup/proto-parallel-v3' of https://github.com/stan-d…
bbbales2 Mar 27, 2020
564ec5e
Merge commit 'b6134fbf1a75d9bfa4716bafc8ced948b794f4b3' into HEAD
yashikno Mar 27, 2020
6743ff6
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Mar 27, 2020
da8332b
Merge remote-tracking branch 'origin/develop' into cleanup/proto-para…
SteveBronder Mar 31, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ pipeline {
stage('Verify changes') {
agent { label 'linux' }
steps {
script {
script {

retry(3) { checkout scm }
sh 'git clean -xffd'
Expand Down Expand Up @@ -240,6 +240,7 @@ pipeline {
steps {
deleteDir()
unstash 'MathSetup'
sh "export STAN_NUM_THREADS=2"
sh "echo CXX=${env.CXX} -Werror > make/local"
sh "echo CPPFLAGS+=-DSTAN_THREADS >> make/local"
runTests("test/unit -f thread")
Expand All @@ -263,6 +264,7 @@ pipeline {
steps {
deleteDirWin()
unstash 'MathSetup'
bat "setx STAN_NUM_THREADS 2"
bat "echo CXX=${env.CXX} -Werror > make/local"
bat "echo CXXFLAGS+=-DSTAN_THREADS >> make/local"
runTestsWin("test/unit -f thread")
Expand All @@ -272,7 +274,7 @@ pipeline {
}
}
stage('Additional merge tests') {
when {
when {
allOf {
anyOf {
branch 'develop'
Expand Down Expand Up @@ -309,12 +311,12 @@ pipeline {
}
}
stage('Upstream tests') {
when {
when {
allOf {
expression {
env.BRANCH_NAME ==~ /PR-\d+/
expression {
env.BRANCH_NAME ==~ /PR-\d+/
}
expression {
expression {
!skipRemainingStages
}
}
Expand Down
2 changes: 1 addition & 1 deletion make/compiler_flags
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ print-compiler-flags:
@echo ' - SUNDIALS ' $(SUNDIALS)
@echo ' - TBB ' $(TBB)
@echo ' - GTEST ' $(GTEST)
@echo ' - STAN_THREADS ' $(STAN_THREADS)
@echo ' - STAN_THREADS ' $(STAN_THREADS)
@echo ' - STAN_OPENCL ' $(STAN_OPENCL)
@echo ' - STAN_MPI ' $(STAN_MPI)
@echo ' Compiler flags (each can be overriden separately):'
Expand Down
2 changes: 2 additions & 0 deletions stan/math/prim/functor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,7 @@
#include <stan/math/prim/functor/mpi_cluster.hpp>
#include <stan/math/prim/functor/mpi_command.hpp>
#include <stan/math/prim/functor/mpi_distributed_apply.hpp>
#include <stan/math/prim/functor/reduce_sum.hpp>
#include <stan/math/prim/functor/reduce_sum_static.hpp>

#endif
272 changes: 272 additions & 0 deletions stan/math/prim/functor/reduce_sum.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
#ifndef STAN_MATH_PRIM_FUNCTOR_REDUCE_SUM_HPP
#define STAN_MATH_PRIM_FUNCTOR_REDUCE_SUM_HPP

#include <stan/math/prim/meta.hpp>

#include <tbb/task_arena.h>
#include <tbb/parallel_reduce.h>
#include <tbb/blocked_range.h>

#include <tuple>
#include <vector>

namespace stan {
namespace math {

namespace internal {

/**
* reduce_sum_impl implementation for any autodiff type.
*
* @tparam ReduceFunction Type of reducer function
* @tparam ReturnType An arithmetic type
* @tparam Vec Type of sliced argument
* @tparam Args Types of shared arguments
*/
template <typename ReduceFunction, typename Enable, typename ReturnType,
typename Vec, typename... Args>
struct reduce_sum_impl {
/**
* Call an instance of the function `ReduceFunction` on every element
* of an input sequence and sum these terms.
*
* This specialization is not parallelized and works for any autodiff types.
*
* An instance, f, of `ReduceFunction` should have the signature:
* T f(int start, int end, Vec&& vmapped_subset, std::ostream* msgs,
* Args&&... args)
*
* `ReduceFunction` must be default constructible without any arguments
*
* Each call to `ReduceFunction` is responsible for computing the
* start through end terms (inclusive) of the overall sum. All args are
* passed from this function through to the `ReduceFunction` instances.
* However, only the start through end (inclusive) elements of the vmapped
* argument are passed to the `ReduceFunction` instances (as the
* `vmapped_subset` argument).
*
* If auto partitioning is true, do the calculation with one
* ReduceFunction call. If false, break work into pieces strictly smaller
* than grainsize.
*
* grainsize must be greater than or equal to 1
*
* @param vmapped Sliced arguments used only in some sum terms
* @param auto_partitioning Work partitioning style (ignored)
* @param grainsize Suggested grainsize for tbb
* @param[in, out] msgs The print stream for warning messages
* @param args Shared arguments used in every sum term
* @return Summation of all terms
*/
return_type_t<Vec, Args...> operator()(Vec&& vmapped, bool auto_partitioning,
int grainsize, std::ostream* msgs,
Args&&... args) const {
const std::size_t num_jobs = vmapped.size();

if (num_jobs == 0) {
return 0.0;
}

if (auto_partitioning) {
return ReduceFunction()(0, vmapped.size() - 1, std::forward<Vec>(vmapped),
msgs, std::forward<Args>(args)...);
} else {
return_type_t<Vec, Args...> sum = 0.0;
for (size_t i = 0; i < (vmapped.size() + grainsize - 1) / grainsize;
++i) {
size_t start = i * grainsize;
size_t end = std::min((i + 1) * grainsize, vmapped.size()) - 1;

std::decay_t<Vec> sub_slice;
sub_slice.reserve(end - start + 1);
for (int i = start; i <= end; ++i) {
sub_slice.emplace_back(vmapped[i]);
}

sum += ReduceFunction()(start, end, std::forward<Vec>(sub_slice), msgs,
std::forward<Args>(args)...);
}
return sum;
}
}
};

/**
* Specialization of reduce_sum_impl for arithmetic types
*
* @tparam ReduceFunction Type of reducer function
* @tparam ReturnType An arithmetic type
* @tparam Vec Type of sliced argument
* @tparam Args Types of shared arguments
*/
template <typename ReduceFunction, typename ReturnType, typename Vec,
typename... Args>
struct reduce_sum_impl<ReduceFunction, require_arithmetic_t<ReturnType>,
ReturnType, Vec, Args...> {
/**
* Internal object meeting the Imperative form requirements of
* `tbb::parallel_reduce`
*
* @note see link [here](https://tinyurl.com/vp7xw2t) for requirements.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bbbales2 check out the link here

The parallel_reduce template has two forms. The functional form is designed to be easy to use in conjunction with lambda expressions. The imperative form is designed to minimize copying of data.

So the answer to "can this be a function" is probably yes however they say the imperative style does less copying. We could always try the functional style and benchmark the diff. idk anything about tbb's internals here so can't say whether the performance would be noticable

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with the imperative form for the promised additional efficiency... I don't need to benchmark this for my taste given that we do what is supposed to be faster and it's fine to me style wise to go with a struct.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aaah, I meant reduce_sum_impl, not recursive_reducer.

*/
struct recursive_reducer {
Vec vmapped_;
std::ostream* msgs_;
std::tuple<Args...> args_tuple_;
return_type_t<Vec, Args...> sum_{0.0};

recursive_reducer(Vec&& vmapped, std::ostream* msgs, Args&&... args)
: vmapped_(std::forward<Vec>(vmapped)),
msgs_(msgs),
args_tuple_(std::forward<Args>(args)...) {}

/*
* This is the copy operator as required for tbb::parallel_reduce
* Imperative form. This requires sum_ be reset to zero.
*/
recursive_reducer(recursive_reducer& other, tbb::split)
: vmapped_(other.vmapped_),
msgs_(other.msgs_),
args_tuple_(other.args_tuple_) {}

/**
* Compute the value and of `ReduceFunction` over the range defined by r
* and accumulate those in member variable sum_. This function may
* be called multiple times per object instantiation (so the sum_
* must be accumulated, not just assigned).
*
* @param r Range over which to compute `ReduceFunction`
*/
void operator()(const tbb::blocked_range<size_t>& r) {
if (r.empty()) {
return;
}

std::decay_t<Vec> sub_slice;
sub_slice.reserve(r.size());
for (int i = r.begin(); i < r.end(); ++i) {
sub_slice.emplace_back(vmapped_[i]);
}

sum_ += apply(
[&](auto&&... args) {
return ReduceFunction()(r.begin(), r.end() - 1, sub_slice, msgs_,
args...);
},
this->args_tuple_);
}

/**
* Join reducers. Accumuluate the value (sum_) of the other reducer.
*
* @param rhs Another partial sum
*/
void join(const recursive_reducer& child) { this->sum_ += child.sum_; }
};

/**
* Call an instance of the function `ReduceFunction` on every element
* of an input sequence and sum these terms.
*
* This specialization is parallelized using tbb and works only for
* arithmetic types.
*
* An instance, f, of `ReduceFunction` should have the signature:
* double f(int start, int end, Vec&& vmapped_subset, std::ostream* msgs,
* Args&&... args)
*
* `ReduceFunction` must be default constructible without any arguments
*
* Each call to `ReduceFunction` is responsible for computing the
* start through end (inclusive) terms of the overall sum. All args are
* passed from this function through to the `ReduceFunction` instances.
* However, only the start through end (inclusive) elements of the vmapped
* argument are passed to the `ReduceFunction` instances (as the
* `vmapped_subset` argument).
*
* This function distributes computation of the desired sum
* over multiple threads by coordinating calls to `ReduceFunction`
* instances.
*
* If auto partitioning is true, break work into pieces automatically,
* taking grainsize as a recommended work size (this process
* is not deterministic). If false, break work deterministically
* into pieces smaller than or equal to grainsize. The execution
* order is non-deterministic.
*
* grainsize must be greater than or equal to 1
*
* @param vmapped Sliced arguments used only in some sum terms
* @param auto_partitioning Work partitioning style
* @param grainsize Suggested grainsize for tbb
* @param[in, out] msgs The print stream for warning messages
* @param args Shared arguments used in every sum term
* @return Summation of all terms
*/
ReturnType operator()(Vec&& vmapped, bool auto_partitioning, int grainsize,
std::ostream* msgs, Args&&... args) const {
const std::size_t num_jobs = vmapped.size();
if (num_jobs == 0) {
return 0.0;
}
recursive_reducer worker(std::forward<Vec>(vmapped), msgs,
std::forward<Args>(args)...);

if (auto_partitioning) {
tbb::parallel_reduce(
tbb::blocked_range<std::size_t>(0, num_jobs, grainsize), worker);
} else {
tbb::simple_partitioner partitioner;
tbb::parallel_deterministic_reduce(
tbb::blocked_range<std::size_t>(0, num_jobs, grainsize), worker,
partitioner);
}

return worker.sum_;
}
};

} // namespace internal

/**
* Call an instance of the function `ReduceFunction` on every element
* of an input sequence and sum these terms.
*
* This defers to reduce_sum_impl for the appropriate implementation
*
* An instance, f, of `ReduceFunction` should have the signature:
* T f(int start, int end, Vec&& vmapped_subset, std::ostream* msgs, Args&&...
* args)
*
* `ReduceFunction` must be default constructible without any arguments
*
* grainsize must be greater than or equal to 1
*
* @tparam ReduceFunction Type of reducer function
* @tparam ReturnType An arithmetic type
* @tparam Vec Type of sliced argument
* @tparam Args Types of shared arguments
* @param vmapped Sliced arguments used only in some sum terms
* @param grainsize Suggested grainsize for tbb
* @param[in, out] msgs The print stream for warning messages
* @param args Shared arguments used in every sum term
* @return Sum of terms
*/
template <typename ReduceFunction, typename Vec,
typename = require_vector_like_t<Vec>, typename... Args>
auto reduce_sum(Vec&& vmapped, int grainsize, std::ostream* msgs,
Args&&... args) {
using return_type = return_type_t<Vec, Args...>;

check_positive("reduce_sum", "grainsize", grainsize);

return internal::reduce_sum_impl<ReduceFunction, void, return_type, Vec,
Args...>()(std::forward<Vec>(vmapped), true,
grainsize, msgs,
std::forward<Args>(args)...);
}

} // namespace math
} // namespace stan

#endif
57 changes: 57 additions & 0 deletions stan/math/prim/functor/reduce_sum_static.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#ifndef STAN_MATH_PRIM_FUNCTOR_REDUCE_SUM_STATIC_HPP
#define STAN_MATH_PRIM_FUNCTOR_REDUCE_SUM_STATIC_HPP

#include <stan/math/prim/meta.hpp>

#include <tbb/task_arena.h>
#include <tbb/parallel_reduce.h>
#include <tbb/blocked_range.h>

#include <tuple>
#include <vector>

namespace stan {
namespace math {

/**
* Call an instance of the function `ReduceFunction` on every element
* of an input sequence and sum these terms.
*
* This defers to reduce_sum_impl for the appropriate implementation
*
* An instance, f, of `ReduceFunction` should have the signature:
* T f(int start, int end, Vec&& vmapped_subset, std::ostream* msgs, Args&&...
* args)
*
* `ReduceFunction` must be default constructible without any arguments
*
* grainsize must be greater than or equal to 1
*
* @tparam ReduceFunction Type of reducer function
* @tparam ReturnType An arithmetic type
* @tparam Vec Type of sliced argument
* @tparam Args Types of shared arguments
* @param vmapped Sliced arguments used only in some sum terms
* @param grainsize Suggested grainsize for tbb
* @param[in, out] msgs The print stream for warning messages
* @param args Shared arguments used in every sum term
* @return Sum of terms
*/
template <typename ReduceFunction, typename Vec,
typename = require_vector_like_t<Vec>, typename... Args>
auto reduce_sum_static(Vec&& vmapped, int grainsize, std::ostream* msgs,
Args&&... args) {
using return_type = return_type_t<Vec, Args...>;

check_positive("reduce_sum", "grainsize", grainsize);

return internal::reduce_sum_impl<ReduceFunction, void, return_type, Vec,
Args...>()(std::forward<Vec>(vmapped), false,
grainsize, msgs,
std::forward<Args>(args)...);
}

} // namespace math
} // namespace stan

#endif
Loading