Skip to content

Commit

Permalink
add derivatives for mvn
Browse files Browse the repository at this point in the history
  • Loading branch information
spinkney committed Dec 1, 2023
1 parent e43fc08 commit b6d0cf0
Show file tree
Hide file tree
Showing 2 changed files with 245 additions and 23 deletions.
208 changes: 185 additions & 23 deletions stan/math/prim/prob/multi_normal_lpdf.hpp
Original file line number Diff line number Diff line change
@@ -1,27 +1,49 @@
#ifndef STAN_MATH_PRIM_PROB_MULTI_NORMAL_LPDF_HPP
#define STAN_MATH_PRIM_PROB_MULTI_NORMAL_LPDF_HPP

#include <ostream>

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/as_column_vector_or_scalar.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/dot_product.hpp>
#include <stan/math/prim/fun/identity_matrix.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log_determinant_ldlt.hpp>
#include <stan/math/prim/fun/max_size_mvt.hpp>
#include <stan/math/prim/fun/mdivide_left_ldlt.hpp>
#include <stan/math/prim/fun/mdivide_right.hpp>
#include <stan/math/prim/fun/mdivide_left_tri.hpp>
#include <stan/math/prim/fun/mdivide_right_tri.hpp>
#include <stan/math/prim/fun/size_mvt.hpp>
#include <stan/math/prim/fun/sum.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/fun/trace_inv_quad_form_ldlt.hpp>
#include <stan/math/prim/fun/transpose.hpp>
#include <stan/math/prim/fun/vector_seq_view.hpp>
#include <stan/math/prim/functor/partials_propagator.hpp>

