Skip to content

Commit

Permalink
Merge pull request #2992 from stan-dev/feature/simplex_row_and_col
Browse files Browse the repository at this point in the history
Add simplex row and column matrix constraints
  • Loading branch information
andrjohns committed Mar 23, 2024
2 parents 1d0ce81 + df5e1d1 commit 8f0731a
Show file tree
Hide file tree
Showing 19 changed files with 937 additions and 10 deletions.
4 changes: 4 additions & 0 deletions stan/math/prim/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,10 @@
#include <stan/math/prim/fun/squared_distance.hpp>
#include <stan/math/prim/fun/stan_print.hpp>
#include <stan/math/prim/fun/step.hpp>
#include <stan/math/prim/fun/stochastic_column_constrain.hpp>
#include <stan/math/prim/fun/stochastic_column_free.hpp>
#include <stan/math/prim/fun/stochastic_row_constrain.hpp>
#include <stan/math/prim/fun/stochastic_row_free.hpp>
#include <stan/math/prim/fun/sub_col.hpp>
#include <stan/math/prim/fun/sub_row.hpp>
#include <stan/math/prim/fun/subtract.hpp>
Expand Down
12 changes: 7 additions & 5 deletions stan/math/prim/fun/simplex_constrain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ namespace math {
* @param y Free vector input of dimensionality K - 1.
* @return Simplex of dimensionality K.
*/
template <typename Vec, require_eigen_col_vector_t<Vec>* = nullptr,
template <typename Vec, require_eigen_vector_t<Vec>* = nullptr,
require_not_st_var<Vec>* = nullptr>
inline auto simplex_constrain(const Vec& y) {
inline plain_type_t<Vec> simplex_constrain(const Vec& y) {
// cut & paste simplex_constrain(Eigen::Matrix, T) w/o Jacobian
using std::log;
using T = value_type_t<Vec>;
Expand Down Expand Up @@ -56,9 +56,10 @@ inline auto simplex_constrain(const Vec& y) {
* @param lp Log probability reference to increment.
* @return Simplex of dimensionality K.
*/
template <typename Vec, require_eigen_col_vector_t<Vec>* = nullptr,
template <typename Vec, require_eigen_vector_t<Vec>* = nullptr,
require_not_st_var<Vec>* = nullptr>
inline auto simplex_constrain(const Vec& y, value_type_t<Vec>& lp) {
inline plain_type_t<Vec> simplex_constrain(const Vec& y,
value_type_t<Vec>& lp) {
using Eigen::Dynamic;
using Eigen::Matrix;
using std::log;
Expand Down Expand Up @@ -98,7 +99,8 @@ inline auto simplex_constrain(const Vec& y, value_type_t<Vec>& lp) {
* @return simplex of dimensionality one greater than `y`
*/
template <bool Jacobian, typename Vec, require_not_std_vector_t<Vec>* = nullptr>
auto simplex_constrain(const Vec& y, return_type_t<Vec>& lp) {
inline plain_type_t<Vec> simplex_constrain(const Vec& y,
return_type_t<Vec>& lp) {
if (Jacobian) {
return simplex_constrain(y, lp);
} else {
Expand Down
6 changes: 3 additions & 3 deletions stan/math/prim/fun/simplex_free.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ namespace math {
* the simplex.
* @throw std::domain_error if x is not a valid simplex
*/
template <typename Vec, require_eigen_col_vector_t<Vec>* = nullptr>
auto simplex_free(const Vec& x) {
template <typename Vec, require_eigen_vector_t<Vec>* = nullptr>
inline plain_type_t<Vec> simplex_free(const Vec& x) {
using std::log;
using T = value_type_t<Vec>;

const auto& x_ref = to_ref(x);
check_simplex("stan::math::simplex_free", "Simplex variable", x_ref);
int Km1 = x_ref.size() - 1;
Eigen::Index Km1 = x_ref.size() - 1;
plain_type_t<Vec> y(Km1);
T stick_len = x_ref.coeff(Km1);
for (Eigen::Index k = Km1; --k >= 0;) {
Expand Down
114 changes: 114 additions & 0 deletions stan/math/prim/fun/stochastic_column_constrain.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#ifndef STAN_MATH_PRIM_FUN_SIMPLEX_COLUMN_CONSTRAIN_HPP
#define STAN_MATH_PRIM_FUN_SIMPLEX_COLUMN_CONSTRAIN_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/fun/logit.hpp>
#include <stan/math/prim/fun/simplex_constrain.hpp>
#include <cmath>

namespace stan {
namespace math {

/**
* Return a column stochastic matrix.
*
* The transform is based on a centered stick-breaking process.
*
* @tparam Mat type of the Matrix
* @param y Free Matrix input of dimensionality (K - 1, M)
* @return Matrix with simplex columns of dimensionality (K, M)
*/
template <typename Mat, require_eigen_matrix_dynamic_t<Mat>* = nullptr,
require_not_st_var<Mat>* = nullptr>
inline plain_type_t<Mat> stochastic_column_constrain(const Mat& y) {
auto&& y_ref = to_ref(y);
const Eigen::Index M = y_ref.cols();
plain_type_t<Mat> ret(y_ref.rows() + 1, M);
for (Eigen::Index i = 0; i < M; ++i) {
ret.col(i) = simplex_constrain(y_ref.col(i));
}
return ret;
}

/**
* Return a column stochastic matrix
* and increment the specified log probability reference with
* the log absolute Jacobian determinant of the transform.
*
* The simplex transform is defined through a centered
* stick-breaking process.
*
* @tparam Mat type of the Matrix
* @param y Free Matrix input of dimensionality (K - 1, M)
* @param lp Log probability reference to increment.
* @return Matrix with stochastic columns of dimensionality (K, M)
*/
template <typename Mat, require_eigen_matrix_dynamic_t<Mat>* = nullptr,
require_not_st_var<Mat>* = nullptr>
inline plain_type_t<Mat> stochastic_column_constrain(const Mat& y,
value_type_t<Mat>& lp) {
auto&& y_ref = to_ref(y);
const Eigen::Index M = y_ref.cols();
plain_type_t<Mat> ret(y_ref.rows() + 1, M);
for (Eigen::Index i = 0; i < M; ++i) {
ret.col(i) = simplex_constrain(y_ref.col(i), lp);
}
return ret;
}

/**
* Return a column stochastic matrix. If the
* `Jacobian` parameter is `true`, the log density accumulator is incremented
* with the log absolute Jacobian determinant of the transform. All of the
* transforms are specified with their Jacobians in the *Stan Reference Manual*
* chapter Constraint Transforms.
*
* @tparam Jacobian if `true`, increment log density accumulator with log
* absolute Jacobian determinant of constraining transform
* @tparam Mat type of the Matrix
* @param y Free Matrix input of dimensionality (K - 1, M).
* @param[in, out] lp log density accumulator
* @return Matrix with simplex columns of dimensionality (K, M).
*/
template <bool Jacobian, typename Mat, require_not_std_vector_t<Mat>* = nullptr>
inline plain_type_t<Mat> stochastic_column_constrain(const Mat& y,
return_type_t<Mat>& lp) {
if (Jacobian) {
return stochastic_column_constrain(y, lp);
} else {
return stochastic_column_constrain(y);
}
}

/**
* Return a vector of column stochastic matrices. If the
* `Jacobian` parameter is `true`, the log density accumulator is incremented
* with the log absolute Jacobian determinant of the transform. All of the
* transforms are specified with their Jacobians in the *Stan Reference Manual*
* chapter Constraint Transforms.
*
* @tparam Jacobian if `true`, increment log density accumulator with log
* absolute Jacobian determinant of constraining transform
* @tparam T A standard vector with inner type inheriting from
* `Eigen::DenseBase` or a `var_value` with inner type inheriting from
* `Eigen::DenseBase` with compile time dynamic rows and dynamic columns
* @param[in] y free vector
* @param[in, out] lp log density accumulator
* @return Standard vector containing matrices with simplex columns of
* dimensionality (K, M).
*/
template <bool Jacobian, typename T, require_std_vector_t<T>* = nullptr>
inline auto stochastic_column_constrain(const T& y, return_type_t<T>& lp) {
return apply_vector_unary<T>::apply(y, [&lp](auto&& v) {
return stochastic_column_constrain<Jacobian>(v, lp);
});
}

} // namespace math
} // namespace stan

#endif
48 changes: 48 additions & 0 deletions stan/math/prim/fun/stochastic_column_free.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#ifndef STAN_MATH_PRIM_FUN_STOCHASTIC_COLUMN_FREE_HPP
#define STAN_MATH_PRIM_FUN_STOCHASTIC_COLUMN_FREE_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/simplex_free.hpp>

namespace stan {
namespace math {

/**
* Return an unconstrained matrix that when transformed produces
* the specified columnwise stochastic matrix. It applies to a stochastic
* matrix of dimensionality (N, K) and produces an unconstrained vector of
* dimensionality (N - 1, K).
*
* @tparam Mat type of the Matrix
* @param y Columnwise stochastic matrix input of dimensionality (N, K)
*/
template <typename Mat, require_eigen_matrix_dynamic_t<Mat>* = nullptr,
require_not_st_var<Mat>* = nullptr>
inline plain_type_t<Mat> stochastic_column_free(const Mat& y) {
auto&& y_ref = to_ref(y);
const Eigen::Index M = y_ref.cols();
plain_type_t<Mat> ret(y_ref.rows() - 1, M);
for (Eigen::Index i = 0; i < M; ++i) {
ret.col(i) = simplex_free(y_ref.col(i));
}
return ret;
}

/**
* Overload that untransforms each Eigen matrix in a standard vector.
*
* @tparam T A standard vector with inner type inheriting from
* `Eigen::DenseBase` with compile time dynamic rows and dynamic rows
* @param[in] y vector of columnwise stochastic matrix of size (N, K)
*/
template <typename T, require_std_vector_t<T>* = nullptr>
inline auto stochastic_column_free(const T& y) {
return apply_vector_unary<T>::apply(
y, [](auto&& v) { return stochastic_column_free(v); });
}

} // namespace math
} // namespace stan

#endif
127 changes: 127 additions & 0 deletions stan/math/prim/fun/stochastic_row_constrain.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#ifndef STAN_MATH_PRIM_FUN_STOCHASTIC_ROW_CONSTRAIN_HPP
#define STAN_MATH_PRIM_FUN_STOCHASTIC_ROW_CONSTRAIN_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/fun/logit.hpp>
#include <stan/math/prim/fun/simplex_constrain.hpp>
#include <cmath>

namespace stan {
namespace math {

/**
* Return a row stochastic matrix.
*
* The transform is based on a centered stick-breaking process.
*
* @tparam Mat type of the Matrix
* @param y Free Matrix input of dimensionality (N, K - 1).
* @return Matrix with simplexes along the rows of dimensionality (N, K).
*/
template <typename Mat, require_eigen_matrix_dynamic_t<Mat>* = nullptr,
require_not_st_var<Mat>* = nullptr>
inline plain_type_t<Mat> stochastic_row_constrain(const Mat& y) {
auto&& y_ref = to_ref(y);
const Eigen::Index N = y_ref.rows();
int Km1 = y_ref.cols();
plain_type_t<Mat> x(N, Km1 + 1);
using eigen_arr = Eigen::Array<scalar_type_t<Mat>, -1, 1>;
eigen_arr stick_len = eigen_arr::Constant(N, 1.0);
for (Eigen::Index k = 0; k < Km1; ++k) {
auto z_k = inv_logit(y_ref.array().col(k) - log(Km1 - k));
x.array().col(k) = stick_len * z_k;
stick_len -= x.array().col(k);
}
x.array().col(Km1) = stick_len;
return x;
}

/**
* Return a row stochastic matrix.
* The simplex transform is defined through a centered
* stick-breaking process.
*
* @tparam Mat type of the matrix
* @param y Free matrix input of dimensionality (N, K - 1).
* @param lp Log probability reference to increment.
* @return Matrix with simplexes along the rows of dimensionality (N, K).
*/
template <typename Mat, require_eigen_matrix_dynamic_t<Mat>* = nullptr,
require_not_st_var<Mat>* = nullptr>
inline plain_type_t<Mat> stochastic_row_constrain(const Mat& y,
value_type_t<Mat>& lp) {
auto&& y_ref = to_ref(y);
const Eigen::Index N = y_ref.rows();
Eigen::Index Km1 = y_ref.cols();
plain_type_t<Mat> x(N, Km1 + 1);
Eigen::Array<scalar_type_t<Mat>, -1, 1> stick_len
= Eigen::Array<scalar_type_t<Mat>, -1, 1>::Constant(N, 1.0);
for (Eigen::Index k = 0; k < Km1; ++k) {
const auto eq_share = -log(Km1 - k); // = logit(1.0/(Km1 + 1 - k));
auto adj_y_k = (y_ref.array().col(k) + eq_share).eval();
auto z_k = inv_logit(adj_y_k);
x.array().col(k) = stick_len * z_k;
lp += -sum(log1p_exp(adj_y_k)) - sum(log1p_exp(-adj_y_k))
+ sum(log(stick_len));
stick_len -= x.array().col(k); // equivalently *= (1 - z_k);
}
x.col(Km1).array() = stick_len;
return x;
}

/**
* Return a row stochastic matrix.
* If the `Jacobian` parameter is `true`, the log density accumulator is
* incremented with the log absolute Jacobian determinant of the transform. All
* of the transforms are specified with their Jacobians in the *Stan Reference
* Manual* chapter Constraint Transforms.
*
* @tparam Jacobian if `true`, increment log density accumulator with log
* absolute Jacobian determinant of constraining transform
* @tparam Mat A type inheriting from `Eigen::DenseBase` or a `var_value` with
* inner type inheriting from `Eigen::DenseBase` with compile time dynamic rows
* and dynamic columns
* @param[in] y free matrix
* @param[in, out] lp log density accumulator
* @return Matrix with simplexes along the rows of dimensionality (N, K).
*/
template <bool Jacobian, typename Mat, require_not_std_vector_t<Mat>* = nullptr>
inline plain_type_t<Mat> stochastic_row_constrain(const Mat& y,
return_type_t<Mat>& lp) {
if (Jacobian) {
return stochastic_row_constrain(y, lp);
} else {
return stochastic_row_constrain(y);
}
}

/**
* Return a row stochastic matrix.
* If the `Jacobian` parameter is `true`, the log density accumulator is
* incremented with the log absolute Jacobian determinant of the transform. All
* of the transforms are specified with their Jacobians in the *Stan Reference
* Manual* chapter Constraint Transforms.
*
* @tparam Jacobian if `true`, increment log density accumulator with log
* absolute Jacobian determinant of constraining transform
* @tparam T A standard vector with inner type inheriting from
* `Eigen::DenseBase` or a `var_value` with inner type inheriting from
* `Eigen::DenseBase` with compile time dynamic rows and dynamic columns
* @param[in] y free vector with matrices of size (N, K - 1)
* @param[in, out] lp log density accumulator
* @return vector of matrices with simplex rows of dimensionality (N, K)
*/
template <bool Jacobian, typename T, require_std_vector_t<T>* = nullptr>
inline auto stochastic_row_constrain(const T& y, return_type_t<T>& lp) {
return apply_vector_unary<T>::apply(
y, [&lp](auto&& v) { return stochastic_row_constrain<Jacobian>(v, lp); });
}

} // namespace math
} // namespace stan

#endif
Loading

0 comments on commit 8f0731a

Please sign in to comment.