Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add derivatives for mvn #2980

Merged
merged 27 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b6d0cf0
add derivatives for mvn
spinkney Dec 1, 2023
bdc7e52
Merge commit 'f627912fecfbbb57bbf80a39e1987a499e05f589' into HEAD
yashikno Dec 1, 2023
208658a
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Dec 1, 2023
1505d55
fix test and remove some includes not used
spinkney Dec 1, 2023
50335c4
Merge branch 'multi-normal-derivatives-2' of https://github.com/stan-…
spinkney Dec 1, 2023
69ccbcd
add value_of as per reviewer
spinkney Dec 1, 2023
b733c7f
update both
spinkney Dec 1, 2023
8afeba0
remove unnecessary declaration
spinkney Dec 1, 2023
9d1c773
fix multiplication
spinkney Dec 1, 2023
204cce8
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Dec 1, 2023
cc398d0
Update multi_normal_lpdf.hpp
spinkney Dec 1, 2023
75d8a9e
Merge commit '754e94e31d992721829da631a33fe34d4af6b0d8' into HEAD
yashikno Dec 1, 2023
a6ec92c
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Dec 1, 2023
e8a0b79
cleanup
spinkney Dec 2, 2023
ba13cb7
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Dec 2, 2023
057c18b
fix header
spinkney Dec 2, 2023
a2360de
Merge branch 'multi-normal-derivatives-2' of https://github.com/stan-…
spinkney Dec 2, 2023
27584b0
more cleanup
spinkney Dec 3, 2023
53ba6d8
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Dec 3, 2023
5530904
fix CI test
spinkney Dec 4, 2023
636a94e
use vector instead of matrix
spinkney Dec 4, 2023
bc48cbb
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Dec 4, 2023
f21fe28
Update multi_normal_lpdf.hpp
spinkney Dec 5, 2023
53018f5
final update
spinkney Dec 6, 2023
bac7e60
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Dec 6, 2023
479c04b
Merge branch 'develop' into multi-normal-derivatives-2
spinkney Dec 15, 2023
52b36e6
Merge branch 'develop' into multi-normal-derivatives-2
spinkney Feb 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 190 additions & 41 deletions stan/math/prim/prob/multi_normal_lpdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,39 @@
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/as_column_vector_or_scalar.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/dot_product.hpp>
#include <stan/math/prim/fun/eval.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log_determinant_ldlt.hpp>
#include <stan/math/prim/fun/max_size_mvt.hpp>
#include <stan/math/prim/fun/mdivide_left_ldlt.hpp>
#include <stan/math/prim/fun/size_mvt.hpp>
#include <stan/math/prim/fun/sum.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/fun/trace_inv_quad_form_ldlt.hpp>
#include <stan/math/prim/fun/transpose.hpp>
#include <stan/math/prim/fun/vector_seq_view.hpp>
#include <stan/math/prim/functor/partials_propagator.hpp>

