Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simplex row and column matrix constraints #2992

Merged
merged 25 commits into from
Mar 23, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
f086818
adds row and column simplex matrix constraints
SteveBronder Dec 20, 2023
cff1d61
Fix rowwise simplex constrain to be vectorized
SteveBronder Dec 21, 2023
7312332
fix docs
SteveBronder Dec 22, 2023
334b0c6
remove apple debug stuff
SteveBronder Dec 22, 2023
c79d8a1
remove apple debug stuff
SteveBronder Dec 22, 2023
8b75231
clangformat tests
SteveBronder Dec 22, 2023
724712c
Merge commit 'ae90c3edca5d3b5e5234fdd6cbd1146f1666761d' into HEAD
yashikno Dec 22, 2023
3517906
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Dec 22, 2023
35a67b5
add newline
SteveBronder Dec 22, 2023
9442c29
simplex col/row test names
SteveBronder Dec 22, 2023
a3db197
address review comments
SteveBronder Jan 3, 2024
06d08f3
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jan 3, 2024
17aaddb
restart jenkins
SteveBronder Jan 4, 2024
a8f8062
Merge remote-tracking branch 'origin/develop' into feature/simplex_ro…
SteveBronder Feb 15, 2024
cad36a4
update docs
SteveBronder Feb 15, 2024
4575e4b
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Feb 15, 2024
69ce83b
Merge remote-tracking branch 'origin/develop' into feature/simplex_ro…
SteveBronder Feb 27, 2024
845d4e5
update for review comments
SteveBronder Feb 27, 2024
7b0c6be
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Feb 27, 2024
c2e938d
Adds simplex matrix free functions along with tests
SteveBronder Feb 27, 2024
e42fe2f
Merge commit 'b6a227ce58c41b5c2a956e793f86b3e59b43eaf1' into HEAD
yashikno Feb 27, 2024
dbe4416
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Feb 27, 2024
d819287
update naming to be row and column stochastic matrices
SteveBronder Feb 29, 2024
a736fac
Merge commit 'c5e8f0844433af4d64ea8ac2086bdacd5e213477' into HEAD
yashikno Feb 29, 2024
df5e1d1
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Feb 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions stan/math/prim/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,10 @@
#include <stan/math/prim/fun/select.hpp>
#include <stan/math/prim/fun/sign.hpp>
#include <stan/math/prim/fun/signbit.hpp>
#include <stan/math/prim/fun/simplex_column_constrain.hpp>
#include <stan/math/prim/fun/simplex_constrain.hpp>
#include <stan/math/prim/fun/simplex_free.hpp>
#include <stan/math/prim/fun/simplex_row_constrain.hpp>
#include <stan/math/prim/fun/sin.hpp>
#include <stan/math/prim/fun/singular_values.hpp>
#include <stan/math/prim/fun/sinh.hpp>
Expand Down
115 changes: 115 additions & 0 deletions stan/math/prim/fun/simplex_column_constrain.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#ifndef STAN_MATH_PRIM_FUN_SIMPLEX_COLUMN_CONSTRAIN_HPP
#define STAN_MATH_PRIM_FUN_SIMPLEX_COLUMN_CONSTRAIN_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/fun/logit.hpp>
#include <stan/math/prim/fun/simplex_constrain.hpp>
#include <cmath>

