Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simplex row and column matrix constraints #2992

Merged
merged 25 commits into from
Mar 23, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
f086818
adds row and column simplex matrix constraints
SteveBronder Dec 20, 2023
cff1d61
Fix rowwise simplex constrain to be vectorized
SteveBronder Dec 21, 2023
7312332
fix docs
SteveBronder Dec 22, 2023
334b0c6
remove apple debug stuff
SteveBronder Dec 22, 2023
c79d8a1
remove apple debug stuff
SteveBronder Dec 22, 2023
8b75231
clangformat tests
SteveBronder Dec 22, 2023
724712c
Merge commit 'ae90c3edca5d3b5e5234fdd6cbd1146f1666761d' into HEAD
yashikno Dec 22, 2023
3517906
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Dec 22, 2023
35a67b5
add newline
SteveBronder Dec 22, 2023
9442c29
simplex col/row test names
SteveBronder Dec 22, 2023
a3db197
address review comments
SteveBronder Jan 3, 2024
06d08f3
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jan 3, 2024
17aaddb
restart jenkins
SteveBronder Jan 4, 2024
a8f8062
Merge remote-tracking branch 'origin/develop' into feature/simplex_ro…
SteveBronder Feb 15, 2024
cad36a4
update docs
SteveBronder Feb 15, 2024
4575e4b
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Feb 15, 2024
69ce83b
Merge remote-tracking branch 'origin/develop' into feature/simplex_ro…
SteveBronder Feb 27, 2024
845d4e5
update for review comments
SteveBronder Feb 27, 2024
7b0c6be
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Feb 27, 2024
c2e938d
Adds simplex matrix free functions along with tests
SteveBronder Feb 27, 2024
e42fe2f
Merge commit 'b6a227ce58c41b5c2a956e793f86b3e59b43eaf1' into HEAD
yashikno Feb 27, 2024
dbe4416
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Feb 27, 2024
d819287
update naming to be row and column stochastic matrices
SteveBronder Feb 29, 2024
a736fac
Merge commit 'c5e8f0844433af4d64ea8ac2086bdacd5e213477' into HEAD
yashikno Feb 29, 2024
df5e1d1
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Feb 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions stan/math/prim/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,10 @@
#include <stan/math/prim/fun/select.hpp>
#include <stan/math/prim/fun/sign.hpp>
#include <stan/math/prim/fun/signbit.hpp>
#include <stan/math/prim/fun/simplex_column_constrain.hpp>
#include <stan/math/prim/fun/simplex_constrain.hpp>
#include <stan/math/prim/fun/simplex_free.hpp>
#include <stan/math/prim/fun/simplex_row_constrain.hpp>
#include <stan/math/prim/fun/sin.hpp>
#include <stan/math/prim/fun/singular_values.hpp>
#include <stan/math/prim/fun/sinh.hpp>
Expand Down
117 changes: 117 additions & 0 deletions stan/math/prim/fun/simplex_column_constrain.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#ifndef STAN_MATH_PRIM_FUN_SIMPLEX_COLUMN_CONSTRAIN_HPP
#define STAN_MATH_PRIM_FUN_SIMPLEX_COLUMN_CONSTRAIN_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/fun/logit.hpp>
#include <stan/math/prim/fun/simplex_constrain.hpp>
#include <cmath>

