Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
7a955d9
0;10;1c# This is a combination of 14 commits.
drezap Apr 23, 2026
a333398
change return to referenced argument
drezap Apr 26, 2026
794e2a5
intermediate commit playing with pointers
drezap Apr 27, 2026
bee0559
ok, numerical tests 1,..,10 pass _threading_
drezap Apr 27, 2026
99114d8
remove dead code
drezap Apr 27, 2026
2db1474
begin benchmarks
drezap Apr 27, 2026
0d66c6f
try no threading
drezap Apr 27, 2026
73c3c76
scale to 10mm obs and 2^17 threads
drezap Apr 27, 2026
c539a38
add only 10 test for multithreading
drezap Apr 28, 2026
442e814
add some unthreaded tests, for varying N and numerical value of exp(10)
drezap Apr 29, 2026
e9583fa
scale N tests
drezap Apr 29, 2026
c1f263f
unthreaded tests compile
drezap Apr 29, 2026
5b15382
remove print statement
drezap Apr 29, 2026
b331dd8
Merge commit '105bfcc395c1ab824dcb588324dd57724a1cf527' into HEAD
yashikno Apr 29, 2026
b1c8b69
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Apr 29, 2026
84675ab
fix unit test names, rusty, sorry
drezap Apr 30, 2026
aec6db9
Merge branch 'feature/issue-3311-test-thread-tbb-exp' of github.com:d…
drezap Apr 30, 2026
94db162
fix return to satisfy compiler
drezap Apr 30, 2026
3fcf593
change exp to std::exp for numerical accuracy
drezap Apr 30, 2026
4b08e16
investigate drift in tests
drezap May 1, 2026
d143a04
ifdef0endif don't compile mix tests, they're only supporting complex
drezap May 1, 2026
6837a52
change investigate drift naming conventions
drezap May 1, 2026
7ca2f6d
perfect forwarding in parallelizing class
drezap May 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 151 additions & 1 deletion stan/math/prim/fun/exp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,156 @@
#include <complex>
#include <limits>

#ifdef STAN_THREADS // threaded block
#include <stan/math/prim/core.hpp>
#include <tbb/parallel_for.h>
#include <tbb/blocked_range.h>
#include <stan/math/prim/core/init_threadpool_tbb.hpp>

namespace stan {
namespace math {

/**
* Return the natural (base e) exponentiation of the specified
* complex argument.
*
* @tparam V `Arithmetic` type
* @param x input
* @return natural exponentiation of specified number
*/
template <typename T, require_arithmetic_t<T>* = nullptr>
inline auto exp(T&& x) {
return std::exp(x);
}

/**
* Return the natural (base e) complex exponentiation of the specified
* complex argument.
*
* @tparam V `complex<Arithmetic>` type
* @param x complex number
* @return natural exponentiation of specified complex number
* @see documentation for `std::complex` for boundary condition and
* branch cut details
*/
template <typename T, require_complex_bt<std::is_arithmetic, T>* = nullptr>
inline auto exp(T&& x) {
return std::exp(x);
}

/**
* Structure to wrap `exp()` so that it can be
* vectorized.
*/
struct exp_fun {
/**
* Return the exponential of the specified scalar argument.
*
* @tparam T type of argument
* @param[in] x argument
* @return Exponential of argument.
*/
template <typename T>
static inline auto fun(T&& x) {
return exp(std::forward<T>(x));
}
};

// implement a class so we can parallelize a for loop of evaluating
// exp
template <typename Container>
class apply_exp {
Container const my_a;

public:
Container operator()(const tbb::blocked_range<std::size_t>& r) const {
Container a = my_a;
Container a_out = my_a;
for (std::size_t i = r.begin(); i != r.end(); ++i) {
a_out[i] = std::exp(a[i]);
}
return a;
}
apply_exp<Container>(Container&& a) : my_a(a) {}
};

/**
* Return the elementwise `exp()` of the specified argument,
* which may be a scalar or any Stan container of numeric scalars.
* The return type is the same as the argument type.
*
* @tparam Container type of container
* @param[in] x container
* @return Elementwise application of exponentiation to the argument.
*/
template <typename Container, require_ad_container_t<Container>* = nullptr>
inline auto exp(Container&& x) {
return apply_scalar_unary<exp_fun, Container>::apply(
std::forward<Container>(x));
}

/**x
* Version of `exp()` that accepts std::vectors, Eigen Matrix/Array objects
* or expressions, and containers of these.
*
* @tparam Container Type of x
* @param x Container
* @return Elementwise application of exponentiation to the argument.
*/
// experimental function
template <typename Container,
require_container_bt<std::is_arithmetic, Container>* = nullptr>
inline auto exp(Container&& x) {
std::size_t N = x.size();
tbb::parallel_for(tbb::blocked_range<size_t>(0, N),
typename apply_exp<Container>::apply_exp(x));
return x;
}

namespace internal {
/**
* Return the natural (base e) complex exponentiation of the specified
* complex argument.
*
* @tparam V value type (must be Stan autodiff type)
* @param z complex number
* @return natural exponentiation of specified complex number
* @see documentation for `std::complex` for boundary condition and
* branch cut details
*/
template <typename V>
inline std::complex<V> complex_exp(const std::complex<V>& z) {
if (is_inf(z.real()) && z.real() > 0) {
if (is_nan(z.imag()) || z.imag() == 0) {
// (+inf, nan), (+inf, 0)
return z;
} else if (is_inf(z.imag()) && z.imag() > 0) {
// (+inf, +inf)
return {z.real(), std::numeric_limits<double>::quiet_NaN()};
} else if (is_inf(z.imag()) && z.imag() < 0) {
// (+inf, -inf)
return {std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN()};
}
}
if (is_inf(z.real()) && z.real() < 0
&& (is_nan(z.imag()) || is_inf(z.imag()))) {
// (-inf, nan), (-inf, -inf), (-inf, inf)
return {0, 0};
}
if (is_nan(z.real()) && z.imag() == -0.0) {
// (nan, -0)
return z;
}
V exp_re = exp(z.real());
return {exp_re * cos(z.imag()), exp_re * sin(z.imag())};
}
} // namespace internal
} // namespace math
} // namespace stan

#else // unthreaded code

namespace stan {
namespace math {

Expand Down Expand Up @@ -129,5 +279,5 @@ inline std::complex<V> complex_exp(const std::complex<V>& z) {
} // namespace internal
} // namespace math
} // namespace stan

#endif
#endif
2 changes: 2 additions & 0 deletions test/unit/math/mix/fun/exp_test.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <test/unit/math/test_ad.hpp>

#if 0
TEST(mathMixMatFun, exp) {
auto f = [](const auto& x) {
using stan::math::exp;
Expand All @@ -18,3 +19,4 @@ TEST(mathMixMatFun, exp) {
stan::test::expect_ad_vector_matvar(f, stan::math::to_vector(com_args));
stan::test::expect_ad_vector_matvar(f, stan::math::to_vector(args));
}
#endif
Loading
Loading