namespace stan {
namespace math {

template <bool propto, typename T_y, typename T_loc, typename T_covar>
template <bool propto, typename T_y, typename T_loc, typename T_covar,
require_any_not_vector_vt<is_stan_scalar, T_y, T_loc>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
T_y, T_loc, T_covar>* = nullptr>
return_type_t<T_y, T_loc, T_covar> multi_normal_lpdf(const T_y& y,
const T_loc& mu,
const T_covar& Sigma) {
using T_covar_elem = typename scalar_type<T_covar>::type;
using lp_type = return_type_t<T_y, T_loc, T_covar>;
using Eigen::Dynamic;
using T_return = return_type_t<T_y, T_loc, T_covar>;
using T_partials_return = partials_return_t<T_y, T_loc, T_covar>;
using matrix_partials_t
= Eigen::Matrix<T_partials_return, Eigen::Dynamic, Eigen::Dynamic>;
using vector_partials_t = Eigen::Matrix<T_partials_return, Eigen::Dynamic, 1>;
using T_y_ref = ref_type_t<T_y>;
using T_mu_ref = ref_type_t<T_loc>;
using T_Sigma_ref = ref_type_t<T_covar>;

static const char* function = "multi_normal_lpdf";
check_positive(function, "Covariance matrix rows", Sigma.rows());

Expand All @@ -32,15 +54,19 @@ return_type_t<T_y, T_loc, T_covar> multi_normal_lpdf(const T_y& y,
return 0.0;
}

lp_type lp(0.0);
vector_seq_view<T_y> y_vec(y);
vector_seq_view<T_loc> mu_vec(mu);
size_t size_vec = max_size_mvt(y, mu);
T_y_ref y_ref = y;
T_mu_ref mu_ref = mu;
T_Sigma_ref Sigma_ref = Sigma;
vector_seq_view<T_y_ref> y_vec(y_ref);
vector_seq_view<T_mu_ref> mu_vec(mu_ref);
const size_t size_vec = max_size_mvt(y, mu);
const unsigned int K = Sigma.rows();

int size_y = y_vec[0].size();
int size_mu = mu_vec[0].size();
if (size_vec > 1) {
for (size_t i = 1, size_mvt_y = size_mvt(y); i < size_mvt_y; i++) {

// check size consistency of all random variables y
for (size_t i = 1, size_mvt_y = size_mvt(y); i < size_mvt_y; i++) {
check_size_match(function,
"Size of one of the vectors of "
"the random variable",
Expand All @@ -49,6 +75,7 @@ return_type_t<T_y, T_loc, T_covar> multi_normal_lpdf(const T_y& y,
"random variable",
size_y);
}
// check size consistency of all means mu
for (size_t i = 1, size_mvt_mu = size_mvt(mu); i < size_mvt_mu; i++) {
check_size_match(function,
"Size of one of the vectors of "
Expand All @@ -58,7 +85,6 @@ return_type_t<T_y, T_loc, T_covar> multi_normal_lpdf(const T_y& y,
"location variable",
size_mu);
}
}

check_size_match(function, "Size of random variable", size_y,
"size of location parameter", size_mu);
Expand All @@ -71,35 +97,171 @@ return_type_t<T_y, T_loc, T_covar> multi_normal_lpdf(const T_y& y,
check_finite(function, "Location parameter", mu_vec[i]);
check_not_nan(function, "Random variable", y_vec[i]);
}
const auto& Sigma_ref = to_ref(Sigma);
check_symmetric(function, "Covariance matrix", Sigma_ref);

auto ldlt_Sigma = make_ldlt_factor(Sigma_ref);
check_ldlt_factor(function, "LDLT_Factor of covariance parameter",
ldlt_Sigma);

if (size_y == 0) {
return lp;
if (unlikely(size_y == 0)) {
return T_return(0);
}

auto ops_partials = make_partials_propagator(y_ref, mu_ref, Sigma_ref);

T_partials_return logp(0);

if (include_summand<propto>::value) {
lp += NEG_LOG_SQRT_TWO_PI * size_y * size_vec;
logp += NEG_LOG_SQRT_TWO_PI * size_y * size_vec;
}

if (include_summand<propto, T_y, T_loc, T_covar_elem>::value) {
Eigen::Matrix<T_partials_return, Eigen::Dynamic, Eigen::Dynamic>
y_val_minus_mu_val(size_y, size_vec);
T_return sum_lp_vec(0.0);
for (size_t i = 0; i < size_vec; i++) {
decltype(auto) y_val = as_value_column_vector_or_scalar(y_vec[i]);
decltype(auto) mu_val = as_value_column_vector_or_scalar(mu_vec[i]);
y_val_minus_mu_val.col(i) = y_val - mu_val;
}

matrix_partials_t half;
vector_partials_t D_inv = 1.0 / value_of(ldlt_Sigma.ldlt().vectorD().array());

// If the covariance is not autodiff, we can avoid computing a matrix
// inverse
if (is_constant<T_covar_elem>::value) {
half = mdivide_left_ldlt(ldlt_Sigma, y_val_minus_mu_val);

if (include_summand<propto>::value) {
logp += 0.5 * sum(log(D_inv)) * size_vec;
}
} else {
matrix_partials_t inv_Sigma
= mdivide_left_ldlt(ldlt_Sigma, Eigen::MatrixXd::Identity(K,K));

half = (inv_Sigma * y_val_minus_mu_val);

logp += 0.5 * sum(log(D_inv)) * size_vec;

partials<2>(ops_partials) += -0.5 * size_vec * inv_Sigma;

for (size_t i = 0; i < size_vec; i++) {
partials_vec<2>(ops_partials)[i] += 0.5 * half.col(i) * half.row(i);
}
}

logp += -0.5 * y_val_minus_mu_val.cwiseProduct(half).sum();

for (size_t i = 0; i < size_vec; i++) {
if (!is_constant_all<T_y>::value) {
partials_vec<0>(ops_partials)[i] -= half.col(i);
}
if (!is_constant_all<T_loc>::value) {
partials_vec<1>(ops_partials)[i] += half.col(i);
}
}
}

return ops_partials.build(logp);
}
template <bool propto, typename T_y, typename T_loc, typename T_covar,
require_all_vector_vt<is_stan_scalar, T_y, T_loc>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
T_y, T_loc, T_covar>* = nullptr>
return_type_t<T_y, T_loc, T_covar> multi_normal_lpdf(const T_y& y,
const T_loc& mu,
const T_covar& Sigma) {
using T_covar_elem = typename scalar_type<T_covar>::type;
using T_return = return_type_t<T_y, T_loc, T_covar>;
using T_partials_return = partials_return_t<T_y, T_loc, T_covar>;
using matrix_partials_t
= Eigen::Matrix<T_partials_return, Eigen::Dynamic, Eigen::Dynamic>;
using vector_partials_t = Eigen::Matrix<T_partials_return, Eigen::Dynamic, 1>;
using row_vector_partials_t
= Eigen::Matrix<T_partials_return, 1, Eigen::Dynamic>;
using T_y_ref = ref_type_t<T_y>;
using T_mu_ref = ref_type_t<T_loc>;
using T_Sigma_ref = ref_type_t<T_covar>;

static const char* function = "multi_normal_lpdf";
check_positive(function, "Covariance matrix rows", Sigma.rows());

T_y_ref y_ref = y;
T_mu_ref mu_ref = mu;
T_Sigma_ref Sigma_ref = Sigma;

decltype(auto) y_val = as_value_column_vector_or_scalar(y_ref);
decltype(auto) mu_val = as_value_column_vector_or_scalar(mu_ref);

const int size_y = y_ref.size();
const int size_mu = mu_ref.size();
const unsigned int K = Sigma.rows();

check_finite(function, "Location parameter", mu_val);
check_not_nan(function, "Random variable", y_val);

check_size_match(function, "Size of random variable", size_y,
"size of location parameter", size_mu);
check_size_match(function, "Size of random variable", size_y,
"rows of covariance parameter", Sigma.rows());
check_size_match(function, "Size of random variable", size_y,
"columns of covariance parameter", Sigma.cols());

check_symmetric(function, "Covariance matrix", Sigma_ref);

auto ldlt_Sigma = make_ldlt_factor(Sigma_ref);
check_ldlt_factor(function, "LDLT_Factor of covariance parameter",
ldlt_Sigma);

if (unlikely(size_y == 0)) {
return T_return(0);
}

if (include_summand<propto, T_covar_elem>::value) {
lp -= 0.5 * log_determinant_ldlt(ldlt_Sigma) * size_vec;
auto ops_partials = make_partials_propagator(y_ref, mu_ref, Sigma_ref);

T_partials_return logp(0);

if (include_summand<propto>::value) {
logp += NEG_LOG_SQRT_TWO_PI * size_y;
}

if (include_summand<propto, T_y, T_loc, T_covar_elem>::value) {
lp_type sum_lp_vec(0.0);
for (size_t i = 0; i < size_vec; i++) {
const auto& y_col = as_column_vector_or_scalar(y_vec[i]);
const auto& mu_col = as_column_vector_or_scalar(mu_vec[i]);
sum_lp_vec += trace_inv_quad_form_ldlt(ldlt_Sigma, y_col - mu_col);
vector_partials_t half;
vector_partials_t y_val_minus_mu_val = y_val - mu_val;
vector_partials_t D_inv = 1.0 / value_of(ldlt_Sigma.ldlt().vectorD().array());

// If the covariance is not autodiff, we can avoid computing a matrix
// inverse
if (is_constant<T_covar_elem>::value) {
half = mdivide_left_ldlt(ldlt_Sigma, y_val_minus_mu_val);

if (include_summand<propto>::value) {
logp += 0.5 * sum(log(D_inv));
}
} else {
matrix_partials_t inv_Sigma
= mdivide_left_ldlt(ldlt_Sigma, Eigen::MatrixXd::Identity(K,K));
half = (inv_Sigma * y_val_minus_mu_val);


logp += 0.5 * sum(log(D_inv));
edge<2>(ops_partials).partials_ += 0.5 * (half * half.transpose() - inv_Sigma) ;
}

// logp = sum(half.cwiseAbs2());
// std::cout << " * \n\t" << half << std::endl;
logp += -0.5 * dot_product(y_val_minus_mu_val.transpose(), half);

if (!is_constant_all<T_y>::value) {
partials<0>(ops_partials) -= half;
}
if (!is_constant_all<T_loc>::value) {
partials<1>(ops_partials) += half;
}
lp -= 0.5 * sum_lp_vec;
}
return lp;

return ops_partials.build(logp);
}

template <typename T_y, typename T_loc, typename T_covar>
Expand Down
60 changes: 60 additions & 0 deletions test/unit/math/rev/prob/multi_normal_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include <stan/math/rev.hpp>
#include <test/unit/math/rev/util.hpp>
#include <gtest/gtest.h>

TEST(ProbDistributionsMultiNormal, MultiNormalVar) {
using Eigen::Dynamic;
using Eigen::Matrix;
using stan::math::var;
using std::vector;
Matrix<var, Dynamic, 1> y(3, 1);
y << 2.0, -2.0, 11.0;
Matrix<var, Dynamic, 1> mu(3, 1);
mu << 1.0, -1.0, 3.0;
Matrix<var, Dynamic, Dynamic> Sigma(3, 3);
Sigma << 9.0, -3.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, 5.0;
EXPECT_FLOAT_EQ(-11.73908,
stan::math::multi_normal_lpdf(y, mu, Sigma).val());
}

TEST(ProbDistributionsMultiNormal, check_varis_on_stack) {
using Eigen::Dynamic;
using Eigen::Matrix;
using stan::math::to_var;
using std::vector;
Matrix<double, Dynamic, 1> y(3, 1);
y << 2.0, -2.0, 11.0;
Matrix<double, Dynamic, 1> mu(3, 1);
mu << 1.0, -1.0, 3.0;
Matrix<double, Dynamic, Dynamic> Sigma(3, 3);
Sigma << 9.0, -3.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, 5.0;
test::check_varis_on_stack(stan::math::multi_normal_lpdf<true>(
to_var(y), to_var(mu), to_var(Sigma)));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<true>(to_var(y), to_var(mu), Sigma));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<true>(to_var(y), mu, to_var(Sigma)));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<true>(to_var(y), mu, Sigma));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<true>(y, to_var(mu), to_var(Sigma)));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<true>(y, to_var(mu), Sigma));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<true>(y, mu, to_var(Sigma)));

test::check_varis_on_stack(stan::math::multi_normal_lpdf<false>(
to_var(y), to_var(mu), to_var(Sigma)));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<false>(to_var(y), to_var(mu), Sigma));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<false>(to_var(y), mu, to_var(Sigma)));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<false>(to_var(y), mu, Sigma));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<false>(y, to_var(mu), to_var(Sigma)));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<false>(y, to_var(mu), Sigma));
test::check_varis_on_stack(
stan::math::multi_normal_lpdf<false>(y, mu, to_var(Sigma)));
}

0 comments on commit b6d0cf0

Please sign in to comment.