diff --git a/stan/math/prim/prob/multi_normal_lpdf.hpp b/stan/math/prim/prob/multi_normal_lpdf.hpp index abad728f116..0a68015e0dc 100644 --- a/stan/math/prim/prob/multi_normal_lpdf.hpp +++ b/stan/math/prim/prob/multi_normal_lpdf.hpp @@ -5,24 +5,40 @@ #include #include #include +#include +#include +#include #include #include +#include #include +#include #include -#include +#include #include +#include namespace stan { namespace math { -template +template * = nullptr, + require_all_not_nonscalar_prim_or_rev_kernel_expression_t< + T_y, T_loc, T_covar>* = nullptr> return_type_t multi_normal_lpdf(const T_y& y, const T_loc& mu, const T_covar& Sigma) { using T_covar_elem = typename scalar_type::type; - using lp_type = return_type_t; - using Eigen::Dynamic; - static constexpr const char* function = "multi_normal_lpdf"; + using T_return = return_type_t; + using T_partials_return = partials_return_t; + using matrix_partials_t + = Eigen::Matrix; + using vector_partials_t = Eigen::Matrix; + using T_y_ref = ref_type_t; + using T_mu_ref = ref_type_t; + using T_Sigma_ref = ref_type_t; + + static const char* function = "multi_normal_lpdf"; check_positive(function, "Covariance matrix rows", Sigma.rows()); check_consistent_sizes_mvt(function, "y", y, "mu", mu); @@ -32,32 +48,36 @@ return_type_t multi_normal_lpdf(const T_y& y, return 0.0; } - lp_type lp(0.0); - vector_seq_view y_vec(y); - vector_seq_view mu_vec(mu); - size_t size_vec = max_size_mvt(y, mu); + T_y_ref y_ref = y; + T_mu_ref mu_ref = mu; + T_Sigma_ref Sigma_ref = Sigma; + vector_seq_view y_vec(y_ref); + vector_seq_view mu_vec(mu_ref); + const size_t size_vec = max_size_mvt(y, mu); + const int K = Sigma.rows(); int size_y = y_vec[0].size(); int size_mu = mu_vec[0].size(); - if (size_vec > 1) { - for (size_t i = 1, size_mvt_y = size_mvt(y); i < size_mvt_y; i++) { - check_size_match(function, - "Size of one of the vectors of " - "the random variable", - y_vec[i].size(), - "Size of the first vector of the " - "random variable", - size_y); - } - for (size_t i = 1, size_mvt_mu = size_mvt(mu); i < size_mvt_mu; i++) { - check_size_match(function, - "Size of one of the vectors of " - "the location variable", - mu_vec[i].size(), - "Size of the first vector of the " - "location variable", - size_mu); - } + + // check size consistency of all random variables y + for (size_t i = 1, size_mvt_y = size_mvt(y); i < size_mvt_y; i++) { + check_size_match(function, + "Size of one of the vectors of " + "the random variable", + y_vec[i].size(), + "Size of the first vector of the " + "random variable", + size_y); + } + // check size consistency of all means mu + for (size_t i = 1, size_mvt_mu = size_mvt(mu); i < size_mvt_mu; i++) { + check_size_match(function, + "Size of one of the vectors of " + "the location variable", + mu_vec[i].size(), + "Size of the first vector of the " + "location variable", + size_mu); } check_size_match(function, "Size of random variable", size_y, @@ -71,35 +91,164 @@ return_type_t multi_normal_lpdf(const T_y& y, check_finite(function, "Location parameter", mu_vec[i]); check_not_nan(function, "Random variable", y_vec[i]); } - const auto& Sigma_ref = to_ref(Sigma); check_symmetric(function, "Covariance matrix", Sigma_ref); - auto ldlt_Sigma = make_ldlt_factor(Sigma_ref); + auto ldlt_Sigma = make_ldlt_factor(value_of(Sigma_ref)); + check_ldlt_factor(function, "LDLT_Factor of covariance parameter", ldlt_Sigma); - if (size_y == 0) { - return lp; + if (unlikely(size_y == 0)) { + return T_return(0); } + auto ops_partials = make_partials_propagator(y_ref, mu_ref, Sigma_ref); + + T_partials_return logp(0); + if (include_summand::value) { - lp += NEG_LOG_SQRT_TWO_PI * size_y * size_vec; + logp += NEG_LOG_SQRT_TWO_PI * size_y * size_vec; + } + + if (include_summand::value) { + vector_partials_t half(size_vec); + vector_partials_t y_val_minus_mu_val(size_vec); + + T_partials_return sum_lp_vec(0.0); + for (size_t i = 0; i < size_vec; i++) { + const auto& y_val = as_value_column_vector_or_scalar(y_vec[i]); + const auto& mu_val = as_value_column_vector_or_scalar(mu_vec[i]); + y_val_minus_mu_val = eval(y_val - mu_val); + half = mdivide_left_ldlt(ldlt_Sigma, y_val_minus_mu_val); + + sum_lp_vec += dot_product(y_val_minus_mu_val, half); + + if (!is_constant_all::value) { + partials_vec<0>(ops_partials)[i] += -half; + } + if (!is_constant_all::value) { + partials_vec<1>(ops_partials)[i] += half; + } + if (!is_constant::value) { + partials_vec<2>(ops_partials)[i] += 0.5 * half * half.transpose(); + } + } + + logp += -0.5 * sum_lp_vec; + + // If the covariance is not autodiff, we can avoid computing a matrix + // inverse + if (is_constant::value) { + if (include_summand::value) { + logp += -0.5 * log_determinant_ldlt(ldlt_Sigma) * size_vec; + } + } else { + matrix_partials_t inv_Sigma + = mdivide_left_ldlt(ldlt_Sigma, Eigen::MatrixXd::Identity(K, K)); + + logp += -0.5 * log_determinant_ldlt(ldlt_Sigma) * size_vec; + + partials<2>(ops_partials) += -0.5 * size_vec * inv_Sigma; + } + } + + return ops_partials.build(logp); +} + +template * = nullptr, + require_all_not_nonscalar_prim_or_rev_kernel_expression_t< + T_y, T_loc, T_covar>* = nullptr> +return_type_t multi_normal_lpdf(const T_y& y, + const T_loc& mu, + const T_covar& Sigma) { + using T_covar_elem = typename scalar_type::type; + using T_return = return_type_t; + using T_partials_return = partials_return_t; + using matrix_partials_t + = Eigen::Matrix; + using vector_partials_t = Eigen::Matrix; + using T_y_ref = ref_type_t; + using T_mu_ref = ref_type_t; + using T_Sigma_ref = ref_type_t; + + static const char* function = "multi_normal_lpdf"; + check_positive(function, "Covariance matrix rows", Sigma.rows()); + + T_y_ref y_ref = y; + T_mu_ref mu_ref = mu; + T_Sigma_ref Sigma_ref = Sigma; + + decltype(auto) y_val = as_value_column_vector_or_scalar(y_ref); + decltype(auto) mu_val = as_value_column_vector_or_scalar(mu_ref); + + const int size_y = y_ref.size(); + const int size_mu = mu_ref.size(); + const unsigned int K = Sigma.rows(); + + check_finite(function, "Location parameter", mu_val); + check_not_nan(function, "Random variable", y_val); + + check_size_match(function, "Size of random variable", size_y, + "size of location parameter", size_mu); + check_size_match(function, "Size of random variable", size_y, + "rows of covariance parameter", Sigma.rows()); + check_size_match(function, "Size of random variable", size_y, + "columns of covariance parameter", Sigma.cols()); + + check_symmetric(function, "Covariance matrix", Sigma_ref); + + auto ldlt_Sigma = make_ldlt_factor(value_of(Sigma_ref)); + check_ldlt_factor(function, "LDLT_Factor of covariance parameter", + ldlt_Sigma); + + if (unlikely(size_y == 0)) { + return T_return(0); } - if (include_summand::value) { - lp -= 0.5 * log_determinant_ldlt(ldlt_Sigma) * size_vec; + auto ops_partials = make_partials_propagator(y_ref, mu_ref, Sigma_ref); + + T_partials_return logp(0); + + if (include_summand::value) { + logp += NEG_LOG_SQRT_TWO_PI * size_y; } if (include_summand::value) { - lp_type sum_lp_vec(0.0); - for (size_t i = 0; i < size_vec; i++) { - const auto& y_col = as_column_vector_or_scalar(y_vec[i]); - const auto& mu_col = as_column_vector_or_scalar(mu_vec[i]); - sum_lp_vec += trace_inv_quad_form_ldlt(ldlt_Sigma, y_col - mu_col); + vector_partials_t half(size_y); + vector_partials_t y_val_minus_mu_val = eval(y_val - mu_val); + + // If the covariance is not autodiff, we can avoid computing a matrix + // inverse + if (is_constant::value) { + half = mdivide_left_ldlt(ldlt_Sigma, y_val_minus_mu_val); + + if (include_summand::value) { + logp += -0.5 * log_determinant_ldlt(ldlt_Sigma); + } + } else { + matrix_partials_t inv_Sigma + = mdivide_left_ldlt(ldlt_Sigma, Eigen::MatrixXd::Identity(K, K)); + + half.noalias() = inv_Sigma * y_val_minus_mu_val; + + logp += -0.5 * log_determinant_ldlt(ldlt_Sigma); + + edge<2>(ops_partials).partials_ + += 0.5 * (half * half.transpose() - inv_Sigma); + } + + logp += -0.5 * dot_product(y_val_minus_mu_val, half); + + if (!is_constant_all::value) { + partials<0>(ops_partials) += -half; + } + if (!is_constant_all::value) { + partials<1>(ops_partials) += half; } - lp -= 0.5 * sum_lp_vec; } - return lp; + + return ops_partials.build(logp); } template diff --git a/test/unit/math/rev/prob/multi_normal2_test.cpp b/test/unit/math/rev/prob/multi_normal2_test.cpp index 45f6355762a..16af3fa8ac6 100644 --- a/test/unit/math/rev/prob/multi_normal2_test.cpp +++ b/test/unit/math/rev/prob/multi_normal2_test.cpp @@ -120,7 +120,7 @@ TEST_F(agrad_distributions_multi_normal_multi_row, ProptoSigma) { stan::math::recover_memory(); } -TEST(ProbDistributionsMultiNormal, MultiNormalVar) { +TEST(ProbDistributionsMultiNormal, MultiNormalVar2) { using Eigen::Dynamic; using Eigen::Matrix; using stan::math::var; diff --git a/test/unit/math/rev/prob/multi_normal_test.cpp b/test/unit/math/rev/prob/multi_normal_test.cpp new file mode 100644 index 00000000000..3757c6aa33a --- /dev/null +++ b/test/unit/math/rev/prob/multi_normal_test.cpp @@ -0,0 +1,59 @@ +#include +#include +#include + +TEST(ProbDistributionsMultiNormal, MultiNormalVar) { + using Eigen::Dynamic; + using Eigen::Matrix; + using stan::math::var; + using std::vector; + Matrix y(3, 1); + y << 2.0, -2.0, 11.0; + Matrix mu(3, 1); + mu << 1.0, -1.0, 3.0; + Matrix Sigma(3, 3); + Sigma << 9.0, -3.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, 5.0; + EXPECT_FLOAT_EQ(-11.73908, stan::math::multi_normal_lpdf(y, mu, Sigma).val()); +} + +TEST(ProbDistributionsMultiNormal, check_varis_on_stack) { + using Eigen::Dynamic; + using Eigen::Matrix; + using stan::math::to_var; + using std::vector; + Matrix y(3, 1); + y << 2.0, -2.0, 11.0; + Matrix mu(3, 1); + mu << 1.0, -1.0, 3.0; + Matrix Sigma(3, 3); + Sigma << 9.0, -3.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, 5.0; + test::check_varis_on_stack(stan::math::multi_normal_lpdf( + to_var(y), to_var(mu), to_var(Sigma))); + test::check_varis_on_stack( + stan::math::multi_normal_lpdf(to_var(y), to_var(mu), Sigma)); + test::check_varis_on_stack( + stan::math::multi_normal_lpdf(to_var(y), mu, to_var(Sigma))); + test::check_varis_on_stack( + stan::math::multi_normal_lpdf(to_var(y), mu, Sigma)); + test::check_varis_on_stack( + stan::math::multi_normal_lpdf(y, to_var(mu), to_var(Sigma))); + test::check_varis_on_stack( + stan::math::multi_normal_lpdf(y, to_var(mu), Sigma)); + test::check_varis_on_stack( + stan::math::multi_normal_lpdf(y, mu, to_var(Sigma))); + + test::check_varis_on_stack(stan::math::multi_normal_lpdf( + to_var(y), to_var(mu), to_var(Sigma))); + test::check_varis_on_stack( + stan::math::multi_normal_lpdf(to_var(y), to_var(mu), Sigma)); + test::check_varis_on_stack( + stan::math::multi_normal_lpdf(to_var(y), mu, to_var(Sigma))); + test::check_varis_on_stack( + stan::math::multi_normal_lpdf(to_var(y), mu, Sigma)); + test::check_varis_on_stack( + stan::math::multi_normal_lpdf(y, to_var(mu), to_var(Sigma))); + test::check_varis_on_stack( + stan::math::multi_normal_lpdf(y, to_var(mu), Sigma)); + test::check_varis_on_stack( + stan::math::multi_normal_lpdf(y, mu, to_var(Sigma))); +}