diff --git a/stan/math/prim/fun.hpp b/stan/math/prim/fun.hpp index 59e087ec277..80ee52d4fe7 100644 --- a/stan/math/prim/fun.hpp +++ b/stan/math/prim/fun.hpp @@ -334,6 +334,10 @@ #include #include #include +#include +#include +#include +#include #include #include #include diff --git a/stan/math/prim/fun/simplex_constrain.hpp b/stan/math/prim/fun/simplex_constrain.hpp index 4a7342620a1..cd10e93d587 100644 --- a/stan/math/prim/fun/simplex_constrain.hpp +++ b/stan/math/prim/fun/simplex_constrain.hpp @@ -24,9 +24,9 @@ namespace math { * @param y Free vector input of dimensionality K - 1. * @return Simplex of dimensionality K. */ -template * = nullptr, +template * = nullptr, require_not_st_var* = nullptr> -inline auto simplex_constrain(const Vec& y) { +inline plain_type_t simplex_constrain(const Vec& y) { // cut & paste simplex_constrain(Eigen::Matrix, T) w/o Jacobian using std::log; using T = value_type_t; @@ -56,9 +56,10 @@ inline auto simplex_constrain(const Vec& y) { * @param lp Log probability reference to increment. * @return Simplex of dimensionality K. */ -template * = nullptr, +template * = nullptr, require_not_st_var* = nullptr> -inline auto simplex_constrain(const Vec& y, value_type_t& lp) { +inline plain_type_t simplex_constrain(const Vec& y, + value_type_t& lp) { using Eigen::Dynamic; using Eigen::Matrix; using std::log; @@ -98,7 +99,8 @@ inline auto simplex_constrain(const Vec& y, value_type_t& lp) { * @return simplex of dimensionality one greater than `y` */ template * = nullptr> -auto simplex_constrain(const Vec& y, return_type_t& lp) { +inline plain_type_t simplex_constrain(const Vec& y, + return_type_t& lp) { if (Jacobian) { return simplex_constrain(y, lp); } else { diff --git a/stan/math/prim/fun/simplex_free.hpp b/stan/math/prim/fun/simplex_free.hpp index ad984d1070a..39487a0cae0 100644 --- a/stan/math/prim/fun/simplex_free.hpp +++ b/stan/math/prim/fun/simplex_free.hpp @@ -26,14 +26,14 @@ namespace math { * the simplex. * @throw std::domain_error if x is not a valid simplex */ -template * = nullptr> -auto simplex_free(const Vec& x) { +template * = nullptr> +inline plain_type_t simplex_free(const Vec& x) { using std::log; using T = value_type_t; const auto& x_ref = to_ref(x); check_simplex("stan::math::simplex_free", "Simplex variable", x_ref); - int Km1 = x_ref.size() - 1; + Eigen::Index Km1 = x_ref.size() - 1; plain_type_t y(Km1); T stick_len = x_ref.coeff(Km1); for (Eigen::Index k = Km1; --k >= 0;) { diff --git a/stan/math/prim/fun/stochastic_column_constrain.hpp b/stan/math/prim/fun/stochastic_column_constrain.hpp new file mode 100644 index 00000000000..8ed66883579 --- /dev/null +++ b/stan/math/prim/fun/stochastic_column_constrain.hpp @@ -0,0 +1,114 @@ +#ifndef STAN_MATH_PRIM_FUN_SIMPLEX_COLUMN_CONSTRAIN_HPP +#define STAN_MATH_PRIM_FUN_SIMPLEX_COLUMN_CONSTRAIN_HPP + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace math { + +/** + * Return a column stochastic matrix. + * + * The transform is based on a centered stick-breaking process. + * + * @tparam Mat type of the Matrix + * @param y Free Matrix input of dimensionality (K - 1, M) + * @return Matrix with simplex columns of dimensionality (K, M) + */ +template * = nullptr, + require_not_st_var* = nullptr> +inline plain_type_t stochastic_column_constrain(const Mat& y) { + auto&& y_ref = to_ref(y); + const Eigen::Index M = y_ref.cols(); + plain_type_t ret(y_ref.rows() + 1, M); + for (Eigen::Index i = 0; i < M; ++i) { + ret.col(i) = simplex_constrain(y_ref.col(i)); + } + return ret; +} + +/** + * Return a column stochastic matrix + * and increment the specified log probability reference with + * the log absolute Jacobian determinant of the transform. + * + * The simplex transform is defined through a centered + * stick-breaking process. + * + * @tparam Mat type of the Matrix + * @param y Free Matrix input of dimensionality (K - 1, M) + * @param lp Log probability reference to increment. + * @return Matrix with stochastic columns of dimensionality (K, M) + */ +template * = nullptr, + require_not_st_var* = nullptr> +inline plain_type_t stochastic_column_constrain(const Mat& y, + value_type_t& lp) { + auto&& y_ref = to_ref(y); + const Eigen::Index M = y_ref.cols(); + plain_type_t ret(y_ref.rows() + 1, M); + for (Eigen::Index i = 0; i < M; ++i) { + ret.col(i) = simplex_constrain(y_ref.col(i), lp); + } + return ret; +} + +/** + * Return a column stochastic matrix. If the + * `Jacobian` parameter is `true`, the log density accumulator is incremented + * with the log absolute Jacobian determinant of the transform. All of the + * transforms are specified with their Jacobians in the *Stan Reference Manual* + * chapter Constraint Transforms. + * + * @tparam Jacobian if `true`, increment log density accumulator with log + * absolute Jacobian determinant of constraining transform + * @tparam Mat type of the Matrix + * @param y Free Matrix input of dimensionality (K - 1, M). + * @param[in, out] lp log density accumulator + * @return Matrix with simplex columns of dimensionality (K, M). + */ +template * = nullptr> +inline plain_type_t stochastic_column_constrain(const Mat& y, + return_type_t& lp) { + if (Jacobian) { + return stochastic_column_constrain(y, lp); + } else { + return stochastic_column_constrain(y); + } +} + +/** + * Return a vector of column stochastic matrices. If the + * `Jacobian` parameter is `true`, the log density accumulator is incremented + * with the log absolute Jacobian determinant of the transform. All of the + * transforms are specified with their Jacobians in the *Stan Reference Manual* + * chapter Constraint Transforms. + * + * @tparam Jacobian if `true`, increment log density accumulator with log + * absolute Jacobian determinant of constraining transform + * @tparam T A standard vector with inner type inheriting from + * `Eigen::DenseBase` or a `var_value` with inner type inheriting from + * `Eigen::DenseBase` with compile time dynamic rows and dynamic columns + * @param[in] y free vector + * @param[in, out] lp log density accumulator + * @return Standard vector containing matrices with simplex columns of + * dimensionality (K, M). + */ +template * = nullptr> +inline auto stochastic_column_constrain(const T& y, return_type_t& lp) { + return apply_vector_unary::apply(y, [&lp](auto&& v) { + return stochastic_column_constrain(v, lp); + }); +} + +} // namespace math +} // namespace stan + +#endif diff --git a/stan/math/prim/fun/stochastic_column_free.hpp b/stan/math/prim/fun/stochastic_column_free.hpp new file mode 100644 index 00000000000..b7a69fced25 --- /dev/null +++ b/stan/math/prim/fun/stochastic_column_free.hpp @@ -0,0 +1,48 @@ +#ifndef STAN_MATH_PRIM_FUN_STOCHASTIC_COLUMN_FREE_HPP +#define STAN_MATH_PRIM_FUN_STOCHASTIC_COLUMN_FREE_HPP + +#include +#include +#include + +namespace stan { +namespace math { + +/** + * Return an unconstrained matrix that when transformed produces + * the specified columnwise stochastic matrix. It applies to a stochastic + * matrix of dimensionality (N, K) and produces an unconstrained vector of + * dimensionality (N - 1, K). + * + * @tparam Mat type of the Matrix + * @param y Columnwise stochastic matrix input of dimensionality (N, K) + */ +template * = nullptr, + require_not_st_var* = nullptr> +inline plain_type_t stochastic_column_free(const Mat& y) { + auto&& y_ref = to_ref(y); + const Eigen::Index M = y_ref.cols(); + plain_type_t ret(y_ref.rows() - 1, M); + for (Eigen::Index i = 0; i < M; ++i) { + ret.col(i) = simplex_free(y_ref.col(i)); + } + return ret; +} + +/** + * Overload that untransforms each Eigen matrix in a standard vector. + * + * @tparam T A standard vector with inner type inheriting from + * `Eigen::DenseBase` with compile time dynamic rows and dynamic rows + * @param[in] y vector of columnwise stochastic matrix of size (N, K) + */ +template * = nullptr> +inline auto stochastic_column_free(const T& y) { + return apply_vector_unary::apply( + y, [](auto&& v) { return stochastic_column_free(v); }); +} + +} // namespace math +} // namespace stan + +#endif diff --git a/stan/math/prim/fun/stochastic_row_constrain.hpp b/stan/math/prim/fun/stochastic_row_constrain.hpp new file mode 100644 index 00000000000..7ea9306a933 --- /dev/null +++ b/stan/math/prim/fun/stochastic_row_constrain.hpp @@ -0,0 +1,127 @@ +#ifndef STAN_MATH_PRIM_FUN_STOCHASTIC_ROW_CONSTRAIN_HPP +#define STAN_MATH_PRIM_FUN_STOCHASTIC_ROW_CONSTRAIN_HPP + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace math { + +/** + * Return a row stochastic matrix. + * + * The transform is based on a centered stick-breaking process. + * + * @tparam Mat type of the Matrix + * @param y Free Matrix input of dimensionality (N, K - 1). + * @return Matrix with simplexes along the rows of dimensionality (N, K). + */ +template * = nullptr, + require_not_st_var* = nullptr> +inline plain_type_t stochastic_row_constrain(const Mat& y) { + auto&& y_ref = to_ref(y); + const Eigen::Index N = y_ref.rows(); + int Km1 = y_ref.cols(); + plain_type_t x(N, Km1 + 1); + using eigen_arr = Eigen::Array, -1, 1>; + eigen_arr stick_len = eigen_arr::Constant(N, 1.0); + for (Eigen::Index k = 0; k < Km1; ++k) { + auto z_k = inv_logit(y_ref.array().col(k) - log(Km1 - k)); + x.array().col(k) = stick_len * z_k; + stick_len -= x.array().col(k); + } + x.array().col(Km1) = stick_len; + return x; +} + +/** + * Return a row stochastic matrix. + * The simplex transform is defined through a centered + * stick-breaking process. + * + * @tparam Mat type of the matrix + * @param y Free matrix input of dimensionality (N, K - 1). + * @param lp Log probability reference to increment. + * @return Matrix with simplexes along the rows of dimensionality (N, K). + */ +template * = nullptr, + require_not_st_var* = nullptr> +inline plain_type_t stochastic_row_constrain(const Mat& y, + value_type_t& lp) { + auto&& y_ref = to_ref(y); + const Eigen::Index N = y_ref.rows(); + Eigen::Index Km1 = y_ref.cols(); + plain_type_t x(N, Km1 + 1); + Eigen::Array, -1, 1> stick_len + = Eigen::Array, -1, 1>::Constant(N, 1.0); + for (Eigen::Index k = 0; k < Km1; ++k) { + const auto eq_share = -log(Km1 - k); // = logit(1.0/(Km1 + 1 - k)); + auto adj_y_k = (y_ref.array().col(k) + eq_share).eval(); + auto z_k = inv_logit(adj_y_k); + x.array().col(k) = stick_len * z_k; + lp += -sum(log1p_exp(adj_y_k)) - sum(log1p_exp(-adj_y_k)) + + sum(log(stick_len)); + stick_len -= x.array().col(k); // equivalently *= (1 - z_k); + } + x.col(Km1).array() = stick_len; + return x; +} + +/** + * Return a row stochastic matrix. + * If the `Jacobian` parameter is `true`, the log density accumulator is + * incremented with the log absolute Jacobian determinant of the transform. All + * of the transforms are specified with their Jacobians in the *Stan Reference + * Manual* chapter Constraint Transforms. + * + * @tparam Jacobian if `true`, increment log density accumulator with log + * absolute Jacobian determinant of constraining transform + * @tparam Mat A type inheriting from `Eigen::DenseBase` or a `var_value` with + * inner type inheriting from `Eigen::DenseBase` with compile time dynamic rows + * and dynamic columns + * @param[in] y free matrix + * @param[in, out] lp log density accumulator + * @return Matrix with simplexes along the rows of dimensionality (N, K). + */ +template * = nullptr> +inline plain_type_t stochastic_row_constrain(const Mat& y, + return_type_t& lp) { + if (Jacobian) { + return stochastic_row_constrain(y, lp); + } else { + return stochastic_row_constrain(y); + } +} + +/** + * Return a row stochastic matrix. + * If the `Jacobian` parameter is `true`, the log density accumulator is + * incremented with the log absolute Jacobian determinant of the transform. All + * of the transforms are specified with their Jacobians in the *Stan Reference + * Manual* chapter Constraint Transforms. + * + * @tparam Jacobian if `true`, increment log density accumulator with log + * absolute Jacobian determinant of constraining transform + * @tparam T A standard vector with inner type inheriting from + * `Eigen::DenseBase` or a `var_value` with inner type inheriting from + * `Eigen::DenseBase` with compile time dynamic rows and dynamic columns + * @param[in] y free vector with matrices of size (N, K - 1) + * @param[in, out] lp log density accumulator + * @return vector of matrices with simplex rows of dimensionality (N, K) + */ +template * = nullptr> +inline auto stochastic_row_constrain(const T& y, return_type_t& lp) { + return apply_vector_unary::apply( + y, [&lp](auto&& v) { return stochastic_row_constrain(v, lp); }); +} + +} // namespace math +} // namespace stan + +#endif diff --git a/stan/math/prim/fun/stochastic_row_free.hpp b/stan/math/prim/fun/stochastic_row_free.hpp new file mode 100644 index 00000000000..7fa255b9949 --- /dev/null +++ b/stan/math/prim/fun/stochastic_row_free.hpp @@ -0,0 +1,47 @@ +#ifndef STAN_MATH_PRIM_FUN_STOCHASTIC_ROW_FREE_HPP +#define STAN_MATH_PRIM_FUN_STOCHASTIC_ROW_FREE_HPP + +#include +#include +#include + +namespace stan { +namespace math { + +/** + * Return an unconstrained matrix that when transformed produces + * the specified simplex matrix. It applies to a simplex of dimensionality + * (N, K) and produces an unconstrained vector of dimensionality (N, K - 1). + * + * @tparam Mat type of the Matrix + * @param y Rowwise simplex Matrix input of dimensionality (N, K) + */ +template * = nullptr, + require_not_st_var* = nullptr> +inline plain_type_t stochastic_row_free(const Mat& y) { + auto&& y_ref = to_ref(y); + const Eigen::Index N = y_ref.rows(); + plain_type_t ret(N, y_ref.cols() - 1); + for (Eigen::Index i = 0; i < N; ++i) { + ret.row(i) = simplex_free(y_ref.row(i)); + } + return ret; +} + +/** + * Overload that untransforms each Eigen matrix in a standard vector. + * + * @tparam T A standard vector with inner type inheriting from + * `Eigen::DenseBase` with compile time dynamic rows and dynamic rows + * @param[in] y vector of rowwise simplex matrices each of size (N, K) + */ +template * = nullptr> +inline auto stochastic_row_free(const T& y) { + return apply_vector_unary::apply( + y, [](auto&& v) { return stochastic_row_free(v); }); +} + +} // namespace math +} // namespace stan + +#endif diff --git a/stan/math/prim/meta.hpp b/stan/math/prim/meta.hpp index d70b2386a6a..68d89d34ba5 100644 --- a/stan/math/prim/meta.hpp +++ b/stan/math/prim/meta.hpp @@ -112,6 +112,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/prim/meta/modify_eigen_options.hpp b/stan/math/prim/meta/modify_eigen_options.hpp new file mode 100644 index 00000000000..f68110d21ea --- /dev/null +++ b/stan/math/prim/meta/modify_eigen_options.hpp @@ -0,0 +1,47 @@ +#ifndef STAN_MATH_PRIM_META_MODIFY_EIGEN_OPTIONS_HPP +#define STAN_MATH_PRIM_META_MODIFY_EIGEN_OPTIONS_HPP + +#include +#include + +namespace stan { +namespace math { +namespace internal { +/** + * Change the options of an Eigen matrix or array. + * @tparam Mat type of the matrix or array + * @tparam NewOptions new options for the matrix or array + */ +template +struct change_eigen_options_impl {}; + +template +struct change_eigen_options_impl> { + using type + = Eigen::Matrix; +}; + +template +struct change_eigen_options_impl> { + using type + = Eigen::Array; +}; +} // namespace internal +/** + * Change the options of an Eigen matrix or array. + * @tparam Mat type of the matrix or array + * @tparam NewOptions new options for the matrix or array + */ +template +using change_eigen_options_t = typename internal::change_eigen_options_impl< + plain_type_t>, NewOptions>::type; + +} // namespace math +} // namespace stan + +#endif diff --git a/stan/math/rev/core/var.hpp b/stan/math/rev/core/var.hpp index c6dd279422d..c56d4f09a42 100644 --- a/stan/math/rev/core/var.hpp +++ b/stan/math/rev/core/var.hpp @@ -366,15 +366,40 @@ class var_value> { var_value(S&& x) : vi_(new vari_type(std::forward(x), false)) {} // NOLINT /** - * Copy constructor for var_val. + * Copy constructor for var_val when the vari_type from `other` is directly + * assignable. * @tparam S type of the value in the `var_value` to assing * @param other the value to assign * @return this */ - template * = nullptr, + template ::vari_type>* = nullptr, require_all_plain_type_t* = nullptr> var_value(const var_value& other) : vi_(other.vi_) {} + /** + * Construct from a `var_value` with different inner `vari_type` + * @tparam S An eigen type that is not the same as `T`, but can be assigned to + * `vari_value`. + * @param other the value to assign + * @note This constructor is for types such as + * `vari_value>` and + * `vari_value>`. As pointers those are not + * assignable to one another, but their inner matrix types are. So the `var` + * has to make a new `vari` and assign the inner matrices to that new `vari`. + */ + template ::vari_type>* = nullptr, + require_constructible_t* = nullptr, + require_all_plain_type_t* = nullptr> + var_value(const var_value& other) : vi_(new vari_type(other.vi_->val_)) { + reverse_pass_callback([this_vi = this->vi_, other_vi = other.vi_]() { + other_vi->adj_ += this_vi->adj_; + }); + } + /** * Construct a `var_value` with a plain type * from another `var_value` containing an expression. diff --git a/stan/math/rev/fun.hpp b/stan/math/rev/fun.hpp index 0e80af80406..08f0f044a6d 100644 --- a/stan/math/rev/fun.hpp +++ b/stan/math/rev/fun.hpp @@ -163,6 +163,8 @@ #include #include #include +#include +#include #include #include #include diff --git a/stan/math/rev/fun/stochastic_column_constrain.hpp b/stan/math/rev/fun/stochastic_column_constrain.hpp new file mode 100644 index 00000000000..3838e0d0f77 --- /dev/null +++ b/stan/math/rev/fun/stochastic_column_constrain.hpp @@ -0,0 +1,135 @@ +#ifndef STAN_MATH_REV_FUN_STOCHASTIC_COLUMN_CONSTRAIN_HPP +#define STAN_MATH_REV_FUN_STOCHASTIC_COLUMN_CONSTRAIN_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace math { + +/** + * Return a column stochastic matrix. + * The transform is based on a centered stick-breaking process. + * + * @tparam T Type of matrix to constrain + * @param y Free matrix input of dimensionality (K - 1, M) + * @return matrix of column simplexes of dimensionality (K, M) + */ +template * = nullptr> +inline plain_type_t stochastic_column_constrain(const T& y) { + using ret_type = plain_type_t; + const Eigen::Index N = y.rows(); + const Eigen::Index M = y.cols(); + using eigen_mat_rowmajor + = Eigen::Matrix; + arena_t x_val(N + 1, M); + if (unlikely(N == 0 || M == 0)) { + return ret_type(x_val); + } + arena_t> arena_y = y; + arena_t arena_z(N, M); + using arr_vec = Eigen::Array; + arr_vec stick_len = arr_vec::Constant(M, 1.0); + for (Eigen::Index k = 0; k < N; ++k) { + const double log_N_minus_k = std::log(N - k); + arena_z.row(k) + = inv_logit(arena_y.array().row(k).val_op() - log_N_minus_k).matrix(); + x_val.row(k) = stick_len.array() * arena_z.array().row(k); + stick_len -= x_val.array().row(k); + } + x_val.row(N) = stick_len; + arena_t arena_x = x_val; + reverse_pass_callback([arena_y, arena_x, arena_z]() mutable { + const Eigen::Index N = arena_y.rows(); + auto arena_x_arr = arena_x.array(); + auto arena_y_arr = arena_y.array(); + auto arena_z_arr = arena_z.array(); + auto stick_len_val = arena_x.array().row(N).val().eval(); + auto stick_len_adj = arena_x.array().row(N).adj().eval(); + for (Eigen::Index k = N; k-- > 0;) { + arena_x_arr.row(k).adj() -= stick_len_adj; + stick_len_val += arena_x_arr.row(k).val(); + stick_len_adj += arena_x_arr.row(k).adj() * arena_z_arr.row(k); + auto arena_z_adj = arena_x_arr.row(k).adj() * stick_len_val; + arena_y_arr.row(k).adj() + += arena_z_adj * arena_z_arr.row(k) * (1.0 - arena_z_arr.row(k)); + } + }); + return ret_type(arena_x); +} + +/** + * Return a column stochastic matrix + * and increment the specified log probability reference with + * the log absolute Jacobian determinant of the transform. + * + * The simplex transform is defined through a centered + * stick-breaking process. + * + * @tparam T type of the matrix to constrain + * @param y Free matrix input of dimensionality N, K. + * @param lp Log probability reference to increment. + * @return Matrix of stochastic columns of dimensionality (N + 1, K). + */ +template * = nullptr> +inline plain_type_t stochastic_column_constrain(const T& y, + scalar_type_t& lp) { + using ret_type = plain_type_t; + const Eigen::Index N = y.rows(); + const Eigen::Index M = y.cols(); + using eigen_mat_rowmajor + = Eigen::Matrix; + arena_t x_val(N + 1, M); + if (unlikely(N == 0 || M == 0)) { + return ret_type(x_val); + } + arena_t> arena_y = y; + arena_t arena_z(N, M); + using arr_vec = Eigen::Array; + arr_vec stick_len = arr_vec::Constant(M, 1.0); + arr_vec adj_y_k(N); + for (Eigen::Index k = 0; k < N; ++k) { + double log_N_minus_k = std::log(N - k); + adj_y_k = arena_y.array().row(k).val() - log_N_minus_k; + arena_z.array().row(k) = inv_logit(adj_y_k); + x_val.array().row(k) = stick_len * arena_z.array().row(k); + lp += sum(log(stick_len)) - sum(log1p_exp(-adj_y_k)) + - sum(log1p_exp(adj_y_k)); + stick_len -= x_val.array().row(k); + } + x_val.array().row(N) = stick_len; + arena_t arena_x = x_val; + reverse_pass_callback([arena_y, arena_x, arena_z, lp]() mutable { + const Eigen::Index N = arena_y.rows(); + auto arena_x_arr = arena_x.array(); + auto arena_y_arr = arena_y.array(); + auto arena_z_arr = arena_z.array(); + auto stick_len_val = arena_x.array().row(N).val().eval(); + auto stick_len_adj = arena_x.array().row(N).adj().eval(); + for (Eigen::Index k = N; k-- > 0;) { + const double log_N_minus_k = std::log(N - k); + arena_x_arr.row(k).adj() -= stick_len_adj; + stick_len_val += arena_x_arr.row(k).val(); + stick_len_adj += lp.adj() / stick_len_val + + arena_x_arr.row(k).adj() * arena_z_arr.row(k); + auto adj_y_k = arena_y_arr.row(k).val() - log_N_minus_k; + auto arena_z_adj = arena_x_arr.row(k).adj() * stick_len_val; + arena_y_arr.row(k).adj() + += -(lp.adj() * inv_logit(adj_y_k)) + lp.adj() * inv_logit(-adj_y_k) + + arena_z_adj * arena_z_arr.row(k) * (1.0 - arena_z_arr.row(k)); + } + }); + return ret_type(arena_x); +} + +} // namespace math +} // namespace stan +#endif diff --git a/stan/math/rev/fun/stochastic_row_constrain.hpp b/stan/math/rev/fun/stochastic_row_constrain.hpp new file mode 100644 index 00000000000..98f6781a1c3 --- /dev/null +++ b/stan/math/rev/fun/stochastic_row_constrain.hpp @@ -0,0 +1,127 @@ +#ifndef STAN_MATH_REV_FUN_STOCHASTIC_ROW_CONSTRAIN_HPP +#define STAN_MATH_REV_FUN_STOCHASTIC_ROW_CONSTRAIN_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace math { + +/** + * Return a row stochastic matrix. + * + * @tparam T Type of matrix to constrain + * @param y Free vector input of dimensionality (N, K - 1) + * @return Matrix with Simplexes along the rows of dimensionality (N, K) + */ +template * = nullptr> +inline plain_type_t stochastic_row_constrain(const T& y) { + using ret_type = plain_type_t; + const Eigen::Index N = y.rows(); + const Eigen::Index M = y.cols(); + arena_t x_val(N, M + 1); + if (unlikely(N == 0 || M == 0)) { + return ret_type(x_val); + } + arena_t arena_y = y; + arena_t arena_z(N, M); + Eigen::Array stick_len = Eigen::Array::Ones(N); + for (Eigen::Index j = 0; j < M; ++j) { + double log_N_minus_k = std::log(M - j); + arena_z.col(j).array() + = inv_logit((arena_y.col(j).val_op().array() - log_N_minus_k).matrix()); + x_val.col(j).array() = stick_len * arena_z.col(j).array(); + stick_len -= x_val.col(j).array(); + } + x_val.col(M).array() = stick_len; + arena_t arena_x = x_val; + reverse_pass_callback([arena_y, arena_x, arena_z]() mutable { + const Eigen::Index M = arena_y.cols(); + auto arena_y_arr = arena_y.array(); + auto arena_x_arr = arena_x.array(); + auto arena_z_arr = arena_z.array(); + auto stick_len_val_arr = arena_x_arr.col(M).val_op().eval(); + auto stick_len_adj_arr = arena_x_arr.col(M).adj_op().eval(); + for (Eigen::Index k = M; k-- > 0;) { + arena_x_arr.col(k).adj() -= stick_len_adj_arr; + stick_len_val_arr += arena_x_arr.col(k).val_op(); + stick_len_adj_arr += arena_x_arr.col(k).adj_op() * arena_z_arr.col(k); + arena_y_arr.col(k).adj() += arena_x_arr.adj_op().col(k) + * stick_len_val_arr * arena_z_arr.col(k) + * (1.0 - arena_z_arr.col(k)); + } + }); + return ret_type(arena_x); +} + +/** + * Return a row stochastic matrix + * and increment the specified log probability reference with + * the log absolute Jacobian determinant of the transform. + * + * The simplex transform is defined through a centered + * stick-breaking process. + * + * @tparam T type of the matrix to constrain + * @param y Free matrix input of dimensionality (N, K). + * @param lp Log probability reference to increment. + * @return Matrix with simplexes along the rows of dimensionality (N, K + 1). + */ +template * = nullptr> +inline plain_type_t stochastic_row_constrain(const T& y, + scalar_type_t& lp) { + using ret_type = plain_type_t; + const Eigen::Index N = y.rows(); + const Eigen::Index M = y.cols(); + arena_t x_val(N, M + 1); + if (unlikely(N == 0 || M == 0)) { + return ret_type(x_val); + } + arena_t arena_y = y; + arena_t arena_z(N, M); + Eigen::Array stick_len = Eigen::Array::Ones(N); + for (Eigen::Index j = 0; j < M; ++j) { + double log_N_minus_k = std::log(M - j); + auto adj_y_k = arena_y.col(j).val_op().array() - log_N_minus_k; + arena_z.col(j).array() = inv_logit(adj_y_k); + x_val.col(j).array() = stick_len * arena_z.col(j).array(); + lp += sum(log(stick_len)) - sum(log1p_exp(-adj_y_k)) + - sum(log1p_exp(adj_y_k)); + stick_len -= x_val.col(j).array(); + } + x_val.col(M).array() = stick_len; + arena_t arena_x = x_val; + reverse_pass_callback([arena_y, arena_x, arena_z, lp]() mutable { + const Eigen::Index M = arena_y.cols(); + auto arena_y_arr = arena_y.array(); + auto arena_x_arr = arena_x.array(); + auto arena_z_arr = arena_z.array(); + auto stick_len_val = arena_x_arr.col(M).val_op().eval(); + auto stick_len_adj = arena_x_arr.col(M).adj_op().eval(); + for (Eigen::Index k = M; k-- > 0;) { + const double log_N_minus_k = std::log(M - k); + arena_x_arr.col(k).adj() -= stick_len_adj; + stick_len_val += arena_x_arr.col(k).val_op(); + stick_len_adj += lp.adj() / stick_len_val + + arena_x_arr.adj_op().col(k) * arena_z_arr.col(k); + auto adj_y_k = arena_y_arr.col(k).val_op() - log_N_minus_k; + arena_y_arr.col(k).adj() + += -(lp.adj() * inv_logit(adj_y_k)) + lp.adj() * inv_logit(-adj_y_k) + + arena_x_arr.col(k).adj_op() * stick_len_val * arena_z_arr.col(k) + * (1.0 - arena_z_arr.col(k)); + } + }); + return ret_type(arena_x); +} + +} // namespace math +} // namespace stan +#endif diff --git a/stan/math/rev/meta.hpp b/stan/math/rev/meta.hpp index e182c6ab825..892080b5343 100644 --- a/stan/math/rev/meta.hpp +++ b/stan/math/rev/meta.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/rev/meta/modify_eigen_options.hpp b/stan/math/rev/meta/modify_eigen_options.hpp new file mode 100644 index 00000000000..30e21db4bb2 --- /dev/null +++ b/stan/math/rev/meta/modify_eigen_options.hpp @@ -0,0 +1,31 @@ +#ifndef STAN_MATH_REV_META_MODIFY_EIGEN_OPTIONS_HPP +#define STAN_MATH_REV_META_MODIFY_EIGEN_OPTIONS_HPP + +#include +#include + +namespace stan { +namespace math { +namespace internal { + +template +struct change_eigen_options_impl, NewOptions, + require_eigen_matrix_base_t> { + using type = var_value>; +}; + +template +struct change_eigen_options_impl, NewOptions, + require_eigen_array_t> { + using type = var_value>; +}; + +} // namespace internal +} // namespace math +} // namespace stan + +#endif diff --git a/test/unit/math/mix/fun/stochastic_column_constrain_test.cpp b/test/unit/math/mix/fun/stochastic_column_constrain_test.cpp new file mode 100644 index 00000000000..995db3d1e7a --- /dev/null +++ b/test/unit/math/mix/fun/stochastic_column_constrain_test.cpp @@ -0,0 +1,58 @@ +#include + +namespace stochastic_column_constrain_test { +template +T g1(const T& x) { + stan::scalar_type_t lp = 0; + return stan::math::stochastic_column_constrain(x, lp); +} +template +T g2(const T& x) { + stan::scalar_type_t lp = 0; + return stan::math::stochastic_column_constrain(x, lp); +} +template +typename stan::scalar_type::type g3(const T& x) { + stan::scalar_type_t lp = 0; + stan::math::stochastic_column_constrain(x, lp); + return lp; +} + +template +void expect_simplex_transform(const T& x) { + auto f1 = [](const auto& x) { return g1(x); }; + auto f2 = [](const auto& x) { return g2(x); }; + auto f3 = [](const auto& x) { return g3(x); }; + stan::test::expect_ad(f1, x); + stan::test::expect_ad_matvar(f1, x); + stan::test::expect_ad(f2, x); + stan::test::expect_ad_matvar(f2, x); + stan::test::expect_ad(f3, x); + stan::test::expect_ad_matvar(f3, x); +} +} // namespace stochastic_column_constrain_test + +TEST(MathMixMatFun, simplexColumnTransform) { + Eigen::MatrixXd v0(0, 0); + stochastic_column_constrain_test::expect_simplex_transform(v0); + + Eigen::MatrixXd v1(1, 3); + v1 << 1, 2, 3; + stochastic_column_constrain_test::expect_simplex_transform(v1); + + Eigen::MatrixXd v2(2, 3); + v2 << 3, -1, 3, -1, 3, -1; + stochastic_column_constrain_test::expect_simplex_transform(v2); + + Eigen::MatrixXd v3(3, 3); + v3 << 2, 3, -1, 2, 3, -1, 2, 3, -1; + stochastic_column_constrain_test::expect_simplex_transform(v3); + + Eigen::MatrixXd v4(4, 3); + v4 << 2, -1, 0, -1.1, 2, -1, 0, -1.1, 2, -1, 0, -1.1; + stochastic_column_constrain_test::expect_simplex_transform(v4); + + Eigen::MatrixXd v5(5, 3); + v5 << 1, -3, 2, 0, -1, 1, -3, 2, 0, -1, 1, -3, 2, 0, -1; + stochastic_column_constrain_test::expect_simplex_transform(v5); +} diff --git a/test/unit/math/mix/fun/stochastic_row_constrain_test.cpp b/test/unit/math/mix/fun/stochastic_row_constrain_test.cpp new file mode 100644 index 00000000000..d40921d3032 --- /dev/null +++ b/test/unit/math/mix/fun/stochastic_row_constrain_test.cpp @@ -0,0 +1,65 @@ +#include + +namespace stochastic_row_constrain_test { +template +T g1(const T& x) { + stan::scalar_type_t lp = 0; + return stan::math::stochastic_row_constrain(x, lp); +} +template +T g2(const T& x) { + stan::scalar_type_t lp = 0; + return stan::math::stochastic_row_constrain(x, lp); +} +template +typename stan::scalar_type::type g3(const T& x) { + stan::scalar_type_t lp = 0; + stan::math::stochastic_row_constrain(x, lp); + return lp; +} + +template +void expect_simplex_transform(const T& x) { + auto f1 = [](const auto& x) { return g1(x); }; + auto f2 = [](const auto& x) { return g2(x); }; + auto f3 = [](const auto& x) { return g3(x); }; + stan::test::expect_ad(f1, x); + stan::test::expect_ad_matvar(f1, x); + stan::test::expect_ad(f2, x); + stan::test::expect_ad_matvar(f2, x); + stan::test::expect_ad(f3, x); + stan::test::expect_ad_matvar(f3, x); +} +} // namespace stochastic_row_constrain_test + +TEST(MathMixMatFun, simplexRowTransform0) { + Eigen::MatrixXd v0(0, 0); + stochastic_row_constrain_test::expect_simplex_transform(v0); +} +TEST(MathMixMatFun, simplexRowTransform1) { + Eigen::MatrixXd v1(1, 3); + v1 << .01, .1, 1; + stochastic_row_constrain_test::expect_simplex_transform(v1); +} + +TEST(MathMixMatFun, simplexRowTransform2) { + Eigen::MatrixXd v2(2, 3); + v2 << 3, -1, 3, -1, 3, -1; + stochastic_row_constrain_test::expect_simplex_transform(v2); +} + +TEST(MathMixMatFun, simplexRowTransform3) { + Eigen::MatrixXd v3(3, 3); + v3 << 2, 3, -1, 2, 3, -1, 2, 3, -1; + stochastic_row_constrain_test::expect_simplex_transform(v3); +} +TEST(MathMixMatFun, simplexRowTransform4) { + Eigen::MatrixXd v4(4, 3); + v4 << 2, -1, 0, -1.1, 2, -1, 0, -1.1, 2, -1, 0, -1.1; + stochastic_row_constrain_test::expect_simplex_transform(v4); +} +TEST(MathMixMatFun, simplexRowTransform5) { + Eigen::MatrixXd v5(5, 3); + v5 << 1, -3, 2, 0, -1, 1, -3, 2, 0, -1, 1, -3, 2, 0, -1; + stochastic_row_constrain_test::expect_simplex_transform(v5); +} diff --git a/test/unit/math/prim/fun/stochastic_column_constrain_test.cpp b/test/unit/math/prim/fun/stochastic_column_constrain_test.cpp new file mode 100644 index 00000000000..c8a7d3cd00f --- /dev/null +++ b/test/unit/math/prim/fun/stochastic_column_constrain_test.cpp @@ -0,0 +1,47 @@ +#include +#include +#include + +TEST(prob_transform, stochastic_column_rt0) { + using Eigen::Dynamic; + using Eigen::Matrix; + Matrix x(4, 4); + for (Eigen::Index i = 0; i < x.size(); ++i) { + x(i) = static_cast(i); + } + double lp = 0; + Matrix x_test + = stan::math::stochastic_column_constrain(x, lp); + EXPECT_EQ(lp, 0.0); + Matrix x_res(5, 4); + double lp_orig = 0.0; + for (Eigen::Index i = 0; i < x.cols(); ++i) { + x_res.col(i) = stan::math::simplex_constrain(x.col(i), lp_orig); + } + EXPECT_EQ(lp_orig, 0.0); + EXPECT_MATRIX_EQ(x_test, x_res); + Matrix x_lp_test + = stan::math::stochastic_column_constrain(x, lp); + for (Eigen::Index i = 0; i < x.cols(); ++i) { + x_res.col(i) = stan::math::simplex_constrain(x.col(i), lp_orig); + } + EXPECT_MATRIX_EQ(x_lp_test, x_res); +} + +TEST(prob_transform, stochastic_column_constrain_and_free) { + using Eigen::Dynamic; + using Eigen::Matrix; + Matrix x(4, 4); + for (Eigen::Index i = 0; i < x.size(); ++i) { + x(i) = static_cast(i); + } + double lp = 0; + Matrix x_test = stan::math::stochastic_column_free( + stan::math::stochastic_column_constrain(x, lp)); + EXPECT_MATRIX_NEAR(x, x_test, 1e-9); + + Matrix x_lp_test + = stan::math::stochastic_column_free( + stan::math::stochastic_column_constrain(x, lp)); + EXPECT_MATRIX_NEAR(x, x_lp_test, 1e-9); +} diff --git a/test/unit/math/prim/fun/stochastic_row_constrain_test.cpp b/test/unit/math/prim/fun/stochastic_row_constrain_test.cpp new file mode 100644 index 00000000000..9e409d5f821 --- /dev/null +++ b/test/unit/math/prim/fun/stochastic_row_constrain_test.cpp @@ -0,0 +1,46 @@ +#include +#include +#include + +TEST(prob_transform, row_stochastic_rt0) { + using Eigen::Dynamic; + using Eigen::Matrix; + Matrix x(4, 4); + for (Eigen::Index i = 0; i < x.size(); ++i) { + x(i) = static_cast(i); + } + double lp = 0; + Matrix x_test + = stan::math::stochastic_row_constrain(x, lp); + EXPECT_EQ(lp, 0.0); + Matrix x_res(4, 5); + double lp_orig = 0.0; + for (Eigen::Index i = 0; i < x.cols(); ++i) { + x_res.row(i) = stan::math::simplex_constrain(x.row(i), lp_orig); + } + EXPECT_EQ(lp_orig, 0.0); + EXPECT_MATRIX_EQ(x_test, x_res); + Matrix x_lp_test + = stan::math::stochastic_row_constrain(x, lp); + for (Eigen::Index i = 0; i < x.cols(); ++i) { + x_res.row(i) = stan::math::simplex_constrain(x.row(i), lp_orig); + } + EXPECT_MATRIX_EQ(x_lp_test, x_res); +} + +TEST(prob_transform, row_stochastic_constrain_free) { + using Eigen::Dynamic; + using Eigen::Matrix; + Matrix x(4, 4); + for (Eigen::Index i = 0; i < x.size(); ++i) { + x(i) = static_cast(i); + } + double lp = 0; + Matrix x_test = stan::math::stochastic_row_free( + stan::math::stochastic_row_constrain(x, lp)); + EXPECT_MATRIX_NEAR(x, x_test, 1e-9); + + Matrix x_lp_test = stan::math::stochastic_row_free( + stan::math::stochastic_row_constrain(x, lp)); + EXPECT_MATRIX_NEAR(x, x_lp_test, 1e-9); +}