-
-
Notifications
You must be signed in to change notification settings - Fork 194
[WIP] Parallel Prototype #1616
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
[WIP] Parallel Prototype #1616
Changes from all commits
Commits
Show all changes
135 commits
Select commit
Hold shift + click to select a range
8099c90
prim version
wds15 7027c12
make nested stuff work
wds15 302cd71
extend signature and test nested parallel AD
wds15 9740b45
add hierarchical example
wds15 e47ca13
add recover_memory_global
wds15 e94be85
Merge branch 'proto-parallel-v3' of https://github.com/stan-dev/math …
wds15 9565db9
add file
wds15 6faa614
Merge branch 'proto-parallel-v3' of https://github.com/stan-dev/math …
wds15 acbb3dc
fix
wds15 ca762fe
Merge branch 'proto-parallel-v3' of https://github.com/stan-dev/math …
wds15 bc2a601
fix
wds15 eec7472
remove debugging msg
wds15 2d696c8
omit recover_memory_global which is not needed
wds15 8a1d99e
Merge branch 'proto-parallel-v3' of https://github.com/stan-dev/math …
wds15 6d8ad8a
aggregate more efficiently the partial sums
wds15 1a54bde
simplify how values are copied
wds15 f689c93
const correctness
wds15 0da9de0
make code more generic (some meta magic bits are missing)
wds15 73fb5ae
make parallel reduce sum work with posted example... need more meta-p…
wds15 e7a83c6
rename to reduce_sum
wds15 2a370e7
add up to 4 arguments for reduce function
wds15 bd13907
make arguments optional
wds15 e1deb4f
more doc and const declares
wds15 358525c
doc
wds15 13641ea
Merge remote-tracking branch 'origin/develop' into proto-parallel-v3
wds15 5f76ca2
generalize possible input data structures
wds15 d9e5276
refactor such that any data strcuture (contained in an array) can be …
wds15 5e01bf0
fix looping order error
wds15 f2ed8a9
still struggling with performance regression
wds15 a3b4ad4
Merge remote-tracking branch 'origin/develop' into proto-parallel-v3
wds15 f3b37ae
start going back to better abstracted code
wds15 940fc82
move to shared_ptr, simplify counstructor call
wds15 8b66f94
move to partials being stored as flat vector
wds15 e56cdf2
get rid of obsolete code
wds15 2cece58
add preliminary support for structured slicing arguments which are no…
wds15 71c0196
add special case of local_operands_and_partials for non-var in a work…
wds15 e2b20d7
start cleanup bits of the reduce_sum code. Mostly to get an idea of w…
SteveBronder c0c118a
start cleanup bits of the reduce_sum code. Mostly to get an idea of w…
SteveBronder ff0f582
start cleanup bits of the reduce_sum code. Mostly to get an idea of w…
SteveBronder 96e8c8e
start cleanup bits of the reduce_sum code. Mostly to get an idea of w…
SteveBronder 3d62fa8
loose stuff needed removed
SteveBronder cbcf42f
Added variadic implementation of rev parallel_sum
bbbales2 8f90d8c
Adds enable_ifs to reduce_sum_impl
SteveBronder 617d6e2
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot fbfa9ba
Fixup little templating things
SteveBronder 008255e
merge to develop
SteveBronder 00e5743
Make accumulate adjoints accept more types
SteveBronder 623ed9e
Catch arithmetics in count_var_impl
SteveBronder 16db0f2
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot 9305fb0
move operator_paren checker out one scope
SteveBronder bb9a1b2
Merge branch 'cleanup/proto-parallel-v3' of github.com:stan-dev/math …
SteveBronder 430ba3f
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot 0248bcb
add left_fold helper func
SteveBronder b2c1dec
Merge branch 'cleanup/proto-parallel-v3' of github.com:stan-dev/math …
SteveBronder cf8b530
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot a27047f
Get slicing tests to pass
SteveBronder 2dd8881
Merge branch 'cleanup/proto-parallel-v3' of github.com:stan-dev/math …
SteveBronder 5e3e078
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot 92fe766
some fixes, but not yet there
weberse2 2d1d57e
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot 906c262
Merging in changes from pull
bbbales2 747c73a
Allow vars in sliced argument, fixed (at least temporarily) a couple …
bbbales2 3da42ff
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot 70001ac
make things work with proper cleanup
wds15 be1ff0c
Changed how the sliced gradients work again
bbbales2 7a6b558
[Jenkins] auto-formatting by clang-format version 5.0.2-svn328729-1~e…
stan-buildbot d440aab
Merge branch 'cleanup/proto-parallel-v3' of https://github.com/stan-d…
wds15 89bad42
ensure proper cleaning of vars on all child threads
wds15 c1ac292
Adds metaprogramming to accept std vectors and eigen vectors
SteveBronder 4ea5f1c
Remove dead code
SteveBronder 6075dcc
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot af7791e
idiot slaps keyboard
SteveBronder 103333c
merge to remote
SteveBronder ab1e653
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot 29f1f71
Merge branch 'develop' into cleanup/proto-parallel-v3
bbbales2 37b516f
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot b76cd8c
Added tests for reduce_sum
bbbales2 6b93f93
[Jenkins] auto-formatting by clang-format version 5.0.2-svn328729-1~e…
stan-buildbot e804157
Fix cpplint errors
SteveBronder 66acef6
[Jenkins] auto-formatting by clang-format version 5.0.2-svn328729-1~e…
stan-buildbot 906d7b0
Fix headers
SteveBronder 71d8c71
Merge branch 'cleanup/proto-parallel-v3' of github.com:stan-dev/math …
SteveBronder d05d452
Added msgs argument to reduce_sum
bbbales2 658755d
Merge commit '10cc6ba675743f09832c6749fbb1a92d74888bd2' into HEAD
yashikno 4e18f1e
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot d834a4f
Allow arrays of all the Stan types to reduce_sum
bbbales2 145ae95
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot 9957909
Allocate deep copied varis on no-chain stack
bbbales2 3ecf79e
[Jenkins] auto-formatting by clang-format version 5.0.2-svn328729-1~e…
stan-buildbot 94989a0
Fixed off by one error on start/end indices passed by reduce_sum to u…
bbbales2 c9ac276
Merge commit 'dd9774dbc03935433b25a00be21763b43a242191' into HEAD
yashikno 70546f4
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot 4cfcfe8
Revert "Fixed off by one error on start/end indices passed by reduce_…
bbbales2 d46c022
Added extra template conditions for accumulate_adjoints_, count_vars_…
bbbales2 ccc82e7
Merge branch 'cleanup/proto-parallel-v3' of https://github.com/stan-d…
bbbales2 144bb9c
Changed how arg_adjoints_ is allocated in reduce_sum and added out of…
bbbales2 4adf05d
Merge commit '8fbce61858ae95745567d8c5f7e054e3d3b20834' into HEAD
yashikno 80ae1bf
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot 7a1a27b
Made deep copy in reduce_sum a little more efficient
bbbales2 9e15292
Merge branch 'cleanup/proto-parallel-v3' of https://github.com/stan-d…
bbbales2 b6b215b
Merge commit 'b29eff6b027fb2a1bcc5af53d868ce8024235272' into HEAD
yashikno 66ba573
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot a5a50e2
Merge remote-tracking branch 'origin/develop' into cleanup/proto-para…
SteveBronder 55171a4
Adds pf (yes I know but seriously this time it's good I think) to the…
SteveBronder bc3b7c6
add inline to functions
SteveBronder a017619
Add docs for templates in reduce_sum
SteveBronder 2561972
Merge remote-tracking branch 'origin/develop' into cleanup/parallel-v5
SteveBronder 3e36ba8
[Jenkins] auto-formatting by clang-format version 5.0.2-svn328729-1~e…
stan-buildbot 38aa2af
Merge remote-tracking branch 'origin/develop' into cleanup/proto-para…
SteveBronder b99d299
Merge branch 'cleanup/proto-parallel-v3' into cleanup/parallel-v5
SteveBronder cc37f22
Add includes for test-headers and add threading environment variable …
SteveBronder b64fe87
remove -fopenmp from CXXFLAGS
SteveBronder 7d27f90
remove old prim/scal.hpp include from apply_test
SteveBronder c53d29d
Fix broadcast_array to remove vec_partial_
SteveBronder fb9f25a
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot ee1e0d7
Merge pull request #1768 from stan-dev/cleanup/parallel-v5
bbbales2 7aa19ce
Changed reduce_sum so sliced argument can be an array of any Stan type
bbbales2 cb57e75
Merge commit 'fe3a41c3e854fe604841b237fa0475f26d29fe98' into HEAD
yashikno 09d3f52
[Jenkins] auto-formatting by clang-format version 5.0.2-svn328729-1~e…
stan-buildbot 558c4ba
Broke out out deep_copy_vars, save_varis, count_vars, accumulate adjo…
bbbales2 5dd3481
Merge commit 'c26cda19159ef0f21e8d4354f3652cb2d9fdc1db' into HEAD
yashikno 5089039
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot 59d23cb
Added tests for deep_copy_vars and save_varis (stan-dev/design-docs p…
bbbales2 d6e2c00
Removed left_fold function.
bbbales2 3d03855
Merge branch 'cleanup/proto-parallel-v3' of https://github.com/stan-d…
bbbales2 2c3a65d
Merge commit 'a9a8fc2dc8059c3597c4e98fdf0c19182a08f21d' into HEAD
yashikno 5f64b49
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot 6025286
Merge remote-tracking branch 'origin/develop' into cleanup/proto-para…
bbbales2 4c30bc1
Added grainsize check, grainsize tests, and more tests for std::vecto…
bbbales2 c702bd7
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot c1fd34f
Added reduce_sum_static which uses tbb::simple_partitioner to have mo…
bbbales2 1dcad2a
Merge branch 'cleanup/proto-parallel-v3' of https://github.com/stan-d…
bbbales2 564ec5e
Merge commit 'b6134fbf1a75d9bfa4716bafc8ced948b794f4b3' into HEAD
yashikno 6743ff6
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot da8332b
Merge remote-tracking branch 'origin/develop' into cleanup/proto-para…
SteveBronder File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,272 @@ | ||
#ifndef STAN_MATH_PRIM_FUNCTOR_REDUCE_SUM_HPP | ||
#define STAN_MATH_PRIM_FUNCTOR_REDUCE_SUM_HPP | ||
|
||
#include <stan/math/prim/meta.hpp> | ||
|
||
#include <tbb/task_arena.h> | ||
#include <tbb/parallel_reduce.h> | ||
#include <tbb/blocked_range.h> | ||
|
||
#include <tuple> | ||
#include <vector> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
namespace internal { | ||
|
||
/** | ||
* reduce_sum_impl implementation for any autodiff type. | ||
* | ||
* @tparam ReduceFunction Type of reducer function | ||
* @tparam ReturnType An arithmetic type | ||
* @tparam Vec Type of sliced argument | ||
* @tparam Args Types of shared arguments | ||
*/ | ||
template <typename ReduceFunction, typename Enable, typename ReturnType, | ||
typename Vec, typename... Args> | ||
struct reduce_sum_impl { | ||
/** | ||
* Call an instance of the function `ReduceFunction` on every element | ||
* of an input sequence and sum these terms. | ||
* | ||
* This specialization is not parallelized and works for any autodiff types. | ||
* | ||
* An instance, f, of `ReduceFunction` should have the signature: | ||
* T f(int start, int end, Vec&& vmapped_subset, std::ostream* msgs, | ||
* Args&&... args) | ||
* | ||
* `ReduceFunction` must be default constructible without any arguments | ||
* | ||
* Each call to `ReduceFunction` is responsible for computing the | ||
* start through end terms (inclusive) of the overall sum. All args are | ||
* passed from this function through to the `ReduceFunction` instances. | ||
* However, only the start through end (inclusive) elements of the vmapped | ||
* argument are passed to the `ReduceFunction` instances (as the | ||
* `vmapped_subset` argument). | ||
* | ||
* If auto partitioning is true, do the calculation with one | ||
* ReduceFunction call. If false, break work into pieces strictly smaller | ||
* than grainsize. | ||
* | ||
* grainsize must be greater than or equal to 1 | ||
* | ||
* @param vmapped Sliced arguments used only in some sum terms | ||
* @param auto_partitioning Work partitioning style (ignored) | ||
* @param grainsize Suggested grainsize for tbb | ||
* @param[in, out] msgs The print stream for warning messages | ||
* @param args Shared arguments used in every sum term | ||
* @return Summation of all terms | ||
*/ | ||
return_type_t<Vec, Args...> operator()(Vec&& vmapped, bool auto_partitioning, | ||
int grainsize, std::ostream* msgs, | ||
Args&&... args) const { | ||
const std::size_t num_jobs = vmapped.size(); | ||
|
||
if (num_jobs == 0) { | ||
return 0.0; | ||
} | ||
|
||
if (auto_partitioning) { | ||
return ReduceFunction()(0, vmapped.size() - 1, std::forward<Vec>(vmapped), | ||
msgs, std::forward<Args>(args)...); | ||
} else { | ||
return_type_t<Vec, Args...> sum = 0.0; | ||
for (size_t i = 0; i < (vmapped.size() + grainsize - 1) / grainsize; | ||
++i) { | ||
size_t start = i * grainsize; | ||
size_t end = std::min((i + 1) * grainsize, vmapped.size()) - 1; | ||
|
||
std::decay_t<Vec> sub_slice; | ||
sub_slice.reserve(end - start + 1); | ||
for (int i = start; i <= end; ++i) { | ||
sub_slice.emplace_back(vmapped[i]); | ||
} | ||
|
||
sum += ReduceFunction()(start, end, std::forward<Vec>(sub_slice), msgs, | ||
std::forward<Args>(args)...); | ||
} | ||
return sum; | ||
} | ||
} | ||
}; | ||
|
||
/** | ||
* Specialization of reduce_sum_impl for arithmetic types | ||
* | ||
* @tparam ReduceFunction Type of reducer function | ||
* @tparam ReturnType An arithmetic type | ||
* @tparam Vec Type of sliced argument | ||
* @tparam Args Types of shared arguments | ||
*/ | ||
template <typename ReduceFunction, typename ReturnType, typename Vec, | ||
typename... Args> | ||
struct reduce_sum_impl<ReduceFunction, require_arithmetic_t<ReturnType>, | ||
ReturnType, Vec, Args...> { | ||
/** | ||
* Internal object meeting the Imperative form requirements of | ||
* `tbb::parallel_reduce` | ||
* | ||
* @note see link [here](https://tinyurl.com/vp7xw2t) for requirements. | ||
*/ | ||
struct recursive_reducer { | ||
Vec vmapped_; | ||
std::ostream* msgs_; | ||
std::tuple<Args...> args_tuple_; | ||
return_type_t<Vec, Args...> sum_{0.0}; | ||
|
||
recursive_reducer(Vec&& vmapped, std::ostream* msgs, Args&&... args) | ||
: vmapped_(std::forward<Vec>(vmapped)), | ||
msgs_(msgs), | ||
args_tuple_(std::forward<Args>(args)...) {} | ||
|
||
/* | ||
* This is the copy operator as required for tbb::parallel_reduce | ||
* Imperative form. This requires sum_ be reset to zero. | ||
*/ | ||
recursive_reducer(recursive_reducer& other, tbb::split) | ||
: vmapped_(other.vmapped_), | ||
msgs_(other.msgs_), | ||
args_tuple_(other.args_tuple_) {} | ||
|
||
/** | ||
* Compute the value and of `ReduceFunction` over the range defined by r | ||
* and accumulate those in member variable sum_. This function may | ||
* be called multiple times per object instantiation (so the sum_ | ||
* must be accumulated, not just assigned). | ||
* | ||
* @param r Range over which to compute `ReduceFunction` | ||
*/ | ||
void operator()(const tbb::blocked_range<size_t>& r) { | ||
if (r.empty()) { | ||
return; | ||
} | ||
|
||
std::decay_t<Vec> sub_slice; | ||
sub_slice.reserve(r.size()); | ||
for (int i = r.begin(); i < r.end(); ++i) { | ||
sub_slice.emplace_back(vmapped_[i]); | ||
} | ||
|
||
sum_ += apply( | ||
[&](auto&&... args) { | ||
return ReduceFunction()(r.begin(), r.end() - 1, sub_slice, msgs_, | ||
args...); | ||
}, | ||
this->args_tuple_); | ||
} | ||
|
||
/** | ||
* Join reducers. Accumuluate the value (sum_) of the other reducer. | ||
* | ||
* @param rhs Another partial sum | ||
*/ | ||
void join(const recursive_reducer& child) { this->sum_ += child.sum_; } | ||
}; | ||
|
||
/** | ||
* Call an instance of the function `ReduceFunction` on every element | ||
* of an input sequence and sum these terms. | ||
* | ||
* This specialization is parallelized using tbb and works only for | ||
* arithmetic types. | ||
* | ||
* An instance, f, of `ReduceFunction` should have the signature: | ||
* double f(int start, int end, Vec&& vmapped_subset, std::ostream* msgs, | ||
* Args&&... args) | ||
* | ||
* `ReduceFunction` must be default constructible without any arguments | ||
* | ||
* Each call to `ReduceFunction` is responsible for computing the | ||
* start through end (inclusive) terms of the overall sum. All args are | ||
* passed from this function through to the `ReduceFunction` instances. | ||
* However, only the start through end (inclusive) elements of the vmapped | ||
* argument are passed to the `ReduceFunction` instances (as the | ||
* `vmapped_subset` argument). | ||
* | ||
* This function distributes computation of the desired sum | ||
* over multiple threads by coordinating calls to `ReduceFunction` | ||
* instances. | ||
* | ||
* If auto partitioning is true, break work into pieces automatically, | ||
* taking grainsize as a recommended work size (this process | ||
* is not deterministic). If false, break work deterministically | ||
* into pieces smaller than or equal to grainsize. The execution | ||
* order is non-deterministic. | ||
* | ||
* grainsize must be greater than or equal to 1 | ||
* | ||
* @param vmapped Sliced arguments used only in some sum terms | ||
* @param auto_partitioning Work partitioning style | ||
* @param grainsize Suggested grainsize for tbb | ||
* @param[in, out] msgs The print stream for warning messages | ||
* @param args Shared arguments used in every sum term | ||
* @return Summation of all terms | ||
*/ | ||
ReturnType operator()(Vec&& vmapped, bool auto_partitioning, int grainsize, | ||
std::ostream* msgs, Args&&... args) const { | ||
const std::size_t num_jobs = vmapped.size(); | ||
if (num_jobs == 0) { | ||
return 0.0; | ||
} | ||
recursive_reducer worker(std::forward<Vec>(vmapped), msgs, | ||
std::forward<Args>(args)...); | ||
|
||
if (auto_partitioning) { | ||
tbb::parallel_reduce( | ||
tbb::blocked_range<std::size_t>(0, num_jobs, grainsize), worker); | ||
} else { | ||
tbb::simple_partitioner partitioner; | ||
tbb::parallel_deterministic_reduce( | ||
tbb::blocked_range<std::size_t>(0, num_jobs, grainsize), worker, | ||
partitioner); | ||
} | ||
|
||
return worker.sum_; | ||
} | ||
}; | ||
|
||
} // namespace internal | ||
|
||
/** | ||
* Call an instance of the function `ReduceFunction` on every element | ||
* of an input sequence and sum these terms. | ||
* | ||
* This defers to reduce_sum_impl for the appropriate implementation | ||
* | ||
* An instance, f, of `ReduceFunction` should have the signature: | ||
* T f(int start, int end, Vec&& vmapped_subset, std::ostream* msgs, Args&&... | ||
* args) | ||
* | ||
* `ReduceFunction` must be default constructible without any arguments | ||
* | ||
* grainsize must be greater than or equal to 1 | ||
* | ||
* @tparam ReduceFunction Type of reducer function | ||
* @tparam ReturnType An arithmetic type | ||
* @tparam Vec Type of sliced argument | ||
* @tparam Args Types of shared arguments | ||
* @param vmapped Sliced arguments used only in some sum terms | ||
* @param grainsize Suggested grainsize for tbb | ||
* @param[in, out] msgs The print stream for warning messages | ||
* @param args Shared arguments used in every sum term | ||
* @return Sum of terms | ||
*/ | ||
template <typename ReduceFunction, typename Vec, | ||
typename = require_vector_like_t<Vec>, typename... Args> | ||
auto reduce_sum(Vec&& vmapped, int grainsize, std::ostream* msgs, | ||
Args&&... args) { | ||
using return_type = return_type_t<Vec, Args...>; | ||
|
||
check_positive("reduce_sum", "grainsize", grainsize); | ||
|
||
return internal::reduce_sum_impl<ReduceFunction, void, return_type, Vec, | ||
Args...>()(std::forward<Vec>(vmapped), true, | ||
grainsize, msgs, | ||
std::forward<Args>(args)...); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
#ifndef STAN_MATH_PRIM_FUNCTOR_REDUCE_SUM_STATIC_HPP | ||
#define STAN_MATH_PRIM_FUNCTOR_REDUCE_SUM_STATIC_HPP | ||
|
||
#include <stan/math/prim/meta.hpp> | ||
|
||
#include <tbb/task_arena.h> | ||
#include <tbb/parallel_reduce.h> | ||
#include <tbb/blocked_range.h> | ||
|
||
#include <tuple> | ||
#include <vector> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Call an instance of the function `ReduceFunction` on every element | ||
* of an input sequence and sum these terms. | ||
* | ||
* This defers to reduce_sum_impl for the appropriate implementation | ||
* | ||
* An instance, f, of `ReduceFunction` should have the signature: | ||
* T f(int start, int end, Vec&& vmapped_subset, std::ostream* msgs, Args&&... | ||
* args) | ||
* | ||
* `ReduceFunction` must be default constructible without any arguments | ||
* | ||
* grainsize must be greater than or equal to 1 | ||
* | ||
* @tparam ReduceFunction Type of reducer function | ||
* @tparam ReturnType An arithmetic type | ||
* @tparam Vec Type of sliced argument | ||
* @tparam Args Types of shared arguments | ||
* @param vmapped Sliced arguments used only in some sum terms | ||
* @param grainsize Suggested grainsize for tbb | ||
* @param[in, out] msgs The print stream for warning messages | ||
* @param args Shared arguments used in every sum term | ||
* @return Sum of terms | ||
*/ | ||
template <typename ReduceFunction, typename Vec, | ||
typename = require_vector_like_t<Vec>, typename... Args> | ||
auto reduce_sum_static(Vec&& vmapped, int grainsize, std::ostream* msgs, | ||
Args&&... args) { | ||
using return_type = return_type_t<Vec, Args...>; | ||
|
||
check_positive("reduce_sum", "grainsize", grainsize); | ||
|
||
return internal::reduce_sum_impl<ReduceFunction, void, return_type, Vec, | ||
Args...>()(std::forward<Vec>(vmapped), false, | ||
grainsize, msgs, | ||
std::forward<Args>(args)...); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bbbales2 check out the link here
So the answer to "can this be a function" is probably yes however they say the imperative style does less copying. We could always try the functional style and benchmark the diff. idk anything about tbb's internals here so can't say whether the performance would be noticable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went with the imperative form for the promised additional efficiency... I don't need to benchmark this for my taste given that we do what is supposed to be faster and it's fine to me style wise to go with a struct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aaah, I meant reduce_sum_impl, not recursive_reducer.