namespace stan {
namespace math {

/**
* Return a matrix with columns as simplex vectors.
* A simplex is a vector containing values greater than or equal
* to 0 that sum to 1. A vector with (K-1) unconstrained values
* will produce a simplex of size K.
*
* The transform is based on a centered stick-breaking process.
*
* @tparam Mat type of the Matrix
* @param y Free Matrix input of dimensionality (K - 1, M).
* @return Matrix with simplex columns of dimensionality (K, M).
*/
template <typename Mat, require_eigen_matrix_dynamic_t<Mat>* = nullptr,
require_not_st_var<Mat>* = nullptr>
inline auto simplex_column_constrain(const Mat& y) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make the return type more explicit? (Instead of auto)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we can make it plain_type_t<Mat> and that should be fine

// cut & paste simplex_column_constrain(Eigen::Matrix, T) w/o Jacobian
auto&& y_ref = to_ref(y);
const Eigen::Index M = y_ref.cols();
plain_type_t<Mat> ret(y_ref.rows() + 1, M);
for (Eigen::Index i = 0; i < M; ++i) {
ret.col(i) = simplex_constrain(y_ref.col(i));
}
return ret;
}

/**
* Return a matrix with columns as simplex vectors
* and increment the specified log probability reference with
* the log absolute Jacobian determinant of the transform.
*
* The simplex transform is defined through a centered
* stick-breaking process.
*
* @tparam Mat type of the Matrix
* @param y Free Matrix input of dimensionality (K - 1, M).
* @param lp Log probability reference to increment.
* @return Matrix with simplex columns of dimensionality (K, M).
*/
template <typename Mat, require_eigen_matrix_dynamic_t<Mat>* = nullptr,
require_not_st_var<Mat>* = nullptr>
inline auto simplex_column_constrain(const Mat& y, value_type_t<Mat>& lp) {
auto&& y_ref = to_ref(y);
const Eigen::Index M = y_ref.cols();
plain_type_t<Mat> ret(y_ref.rows() + 1, M);
for (Eigen::Index i = 0; i < M; ++i) {
ret.col(i) = simplex_constrain(y_ref.col(i), lp);
}
return ret;
}

/**
* Return a matrix with columns as simplex vectors. If the
* `Jacobian` parameter is `true`, the log density accumulator is incremented
* with the log absolute Jacobian determinant of the transform. All of the
* transforms are specified with their Jacobians in the *Stan Reference Manual*
* chapter Constraint Transforms.
*
* @tparam Jacobian if `true`, increment log density accumulator with log
* absolute Jacobian determinant of constraining transform
* @tparam Mat type of the Matrix
* @param y Free Matrix input of dimensionality (K - 1, M).
* @param[in, out] lp log density accumulator
* @return Matrix with simplex columns of dimensionality (K, M).
*/
template <bool Jacobian, typename Mat, require_not_std_vector_t<Mat>* = nullptr>
auto simplex_column_constrain(const Mat& y, return_type_t<Mat>& lp) {
if (Jacobian) {
return simplex_column_constrain(y, lp);
} else {
return simplex_column_constrain(y);
}
}

/**
* Return a standard vector of matrices with columns as simplex vectors. If the
* `Jacobian` parameter is `true`, the log density accumulator is incremented
* with the log absolute Jacobian determinant of the transform. All of the
* transforms are specified with their Jacobians in the *Stan Reference Manual*
* chapter Constraint Transforms.
*
* @tparam Jacobian if `true`, increment log density accumulator with log
* absolute Jacobian determinant of constraining transform
* @tparam T A standard vector with inner type inheriting from
* `Eigen::DenseBase` or a `var_value` with inner type inheriting from
* `Eigen::DenseBase` with compile time dynamic rows and dynamic columns
* @param[in] y free vector
* @param[in, out] lp log density accumulator
* @return Standard vector containing matrices with simplex columns of
* dimensionality (K, M).
*/
template <bool Jacobian, typename T, require_std_vector_t<T>* = nullptr>
inline auto simplex_column_constrain(const T& y, return_type_t<T>& lp) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
inline auto simplex_column_constrain(const T& y, return_type_t<T>& lp) {
inline plain_type_t<T> simplex_column_constrain(const T& y, return_type_t<T>& lp) {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one we want to keep auto because it can return a vector with any amount of vectors of matrices within it.

return apply_vector_unary<T>::apply(
y, [&lp](auto&& v) { return simplex_column_constrain<Jacobian>(v, lp); });
}

} // namespace math
} // namespace stan

