Skip to content

Commit

Permalink
Variadic autodiff test framework checkpoint. About to make some major…
Browse files Browse the repository at this point in the history
… changes so I just wanted to save a copy of this (Issue #993)
  • Loading branch information
bbbales2 committed Aug 22, 2018
1 parent 8d2d46b commit 34a163c
Show file tree
Hide file tree
Showing 9 changed files with 318 additions and 69 deletions.
69 changes: 58 additions & 11 deletions stan/math/prim/arr/fun/promote_double_to_T.hpp
@@ -1,24 +1,71 @@
#ifndef STAN_MATH_PRIM_ARR_FUN_PROMOTE_DOUBLE_TO_T_HPP
#define STAN_MATH_PRIM_ARR_FUN_PROMOTE_DOUBLE_TO_T_HPP

#include <tuple>
#include <vector>
#include <iostream>

namespace stan {
namespace math {

template <typename T, typename R>
inline const std::vector<R>& promote_double_to_T(const std::vector<R>& x) {
return x;
template <typename T, typename... Pargs>
auto promote_double_to_T_impl_impl(std::tuple<Pargs...> output) {
return output;
}

template <typename T>
inline std::vector<T> promote_double_to_T(const std::vector<double>& x) {
std::vector<T> out;
out.reserve(x.size());
for (int i = 0; i < x.size(); ++i) {
out.push_back(x[i]);
}
return out;
template <typename T, typename R, typename... Pargs, typename... Targs>
auto promote_double_to_T_impl_impl(std::tuple<Pargs...> output,
const std::vector<R>& arg,
const Targs&... args);
template <typename T, typename... Pargs, typename... Targs>
auto promote_double_to_T_impl_impl(std::tuple<Pargs...> output,
const double& arg, const Targs&... args);
template <typename T, typename R, typename... Pargs, typename... Targs>
auto promote_double_to_T_impl_impl(std::tuple<Pargs...> output, const R& arg,
const Targs&... args);

template <typename T, typename... Pargs, typename... Targs>
auto promote_double_to_T_impl_impl(std::tuple<Pargs...> output,
const std::vector<double>& arg,
const Targs&... args) {
std::vector<T> output_arg;
output_arg.reserve(arg.size());
for (int i = 0; i < arg.size(); ++i)
output_arg.push_back(arg[i]);
return promote_double_to_T_impl_impl<T>(
std::tuple_cat(output, std::make_tuple(output_arg)), args...);
}

template <typename T, typename R, typename... Pargs, typename... Targs>
auto promote_double_to_T_impl_impl(std::tuple<Pargs...> output,
const std::vector<R>& arg,
const Targs&... args) {
return promote_double_to_T_impl_impl<T>(output, args...);
}

template <typename T, typename... Pargs, typename... Targs>
auto promote_double_to_T_impl_impl(std::tuple<Pargs...> output,
const double& arg, const Targs&... args) {
return promote_double_to_T_impl_impl<T>(
std::tuple_cat(output, std::make_tuple(T(arg))), args...);
}

template <typename T, typename R, typename... Pargs, typename... Targs>
auto promote_double_to_T_impl_impl(std::tuple<Pargs...> output, const R& arg,
const Targs&... args) {
return promote_double_to_T_impl_impl<T>(output, args...);
}

template <typename T, std::size_t... I, typename... Targs>
auto promote_double_to_T_impl(std::index_sequence<I...>,
const std::tuple<Targs...>& args) {
return promote_double_to_T_impl_impl<T>(std::tuple(), std::get<I>(args)...);
}

template <typename T, typename... Targs>
auto promote_double_to_T(const std::tuple<Targs...>& args) {
return promote_double_to_T_impl<T>(
std::make_index_sequence<sizeof...(Targs)>{}, args);
}

} // namespace math
Expand Down
55 changes: 32 additions & 23 deletions stan/math/prim/arr/fun/variable_adapter.hpp
Expand Up @@ -9,81 +9,85 @@
namespace stan {
namespace math {

template <typename... Targs>
template <typename T, typename... Targs>
class variable_adapter {
public:
std::tuple<Targs...> args_;

size_t size_;

protected:
template <typename T, typename... Pargs>
template <typename... Pargs>
size_t count_memory(size_t count, const std::vector<T>& x,
const Pargs&... args) {
return count_memory(count + x.size(), args...);
}

template <typename... Pargs>
size_t count_memory(size_t count, const std::vector<int>& x,
template <typename R, typename... Pargs>
size_t count_memory(size_t count, const std::vector<R>& x,
const Pargs&... args) {
return count_memory(count, args...);
}

template <typename T, typename... Pargs>
template <typename... Pargs>
size_t count_memory(size_t count, const T& x, const Pargs&... args) {
return count_memory(count + 1, args...);
}

template <typename... Pargs>
size_t count_memory(size_t count, const int& x, const Pargs&... args) {
template <typename R, typename... Pargs>
size_t count_memory(size_t count, const R& x, const Pargs&... args) {
return count_memory(count, args...);
}

size_t count_memory(size_t count) { return count; }

template <typename T>
T& get(int i, std::vector<T>& arg) {
return arg[i];
T& get(size_t i, std::vector<T>& arg) { return arg[i]; }

template <typename R>
T& get(size_t i, std::vector<R>& arg) {
throw std::runtime_error("This should not be reachable");
}

template <typename T>
T& get(int i, T& arg) {
return arg;
T& get(size_t i, T& arg) { return arg; }

template <typename R>
T& get(size_t i, R& arg) {
throw std::runtime_error("This should not be reachable");
}

template <typename T, typename... Pargs>
T& get(int i, std::vector<T>& arg, Pargs&... args) {
template <typename... Pargs>
T& get(size_t i, std::vector<T>& arg, Pargs&... args) {
if (i < arg.size())
return arg[i];
else
return get(i - arg.size(), args...);
}

template <typename... Pargs>
auto& get(int i, std::vector<int>& arg, Pargs&... args) {
template <typename R, typename... Pargs>
auto& get(size_t i, std::vector<R>& arg, Pargs&... args) {
return get(i, args...);
}

template <typename T, typename... Pargs>
T& get(int i, T& arg, Pargs&... args) {
template <typename... Pargs>
T& get(size_t i, T& arg, Pargs&... args) {
if (i == 0)
return arg;
else
return get(i - 1, args...);
}

template <typename... Pargs>
auto& get(int i, int& arg, Pargs&... args) {
template <typename R, typename... Pargs>
auto& get(size_t i, R& arg, Pargs&... args) {
return get(i, args...);
}

public:
variable_adapter(const Targs&... args)
: args_(std::make_tuple(args...)), size_(count_memory(0, args...)) {}

int size() const { return size_; }
size_t size() const { return size_; }

auto& operator()(int i) {
auto& operator()(size_t i) {
check_less("variable_adapter::operator()", "i", i, size_);

return std::apply(
Expand All @@ -92,6 +96,11 @@ class variable_adapter {
}
};

template <typename T, typename... Targs>
auto variable_adapter_factory(const Targs&... args) {
return variable_adapter<T, Targs...>(args...);
}

} // namespace math
} // namespace stan
#endif
1 change: 0 additions & 1 deletion stan/math/prim/scal.hpp
Expand Up @@ -159,7 +159,6 @@
#include <stan/math/prim/scal/fun/primitive_value.hpp>
#include <stan/math/prim/scal/fun/prob_constrain.hpp>
#include <stan/math/prim/scal/fun/prob_free.hpp>
#include <stan/math/prim/scal/fun/promote_double_to_T.hpp>
#include <stan/math/prim/scal/fun/promote_elements.hpp>
#include <stan/math/prim/scal/fun/promote_scalar.hpp>
#include <stan/math/prim/scal/fun/promote_scalar_type.hpp>
Expand Down
43 changes: 43 additions & 0 deletions stan/math/prim/scal/fun/call_all_argument_combos.hpp
@@ -0,0 +1,43 @@
#ifndef STAN_MATH_PRIM_SCAL_FUN_CALL_ALL_ARGUMENT_COMBOS_HPP
#define STAN_MATH_PRIM_SCAL_FUN_CALL_ALL_ARGUMENT_COMBOS_HPP

#include <tuple>

namespace stan {
namespace math {

template <typename F>
auto call_all_argument_combos(F f) {
return std::make_tuple(f());
}

template <typename F, typename... Ts_first_arg, std::size_t... I,
typename... T_tail>
auto call_all_argument_combos_impl(
F f, const std::tuple<Ts_first_arg...>& first_arg_tuple,
std::index_sequence<I...>, const T_tail&... tail);

template <typename F, typename... Ts_first_arg, typename... T_tail>
auto call_all_argument_combos(
F f, const std::tuple<Ts_first_arg...>& first_arg_tuple,
const T_tail&... tail) {
return call_all_argument_combos_impl(
f, first_arg_tuple, std::make_index_sequence<sizeof...(Ts_first_arg)>{},
tail...);
}

template <typename F, typename... Ts_first_arg, std::size_t... I,
typename... T_tail>
auto call_all_argument_combos_impl(
F f, const std::tuple<Ts_first_arg...>& first_arg_tuple,
std::index_sequence<I...>, const T_tail&... tail) {
return std::tuple_cat(call_all_argument_combos(
[&first_arg_tuple, &f](const auto&... inner_args) {
return f(std::get<I>(first_arg_tuple), inner_args...);
},
tail...)...);
}

} // namespace math
} // namespace stan
#endif
21 changes: 0 additions & 21 deletions stan/math/prim/scal/fun/promote_double_to_T.hpp

This file was deleted.

59 changes: 51 additions & 8 deletions test/unit/math/prim/arr/fun/promote_double_to_T_test.cpp
Expand Up @@ -3,22 +3,65 @@
#include <type_traits>
#include <vector>

TEST(MathFunctions, promote_double_to_T_int) {
int x;
EXPECT_TRUE((std::is_same<std::tuple<>,
decltype(stan::math::promote_double_to_T<double>(
std::make_tuple(x)))>::value));
EXPECT_FALSE((std::is_same<std::tuple<double>,
decltype(stan::math::promote_double_to_T<double>(
std::make_tuple(x)))>::value));
}

TEST(MathFunctions, promote_double_to_T_double) {
double x;
EXPECT_TRUE((std::is_same<std::tuple<float>,
decltype(stan::math::promote_double_to_T<float>(
std::make_tuple(x)))>::value));
EXPECT_FALSE((std::is_same<std::tuple<>,
decltype(stan::math::promote_double_to_T<float>(
std::make_tuple(x)))>::value));
}

TEST(MathFunctions, promote_double_to_T_mix) {
int xi;
double xd;
auto a = stan::math::promote_double_to_T<float>(std::make_tuple(xd, xi, xd));
EXPECT_TRUE((std::is_same<std::tuple<float, float>,
decltype(stan::math::promote_double_to_T<float>(
std::make_tuple(xd, xi, xd)))>::value));
EXPECT_TRUE((std::is_same<std::tuple<float>,
decltype(stan::math::promote_double_to_T<float>(
std::make_tuple(xd, xi)))>::value));
}

TEST(MathFunctions, promote_double_to_T_std_vector_int) {
std::vector<int> x(3);
EXPECT_TRUE((std::is_same<const std::vector<int>&,
EXPECT_TRUE((std::is_same<std::tuple<>,
decltype(stan::math::promote_double_to_T<double>(
x))>::value));
EXPECT_FALSE((std::is_same<std::vector<double>,
std::make_tuple(x)))>::value));
EXPECT_FALSE((std::is_same<std::tuple<std::vector<double> >,
decltype(stan::math::promote_double_to_T<double>(
x))>::value));
std::make_tuple(x)))>::value));
}

TEST(MathFunctions, promote_double_to_T_std_vector_double) {
std::vector<double> x(3);
EXPECT_TRUE((std::is_same<std::vector<float>,
EXPECT_TRUE((std::is_same<std::tuple<std::vector<float> >,
decltype(stan::math::promote_double_to_T<float>(
x))>::value));
EXPECT_FALSE((std::is_same<const std::vector<double>&,
std::make_tuple(x)))>::value));
EXPECT_FALSE((std::is_same<std::tuple<>,
decltype(stan::math::promote_double_to_T<float>(
x))>::value));
std::make_tuple(x)))>::value));
}

TEST(MathFunctions, promote_double_to_T_std_vector_mix) {
std::vector<int> xi(3);
std::vector<double> xd(3);
EXPECT_TRUE((std::is_same<std::tuple<std::vector<float>, std::vector<float> >,
decltype(stan::math::promote_double_to_T<float>(
std::make_tuple(xd, xi, xd)))>::value));
EXPECT_TRUE((std::is_same<std::tuple<std::vector<float> >,
decltype(stan::math::promote_double_to_T<float>(
std::make_tuple(xd, xi)))>::value));
}
10 changes: 5 additions & 5 deletions test/unit/math/prim/arr/fun/variable_adapter_test.cpp
Expand Up @@ -7,7 +7,7 @@
TEST(MathFunctions, adapt_double) {
double x = 1.0;

auto a = stan::math::variable_adapter<double>(x);
auto a = stan::math::variable_adapter_factory<double>(x);

EXPECT_EQ(1, a.size());
EXPECT_FLOAT_EQ(x, a(0));
Expand All @@ -19,15 +19,15 @@ TEST(MathFunctions, adapt_double) {
TEST(MathFunctions, adapt_int) {
int x = 1;

auto a = stan::math::variable_adapter(x);
auto a = stan::math::variable_adapter_factory<double>(x);

EXPECT_EQ(0, a.size());
}

TEST(MathFunctions, adapt_std_vector_double) {
std::vector<double> x = {{1.0, 2.0}};

auto a = stan::math::variable_adapter(x);
auto a = stan::math::variable_adapter_factory<double>(x);

EXPECT_EQ(x.size(), a.size());
for (size_t i = 0; i < x.size(); ++i)
Expand All @@ -37,7 +37,7 @@ TEST(MathFunctions, adapt_std_vector_double) {
TEST(MathFunctions, adapt_std_vector_int) {
std::vector<int> x = {{1, 2}};

auto a = stan::math::variable_adapter(x);
auto a = stan::math::variable_adapter_factory<double>(x);

EXPECT_EQ(0, a.size());
}
Expand All @@ -48,7 +48,7 @@ TEST(MathFunctions, adapt_all_types) {
std::vector<double> xdv = {{1.0, 2.0}};
std::vector<int> xiv = {{1, 2}};

auto a = stan::math::variable_adapter(xd, xi, xdv, xiv, xd);
auto a = stan::math::variable_adapter_factory<double>(xd, xi, xdv, xiv, xd);

EXPECT_EQ(2 + xdv.size(), a.size());
EXPECT_FLOAT_EQ(xd, a(0));
Expand Down

0 comments on commit 34a163c

Please sign in to comment.