namespace stan {
namespace math {

/**
* Return a matrix with columns as simplex vectors.
* A simplex is a vector containing values greater than or equal
* to 0 that sum to 1. A vector with (K-1) unconstrained values
* will produce a simplex of size K.
*
* The transform is based on a centered stick-breaking process.
*
* @tparam Mat type of the Matrix
* @param y Free Matrix input of dimensionality (K - 1, M)
* @return Matrix with simplex columns of dimensionality (K, M)
*/
template <typename Mat, require_eigen_matrix_dynamic_t<Mat>* = nullptr,
require_not_st_var<Mat>* = nullptr>
inline plain_type_t<Mat> simplex_column_constrain(const Mat& y) {
// cut & paste simplex_column_constrain(Eigen::Matrix, T) w/o Jacobian
auto&& y_ref = to_ref(y);
const Eigen::Index M = y_ref.cols();
plain_type_t<Mat> ret(y_ref.rows() + 1, M);
for (Eigen::Index i = 0; i < M; ++i) {
ret.col(i) = simplex_constrain(y_ref.col(i));
}
return ret;
}

/**
* Return a matrix with columns as simplex vectors
* and increment the specified log probability reference with
* the log absolute Jacobian determinant of the transform.
*
* The simplex transform is defined through a centered
* stick-breaking process.
*
* @tparam Mat type of the Matrix
* @param y Free Matrix input of dimensionality (K - 1, M)
* @param lp Log probability reference to increment.
* @return Matrix with simplex columns of dimensionality (K, M)
*/
template <typename Mat, require_eigen_matrix_dynamic_t<Mat>* = nullptr,
require_not_st_var<Mat>* = nullptr>
inline plain_type_t<Mat> simplex_column_constrain(const Mat& y,
value_type_t<Mat>& lp) {
auto&& y_ref = to_ref(y);
const Eigen::Index M = y_ref.cols();
plain_type_t<Mat> ret(y_ref.rows() + 1, M);
for (Eigen::Index i = 0; i < M; ++i) {
ret.col(i) = simplex_constrain(y_ref.col(i), lp);
}
return ret;
}

/**
* Return a matrix with columns as simplex vectors. If the
* `Jacobian` parameter is `true`, the log density accumulator is incremented
* with the log absolute Jacobian determinant of the transform. All of the
* transforms are specified with their Jacobians in the *Stan Reference Manual*
* chapter Constraint Transforms.
*
* @tparam Jacobian if `true`, increment log density accumulator with log
* absolute Jacobian determinant of constraining transform
* @tparam Mat type of the Matrix
* @param y Free Matrix input of dimensionality (K - 1, M).
* @param[in, out] lp log density accumulator
* @return Matrix with simplex columns of dimensionality (K, M).
*/
template <bool Jacobian, typename Mat, require_not_std_vector_t<Mat>* = nullptr>
inline plain_type_t<Mat> simplex_column_constrain(const Mat& y,
return_type_t<Mat>& lp) {
if (Jacobian) {
return simplex_column_constrain(y, lp);
} else {
return simplex_column_constrain(y);
}
}

/**
* Return a standard vector of matrices with columns as simplex vectors. If the
* `Jacobian` parameter is `true`, the log density accumulator is incremented
* with the log absolute Jacobian determinant of the transform. All of the
* transforms are specified with their Jacobians in the *Stan Reference Manual*
* chapter Constraint Transforms.
*
* @tparam Jacobian if `true`, increment log density accumulator with log
* absolute Jacobian determinant of constraining transform
* @tparam T A standard vector with inner type inheriting from
* `Eigen::DenseBase` or a `var_value` with inner type inheriting from
* `Eigen::DenseBase` with compile time dynamic rows and dynamic columns
* @param[in] y free vector
* @param[in, out] lp log density accumulator
* @return Standard vector containing matrices with simplex columns of
* dimensionality (K, M).
*/
template <bool Jacobian, typename T, require_std_vector_t<T>* = nullptr>
inline auto simplex_column_constrain(const T& y, return_type_t<T>& lp) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
inline auto simplex_column_constrain(const T& y, return_type_t<T>& lp) {
inline plain_type_t<T> simplex_column_constrain(const T& y, return_type_t<T>& lp) {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one we want to keep auto because it can return a vector with any amount of vectors of matrices within it.

return apply_vector_unary<T>::apply(
y, [&lp](auto&& v) { return simplex_column_constrain<Jacobian>(v, lp); });
}

} // namespace math
} // namespace stan

#endif
4 changes: 2 additions & 2 deletions stan/math/prim/fun/simplex_constrain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace math {
* @param y Free vector input of dimensionality K - 1.
* @return Simplex of dimensionality K.
*/
template <typename Vec, require_eigen_col_vector_t<Vec>* = nullptr,
template <typename Vec, require_eigen_vector_t<Vec>* = nullptr,
require_not_st_var<Vec>* = nullptr>
inline auto simplex_constrain(const Vec& y) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
inline auto simplex_constrain(const Vec& y) {
inline plain_type_t<Vec> simplex_constrain(const Vec& y) {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed all of these to plain_type_t

// cut & paste simplex_constrain(Eigen::Matrix, T) w/o Jacobian
Expand Down Expand Up @@ -56,7 +56,7 @@ inline auto simplex_constrain(const Vec& y) {
* @param lp Log probability reference to increment.
* @return Simplex of dimensionality K.
*/
template <typename Vec, require_eigen_col_vector_t<Vec>* = nullptr,
template <typename Vec, require_eigen_vector_t<Vec>* = nullptr,
require_not_st_var<Vec>* = nullptr>
inline auto simplex_constrain(const Vec& y, value_type_t<Vec>& lp) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
inline auto simplex_constrain(const Vec& y, value_type_t<Vec>& lp) {
inline plain_type_t<Vec> simplex_constrain(const Vec& y, value_type_t<Vec>& lp) {

using Eigen::Dynamic;
Expand Down
133 changes: 133 additions & 0 deletions stan/math/prim/fun/simplex_row_constrain.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
#ifndef STAN_MATH_PRIM_FUN_SIMPLEX_ROW_CONSTRAIN_HPP
#define STAN_MATH_PRIM_FUN_SIMPLEX_ROW_CONSTRAIN_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/fun/logit.hpp>
#include <stan/math/prim/fun/simplex_constrain.hpp>
#include <cmath>

namespace stan {
namespace math {

/**
* Return the simplex corresponding to the specified free vector.
* A simplex is a vector containing values greater than or equal
* to 0 that sum to 1. A vector with (K-1) unconstrained values
* will produce a simplex of size K.
*
* The transform is based on a centered stick-breaking process.
*
* @tparam Mat type of the Matrix
* @param y Free Matrix input of dimensionality (N, K - 1).
* @return Matrix with simplexes along the rows of dimensionality (N, K).
*/
template <typename Mat, require_eigen_matrix_dynamic_t<Mat>* = nullptr,
require_not_st_var<Mat>* = nullptr>
inline plain_type_t<Mat> simplex_row_constrain(const Mat& y) {
auto&& y_ref = to_ref(y);
const Eigen::Index N = y_ref.rows();
int Km1 = y_ref.cols();
plain_type_t<Mat> x(N, Km1 + 1);
using eigen_arr = Eigen::Array<scalar_type_t<Mat>, -1, 1>;
syclik marked this conversation as resolved.
Show resolved Hide resolved
eigen_arr stick_len = eigen_arr::Constant(N, 1.0);
for (Eigen::Index k = 0; k < Km1; ++k) {
auto z_k = inv_logit(y_ref.array().col(k) - log(Km1 - k));
x.array().col(k) = stick_len * z_k;
stick_len -= x.array().col(k);
}
x.array().col(Km1) = stick_len;
return x;
}

/**
* Return a matrix with simplex rows corresponding to the specified free matrix
* and increment the specified log probability reference with
* the log absolute Jacobian determinant of the transform.
*
* The simplex transform is defined through a centered
* stick-breaking process.
*
* @tparam Mat type of the matrix
* @param y Free matrix input of dimensionality (N, K - 1).
* @param lp Log probability reference to increment.
* @return Matrix with simplexes along the rows of dimensionality (N, K).
*/
template <typename Mat, require_eigen_matrix_dynamic_t<Mat>* = nullptr,
require_not_st_var<Mat>* = nullptr>
inline plain_type_t<Mat> simplex_row_constrain(const Mat& y,
value_type_t<Mat>& lp) {
auto&& y_ref = to_ref(y);
const Eigen::Index N = y_ref.rows();
Eigen::Index Km1 = y_ref.cols();
plain_type_t<Mat> x(N, Km1 + 1);
Eigen::Array<scalar_type_t<Mat>, -1, 1> stick_len
= Eigen::Array<scalar_type_t<Mat>, -1, 1>::Constant(N, 1.0);
for (Eigen::Index k = 0; k < Km1; ++k) {
const auto eq_share = -log(Km1 - k); // = logit(1.0/(Km1 + 1 - k));
auto adj_y_k = (y_ref.array().col(k) + eq_share).eval();
auto z_k = inv_logit(adj_y_k);
x.array().col(k) = stick_len * z_k;
lp += -sum(log1p_exp(adj_y_k)) - sum(log1p_exp(-adj_y_k))
+ sum(log(stick_len));
stick_len -= x.array().col(k); // equivalently *= (1 - z_k);
}
x.col(Km1).array() = stick_len;
return x;
}

/**
* Return a matrix with simplex rows corresponding to the specified free matrix.
* If the `Jacobian` parameter is `true`, the log density accumulator is
* incremented with the log absolute Jacobian determinant of the transform. All
* of the transforms are specified with their Jacobians in the *Stan Reference
* Manual* chapter Constraint Transforms.
*
* @tparam Jacobian if `true`, increment log density accumulator with log
* absolute Jacobian determinant of constraining transform
* @tparam Mat A type inheriting from `Eigen::DenseBase` or a `var_value` with
* inner type inheriting from `Eigen::DenseBase` with compile time dynamic rows
* and dynamic columns
* @param[in] y free matrix
* @param[in, out] lp log density accumulator
* @return Matrix with simplexes along the rows of dimensionality (N, K).
*/
template <bool Jacobian, typename Mat, require_not_std_vector_t<Mat>* = nullptr>
inline plain_type_t<Mat> simplex_row_constrain(const Mat& y,
return_type_t<Mat>& lp) {
if (Jacobian) {
return simplex_row_constrain(y, lp);
} else {
return simplex_row_constrain(y);
}
}

/**
* Return the simplex corresponding to the specified free vector. If the
* `Jacobian` parameter is `true`, the log density accumulator is incremented
* with the log absolute Jacobian determinant of the transform. All of the
* transforms are specified with their Jacobians in the *Stan Reference Manual*
* chapter Constraint Transforms.
*
* @tparam Jacobian if `true`, increment log density accumulator with log
* absolute Jacobian determinant of constraining transform
* @tparam T A standard vector with inner type inheriting from
* `Eigen::DenseBase` or a `var_value` with inner type inheriting from
* `Eigen::DenseBase` with compile time dynamic rows and dynamic columns
* @param[in] y free vector with matrices of size (N, K - 1)
* @param[in, out] lp log density accumulator
* @return vector of matrices with simplex rows of dimensionality (N, K)
*/
template <bool Jacobian, typename T, require_std_vector_t<T>* = nullptr>
inline auto simplex_row_constrain(const T& y, return_type_t<T>& lp) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
inline auto simplex_row_constrain(const T& y, return_type_t<T>& lp) {
inline plain_type_t<T> simplex_row_constrain(const T& y, return_type_t<T>& lp) {

Copy link
Collaborator Author

@SteveBronder SteveBronder Feb 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above for std vectors of std vectors of matrices. idt plain_type_t here gives any more information

return apply_vector_unary<T>::apply(
y, [&lp](auto&& v) { return simplex_row_constrain<Jacobian>(v, lp); });
}

} // namespace math
} // namespace stan

#endif
29 changes: 27 additions & 2 deletions stan/math/rev/core/var.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,15 +366,40 @@ class var_value<T, internal::require_matrix_var_value<T>> {
var_value(S&& x) : vi_(new vari_type(std::forward<S>(x), false)) {} // NOLINT

/**
* Copy constructor for var_val.
* Copy constructor for var_val when the vari_type from `other` is directly
* assignable.
* @tparam S type of the value in the `var_value` to assing
* @param other the value to assign
* @return this
*/
template <typename S, require_assignable_t<value_type, S>* = nullptr,
template <typename S,
syclik marked this conversation as resolved.
Show resolved Hide resolved
require_assignable_t<vari_type,
typename var_value<S>::vari_type>* = nullptr,
require_all_plain_type_t<T, S>* = nullptr>
var_value(const var_value<S>& other) : vi_(other.vi_) {}

/**
* Construct from a `var_value` with different inner `vari_type`
* @tparam S An eigen type that is not the same as `T`, but can be assigned to
* `vari_value<T>`.
* @param other the value to assign
* @note This constructor is for types such as
* `vari_value<Matrix<double, -1, 1>>` and
* `vari_value<Matrix<double, 1, -1>>`. As pointers those are not
* assignable to one another, but their inner matrix types are. So the `var`
* has to make a new `vari` and assign the inner matrices to that new `vari`.
*/
template <typename S,
require_not_assignable_t<
vari_type, typename var_value<S>::vari_type>* = nullptr,
require_constructible_t<vari_type, S>* = nullptr,
require_all_plain_type_t<T, S>* = nullptr>
var_value(const var_value<S>& other) : vi_(new vari_type(other.vi_->val_)) {
reverse_pass_callback([this_vi = this->vi_, other_vi = other.vi_]() {
other_vi->adj_ += this_vi->adj_;
});
}

/**
* Construct a `var_value` with a plain type
* from another `var_value` containing an expression.
Expand Down
2 changes: 2 additions & 0 deletions stan/math/rev/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@
#include <stan/math/rev/fun/rows_dot_product.hpp>
#include <stan/math/rev/fun/rows_dot_self.hpp>
#include <stan/math/rev/fun/sd.hpp>
#include <stan/math/rev/fun/simplex_column_constrain.hpp>
#include <stan/math/rev/fun/simplex_constrain.hpp>
#include <stan/math/rev/fun/simplex_row_constrain.hpp>
#include <stan/math/rev/fun/sin.hpp>
#include <stan/math/rev/fun/singular_values.hpp>
#include <stan/math/rev/fun/svd.hpp>
Expand Down
Loading