#endif
4 changes: 2 additions & 2 deletions stan/math/prim/fun/simplex_constrain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace math {
* @param y Free vector input of dimensionality K - 1.
* @return Simplex of dimensionality K.
*/
template <typename Vec, require_eigen_col_vector_t<Vec>* = nullptr,
template <typename Vec, require_eigen_vector_t<Vec>* = nullptr,
require_not_st_var<Vec>* = nullptr>
inline auto simplex_constrain(const Vec& y) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
inline auto simplex_constrain(const Vec& y) {
inline plain_type_t<Vec> simplex_constrain(const Vec& y) {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed all of these to plain_type_t

// cut & paste simplex_constrain(Eigen::Matrix, T) w/o Jacobian
Expand Down Expand Up @@ -56,7 +56,7 @@ inline auto simplex_constrain(const Vec& y) {
* @param lp Log probability reference to increment.
* @return Simplex of dimensionality K.
*/
template <typename Vec, require_eigen_col_vector_t<Vec>* = nullptr,
template <typename Vec, require_eigen_vector_t<Vec>* = nullptr,
require_not_st_var<Vec>* = nullptr>
inline auto simplex_constrain(const Vec& y, value_type_t<Vec>& lp) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
inline auto simplex_constrain(const Vec& y, value_type_t<Vec>& lp) {
inline plain_type_t<Vec> simplex_constrain(const Vec& y, value_type_t<Vec>& lp) {

using Eigen::Dynamic;
Expand Down
132 changes: 132 additions & 0 deletions stan/math/prim/fun/simplex_row_constrain.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#ifndef STAN_MATH_PRIM_FUN_SIMPLEX_ROW_CONSTRAIN_HPP
#define STAN_MATH_PRIM_FUN_SIMPLEX_ROW_CONSTRAIN_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/fun/logit.hpp>
#include <stan/math/prim/fun/simplex_constrain.hpp>
#include <cmath>

namespace stan {
namespace math {

/**
* Return the simplex corresponding to the specified free vector.
* A simplex is a vector containing values greater than or equal
* to 0 that sum to 1. A vector with (K-1) unconstrained values
* will produce a simplex of size K.
*
* The transform is based on a centered stick-breaking process.
*
* @tparam Mat type of the Matrix
* @param y Free Matrix input of dimensionality (N, K - 1).
* @return Matrix with simplexes along the rows of dimensionality (N, K).
*/
template <typename Mat, require_eigen_matrix_dynamic_t<Mat>* = nullptr,
require_not_st_var<Mat>* = nullptr>
inline auto simplex_row_constrain(const Mat& y) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here for being more explicit with types instead of auto.

auto&& y_ref = to_ref(y);
const Eigen::Index N = y_ref.rows();
int Km1 = y_ref.cols();
plain_type_t<Mat> x(N, Km1 + 1);
using eigen_arr = Eigen::Array<scalar_type_t<Mat>, -1, 1>;
syclik marked this conversation as resolved.
Show resolved Hide resolved
eigen_arr stick_len = eigen_arr::Constant(N, 1.0);
for (Eigen::Index k = 0; k < Km1; ++k) {
auto z_k = inv_logit(y_ref.array().col(k) - log(Km1 - k));
x.array().col(k) = stick_len * z_k;
stick_len -= x.array().col(k);
}
x.array().col(Km1) = stick_len;
return x;
}

/**
* Return a matrix with simplex rows corresponding to the specified free matrix
* and increment the specified log probability reference with
* the log absolute Jacobian determinant of the transform.
*
* The simplex transform is defined through a centered
* stick-breaking process.
*
* @tparam Mat type of the matrix
* @param y Free matrix input of dimensionality (N, K - 1).
* @param lp Log probability reference to increment.
* @return Matrix with simplexes along the rows of dimensionality (N, K).
*/
template <typename Mat, require_eigen_matrix_dynamic_t<Mat>* = nullptr,
require_not_st_var<Mat>* = nullptr>
inline auto simplex_row_constrain(const Mat& y, value_type_t<Mat>& lp) {
auto&& y_ref = to_ref(y);
const Eigen::Index N = y_ref.rows();
Eigen::Index Km1 = y_ref.cols();
plain_type_t<Mat> x(N, Km1 + 1);
Eigen::Array<scalar_type_t<Mat>, -1, 1> stick_len
= Eigen::Array<scalar_type_t<Mat>, -1, 1>::Constant(N, 1.0);
for (Eigen::Index k = 0; k < Km1; ++k) {
auto eq_share = -log(Km1 - k); // = logit(1.0/(Km1 + 1 - k));
auto adj_y_k = (y_ref.array().col(k) + eq_share).eval();
auto z_k = inv_logit(adj_y_k);
x.array().col(k) = stick_len * z_k;
lp += sum(log(stick_len));
lp -= sum(log1p_exp(-adj_y_k));
lp -= sum(log1p_exp(adj_y_k));
stick_len -= x.array().col(k); // equivalently *= (1 - z_k);
}
x.col(Km1).array() = stick_len;
return x;
}

/**
* Return a matrix with simplex rows corresponding to the specified free matrix.
* If the `Jacobian` parameter is `true`, the log density accumulator is
* incremented with the log absolute Jacobian determinant of the transform. All
* of the transforms are specified with their Jacobians in the *Stan Reference
* Manual* chapter Constraint Transforms.
*
* @tparam Jacobian if `true`, increment log density accumulator with log
* absolute Jacobian determinant of constraining transform
* @tparam Mat A type inheriting from `Eigen::DenseBase` or a `var_value` with
* inner type inheriting from `Eigen::DenseBase` with compile time dynamic rows
* and dynamic columns
* @param[in] y free matrix
* @param[in, out] lp log density accumulator
* @return Matrix with simplexes along the rows of dimensionality (N, K).
*/
template <bool Jacobian, typename Mat, require_not_std_vector_t<Mat>* = nullptr>
inline auto simplex_row_constrain(const Mat& y, return_type_t<Mat>& lp) {
if (Jacobian) {
return simplex_row_constrain(y, lp);
} else {
return simplex_row_constrain(y);
}
}

/**
* Return the simplex corresponding to the specified free vector. If the
* `Jacobian` parameter is `true`, the log density accumulator is incremented
* with the log absolute Jacobian determinant of the transform. All of the
* transforms are specified with their Jacobians in the *Stan Reference Manual*
* chapter Constraint Transforms.
*
* @tparam Jacobian if `true`, increment log density accumulator with log
* absolute Jacobian determinant of constraining transform
* @tparam T A standard vector with inner type inheriting from
* `Eigen::DenseBase` or a `var_value` with inner type inheriting from
* `Eigen::DenseBase` with compile time dynamic rows and dynamic columns
* @param[in] y free vector with matrices of size (N, K - 1)
* @param[in, out] lp log density accumulator
* @return vector of matrices with simplex rows of dimensionality (N, K)
*/
template <bool Jacobian, typename T, require_std_vector_t<T>* = nullptr>
inline auto simplex_row_constrain(const T& y, return_type_t<T>& lp) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
inline auto simplex_row_constrain(const T& y, return_type_t<T>& lp) {
inline plain_type_t<T> simplex_row_constrain(const T& y, return_type_t<T>& lp) {

Copy link
Collaborator Author

@SteveBronder SteveBronder Feb 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above for std vectors of std vectors of matrices. idt plain_type_t here gives any more information

return apply_vector_unary<T>::apply(
y, [&lp](auto&& v) { return simplex_row_constrain<Jacobian>(v, lp); });
}

} // namespace math
} // namespace stan

#endif
15 changes: 14 additions & 1 deletion stan/math/rev/core/var.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,23 @@ class var_value<T, internal::require_matrix_var_value<T>> {
* @param other the value to assign
* @return this
*/
template <typename S, require_assignable_t<value_type, S>* = nullptr,
template <typename S,
syclik marked this conversation as resolved.
Show resolved Hide resolved
require_assignable_t<vari_type,
typename var_value<S>::vari_type>* = nullptr,
require_all_plain_type_t<T, S>* = nullptr>
var_value(const var_value<S>& other) : vi_(other.vi_) {}

template <typename S,
require_not_assignable_t<
vari_type, typename var_value<S>::vari_type>* = nullptr,
require_constructible_t<vari_type, S>* = nullptr,
require_all_plain_type_t<T, S>* = nullptr>
var_value(const var_value<S>& other) : vi_(new vari_type(other.vi_->val_)) {
reverse_pass_callback([this_vi = this->vi_, other_vi = other.vi_]() {
other_vi->adj_ += this_vi->adj_;
});
}

/**
* Construct a `var_value` with a plain type
* from another `var_value` containing an expression.
Expand Down
2 changes: 2 additions & 0 deletions stan/math/rev/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@
#include <stan/math/rev/fun/rows_dot_product.hpp>
#include <stan/math/rev/fun/rows_dot_self.hpp>
#include <stan/math/rev/fun/sd.hpp>
#include <stan/math/rev/fun/simplex_column_constrain.hpp>
#include <stan/math/rev/fun/simplex_constrain.hpp>
#include <stan/math/rev/fun/simplex_row_constrain.hpp>
#include <stan/math/rev/fun/sin.hpp>
#include <stan/math/rev/fun/singular_values.hpp>
#include <stan/math/rev/fun/svd.hpp>
Expand Down
Loading