namespace stan {
namespace math {

template <bool propto, typename T_y, typename T_loc, typename T_covar>
template <bool propto, typename T_y, typename T_loc, typename T_covar,
require_any_not_vector_vt<is_stan_scalar, T_y, T_loc>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
T_y, T_loc, T_covar>* = nullptr>
return_type_t<T_y, T_loc, T_covar> multi_normal_lpdf(const T_y& y,
const T_loc& mu,
const T_covar& Sigma) {
using T_covar_elem = typename scalar_type<T_covar>::type;
using lp_type = return_type_t<T_y, T_loc, T_covar>;
using Eigen::Dynamic;
using T_return = return_type_t<T_y, T_loc, T_covar>;
using T_partials_return = partials_return_t<T_y, T_loc, T_covar>;
using matrix_partials_t
= Eigen::Matrix<T_partials_return, Eigen::Dynamic, Eigen::Dynamic>;
using vector_partials_t = Eigen::Matrix<T_partials_return, Eigen::Dynamic, 1>;
using T_y_ref = ref_type_t<T_y>;
using T_mu_ref = ref_type_t<T_loc>;
using T_Sigma_ref = ref_type_t<T_covar>;

static const char* function = "multi_normal_lpdf";
check_positive(function, "Covariance matrix rows", Sigma.rows());

Expand All @@ -32,32 +48,36 @@ return_type_t<T_y, T_loc, T_covar> multi_normal_lpdf(const T_y& y,
return 0.0;
}

lp_type lp(0.0);
vector_seq_view<T_y> y_vec(y);
vector_seq_view<T_loc> mu_vec(mu);
size_t size_vec = max_size_mvt(y, mu);
T_y_ref y_ref = y;
T_mu_ref mu_ref = mu;
T_Sigma_ref Sigma_ref = Sigma;
vector_seq_view<T_y_ref> y_vec(y_ref);
vector_seq_view<T_mu_ref> mu_vec(mu_ref);
const size_t size_vec = max_size_mvt(y, mu);
const int K = Sigma.rows();

int size_y = y_vec[0].size();
int size_mu = mu_vec[0].size();
if (size_vec > 1) {
for (size_t i = 1, size_mvt_y = size_mvt(y); i < size_mvt_y; i++) {
check_size_match(function,
"Size of one of the vectors of "
"the random variable",
y_vec[i].size(),
"Size of the first vector of the "
"random variable",
size_y);
}
for (size_t i = 1, size_mvt_mu = size_mvt(mu); i < size_mvt_mu; i++) {
check_size_match(function,
"Size of one of the vectors of "
"the location variable",
mu_vec[i].size(),
"Size of the first vector of the "
"location variable",
size_mu);
}

// check size consistency of all random variables y
for (size_t i = 1, size_mvt_y = size_mvt(y); i < size_mvt_y; i++) {
check_size_match(function,
"Size of one of the vectors of "
"the random variable",
y_vec[i].size(),
"Size of the first vector of the "
"random variable",
size_y);
}
// check size consistency of all means mu
for (size_t i = 1, size_mvt_mu = size_mvt(mu); i < size_mvt_mu; i++) {
check_size_match(function,
"Size of one of the vectors of "
"the location variable",
mu_vec[i].size(),
"Size of the first vector of the "
"location variable",
size_mu);
}

check_size_match(function, "Size of random variable", size_y,
Expand All @@ -71,35 +91,164 @@ return_type_t<T_y, T_loc, T_covar> multi_normal_lpdf(const T_y& y,
check_finite(function, "Location parameter", mu_vec[i]);
check_not_nan(function, "Random variable", y_vec[i]);
}
const auto& Sigma_ref = to_ref(Sigma);
check_symmetric(function, "Covariance matrix", Sigma_ref);

auto ldlt_Sigma = make_ldlt_factor(Sigma_ref);
auto ldlt_Sigma = make_ldlt_factor(value_of(Sigma_ref));

check_ldlt_factor(function, "LDLT_Factor of covariance parameter",
ldlt_Sigma);

if (size_y == 0) {
return lp;
if (unlikely(size_y == 0)) {
return T_return(0);
}

auto ops_partials = make_partials_propagator(y_ref, mu_ref, Sigma_ref);

T_partials_return logp(0);

if (include_summand<propto>::value) {
lp += NEG_LOG_SQRT_TWO_PI * size_y * size_vec;
logp += NEG_LOG_SQRT_TWO_PI * size_y * size_vec;
}

if (include_summand<propto, T_covar_elem>::value) {
lp -= 0.5 * log_determinant_ldlt(ldlt_Sigma) * size_vec;
if (include_summand<propto, T_y, T_loc, T_covar_elem>::value) {
vector_partials_t half(size_vec);
vector_partials_t y_val_minus_mu_val(size_vec);

T_partials_return sum_lp_vec(0.0);
for (size_t i = 0; i < size_vec; i++) {
const auto& y_val = as_value_column_vector_or_scalar(y_vec[i]);
const auto& mu_val = as_value_column_vector_or_scalar(mu_vec[i]);
y_val_minus_mu_val = eval(y_val - mu_val);
half = mdivide_left_ldlt(ldlt_Sigma, y_val_minus_mu_val);

sum_lp_vec += dot_product(y_val_minus_mu_val, half);

if (!is_constant_all<T_y>::value) {
partials_vec<0>(ops_partials)[i] += -half;
}
if (!is_constant_all<T_loc>::value) {
partials_vec<1>(ops_partials)[i] += half;
}
if (!is_constant<T_covar_elem>::value) {
partials_vec<2>(ops_partials)[i] += 0.5 * half * half.transpose();
}
}

logp += -0.5 * sum_lp_vec;

// If the covariance is not autodiff, we can avoid computing a matrix
// inverse
if (is_constant<T_covar_elem>::value) {
if (include_summand<propto>::value) {
logp += -0.5 * log_determinant_ldlt(ldlt_Sigma) * size_vec;
}
} else {
matrix_partials_t inv_Sigma
= mdivide_left_ldlt(ldlt_Sigma, Eigen::MatrixXd::Identity(K, K));

logp += -0.5 * log_determinant_ldlt(ldlt_Sigma) * size_vec;

partials<2>(ops_partials) += -0.5 * size_vec * inv_Sigma;
}
}

return ops_partials.build(logp);
}

template <bool propto, typename T_y, typename T_loc, typename T_covar,
require_all_vector_vt<is_stan_scalar, T_y, T_loc>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
T_y, T_loc, T_covar>* = nullptr>
return_type_t<T_y, T_loc, T_covar> multi_normal_lpdf(const T_y& y,
const T_loc& mu,
const T_covar& Sigma) {
using T_covar_elem = typename scalar_type<T_covar>::type;
using T_return = return_type_t<T_y, T_loc, T_covar>;
using T_partials_return = partials_return_t<T_y, T_loc, T_covar>;
using matrix_partials_t
= Eigen::Matrix<T_partials_return, Eigen::Dynamic, Eigen::Dynamic>;
using vector_partials_t = Eigen::Matrix<T_partials_return, Eigen::Dynamic, 1>;
using T_y_ref = ref_type_t<T_y>;
using T_mu_ref = ref_type_t<T_loc>;
using T_Sigma_ref = ref_type_t<T_covar>;

static const char* function = "multi_normal_lpdf";
check_positive(function, "Covariance matrix rows", Sigma.rows());

T_y_ref y_ref = y;
T_mu_ref mu_ref = mu;
T_Sigma_ref Sigma_ref = Sigma;

decltype(auto) y_val = as_value_column_vector_or_scalar(y_ref);
decltype(auto) mu_val = as_value_column_vector_or_scalar(mu_ref);

const int size_y = y_ref.size();
const int size_mu = mu_ref.size();
const unsigned int K = Sigma.rows();

check_finite(function, "Location parameter", mu_val);
check_not_nan(function, "Random variable", y_val);

check_size_match(function, "Size of random variable", size_y,
"size of location parameter", size_mu);
check_size_match(function, "Size of random variable", size_y,
"rows of covariance parameter", Sigma.rows());
check_size_match(function, "Size of random variable", size_y,
"columns of covariance parameter", Sigma.cols());

check_symmetric(function, "Covariance matrix", Sigma_ref);

auto ldlt_Sigma = make_ldlt_factor(value_of(Sigma_ref));
check_ldlt_factor(function, "LDLT_Factor of covariance parameter",
ldlt_Sigma);

if (unlikely(size_y == 0)) {
return T_return(0);
}

auto ops_partials = make_partials_propagator(y_ref, mu_ref, Sigma_ref);

T_partials_return logp(0);

if (include_summand<propto>::value) {
logp += NEG_LOG_SQRT_TWO_PI * size_y;
}

if (include_summand<propto, T_y, T_loc, T_covar_elem>::value) {
lp_type sum_lp_vec(0.0);
for (size_t i = 0; i < size_vec; i++) {
const auto& y_col = as_column_vector_or_scalar(y_vec[i]);
const auto& mu_col = as_column_vector_or_scalar(mu_vec[i]);
sum_lp_vec += trace_inv_quad_form_ldlt(ldlt_Sigma, y_col - mu_col);
vector_partials_t half(size_y);
vector_partials_t y_val_minus_mu_val = eval(y_val - mu_val);

// If the covariance is not autodiff, we can avoid computing a matrix
// inverse
if (is_constant<T_covar_elem>::value) {
half = mdivide_left_ldlt(ldlt_Sigma, y_val_minus_mu_val);

if (include_summand<propto>::value) {
logp += -0.5 * log_determinant_ldlt(ldlt_Sigma);
}
} else {
matrix_partials_t inv_Sigma
= mdivide_left_ldlt(ldlt_Sigma, Eigen::MatrixXd::Identity(K, K));

half.noalias() = inv_Sigma * y_val_minus_mu_val;

logp += -0.5 * log_determinant_ldlt(ldlt_Sigma);

edge<2>(ops_partials).partials_
+= 0.5 * (half * half.transpose() - inv_Sigma);
}

logp += -0.5 * dot_product(y_val_minus_mu_val, half);

if (!is_constant_all<T_y>::value) {
partials<0>(ops_partials) += -half;
}
if (!is_constant_all<T_loc>::value) {
partials<1>(ops_partials) += half;
}
lp -= 0.5 * sum_lp_vec;
}
return lp;

return ops_partials.build(logp);
}

template <typename T_y, typename T_loc, typename T_covar>
Expand Down
2 changes: 1 addition & 1 deletion test/unit/math/rev/prob/multi_normal2_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ TEST_F(agrad_distributions_multi_normal_multi_row, ProptoSigma) {
stan::math::recover_memory();
}

TEST(ProbDistributionsMultiNormal, MultiNormalVar) {
TEST(ProbDistributionsMultiNormal, MultiNormalVar2) {
using Eigen::Dynamic;
using Eigen::Matrix;
using stan::math::var;
Expand Down
59 changes: 59 additions & 0 deletions test/unit/math/rev/prob/multi_normal_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#include <stan/math/rev.hpp>
#include <test/unit/math/rev/util.hpp>
#include <gtest/gtest.h>

TEST(ProbDistributionsMultiNormal, MultiNormalVar) {
using Eigen::Dynamic;
using Eigen::Matrix;
using stan::math::var;
using std::vector;
Matrix<var, Dynamic, 1> y(3, 1);
y << 2.0, -2.0, 11.0;
Matrix<var, Dynamic, 1> mu(3, 1);
mu << 1.0, -1.0, 3.0;
Matrix<var, Dynamic, Dynamic> Sigma(3, 3);
Sigma << 9.0, -3.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, 5.0;
EXPECT_FLOAT_EQ(-11.73908, stan::math::multi_normal_lpdf(y, mu, Sigma).val());
}

TEST(ProbDistributionsMultiNormal, check_varis_on_stack) {
using Eigen::Dynamic;
using Eigen::Matrix;
using stan::math::to_var;
using std::vector;
Matrix<double, Dynamic, 1> y(3, 1);
y << 2.0, -2.0, 11.0;
Matrix<double, Dynamic, 1> mu(3, 1);
mu << 1.0, -1.0, 3.0;
Matrix<double, Dynamic, Dynamic> Sigma(3, 3);
Sigma << 9.0, -3.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, 5.0;
test::check_varis_on_stack(stan::math::multi_normal_lpdf<true>(
to_var(y), to_var(mu), to_var(Sigma)));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<true>(to_var(y), to_var(mu), Sigma));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<true>(to_var(y), mu, to_var(Sigma)));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<true>(to_var(y), mu, Sigma));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<true>(y, to_var(mu), to_var(Sigma)));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<true>(y, to_var(mu), Sigma));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<true>(y, mu, to_var(Sigma)));

test::check_varis_on_stack(stan::math::multi_normal_lpdf<false>(
to_var(y), to_var(mu), to_var(Sigma)));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<false>(to_var(y), to_var(mu), Sigma));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<false>(to_var(y), mu, to_var(Sigma)));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<false>(to_var(y), mu, Sigma));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<false>(y, to_var(mu), to_var(Sigma)));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<false>(y, to_var(mu), Sigma));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<false>(y, mu, to_var(Sigma)));
}