-
-
Notifications
You must be signed in to change notification settings - Fork 365
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2575 from roualdes/feature/issue-2569-analysis-ap…
…i-ess Feature/2569 Analysis API: compute effective sample size
- Loading branch information
Showing
4 changed files
with
314 additions
and
174 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
#ifndef STAN_ANALYZE_MCMC_COMPUTE_EFFECTIVE_SAMPLE_SIZE_HPP | ||
#define STAN_ANALYZE_MCMC_COMPUTE_EFFECTIVE_SAMPLE_SIZE_HPP | ||
|
||
#include <stan/math/prim/mat.hpp> | ||
#include <algorithm> | ||
#include <vector> | ||
|
||
namespace stan { | ||
namespace analyze { | ||
|
||
/** | ||
* Returns the effective sample size for the specified parameter | ||
* across all kept samples. | ||
* | ||
* See more details in Stan reference manual section "Effective | ||
* Sample Size". http://mc-stan.org/users/documentation | ||
* | ||
* Current implementation assumes chains are all of equal size and | ||
* draws are stored in contiguous blocks of memory. | ||
* | ||
* @param std::vector stores pointers to arrays of chains | ||
* @param std::vector stores sizes of chains | ||
* @return effective sample size for the specified parameter | ||
*/ | ||
double compute_effective_sample_size(std::vector<const double*> draws, | ||
std::vector<size_t> sizes) { | ||
int num_chains = sizes.size(); | ||
|
||
// need to generalize to each jagged draws per chain | ||
size_t num_draws = sizes[0]; | ||
for (int chain = 1; chain < num_chains; ++chain) { | ||
num_draws = std::min(num_draws, sizes[chain]); | ||
} | ||
|
||
Eigen::Matrix<Eigen::VectorXd, Eigen::Dynamic, 1> acov(num_chains); | ||
Eigen::VectorXd chain_mean(num_chains); | ||
Eigen::VectorXd chain_var(num_chains); | ||
for (int chain = 0; chain < num_chains; ++chain) { | ||
Eigen::Map<const Eigen::Matrix<double, Eigen::Dynamic, 1>> | ||
draw(draws[chain], sizes[chain]); | ||
math::autocovariance<double>(draw, acov(chain)); | ||
chain_mean(chain) = draw.mean(); | ||
chain_var(chain) = acov(chain)(0) * num_draws / (num_draws - 1); | ||
} | ||
|
||
double mean_var = chain_var.mean(); | ||
double var_plus = mean_var * (num_draws - 1) / num_draws; | ||
if (num_chains > 1) | ||
var_plus += math::variance(chain_mean); | ||
Eigen::VectorXd rho_hat_s(num_draws); | ||
rho_hat_s.setZero(); | ||
Eigen::VectorXd acov_s(num_chains); | ||
for (int chain = 0; chain < num_chains; ++chain) | ||
acov_s(chain) = acov(chain)(1); | ||
double rho_hat_even = 1; | ||
double rho_hat_odd = 1 - (mean_var - acov_s.mean()) / var_plus; | ||
rho_hat_s(1) = rho_hat_odd; | ||
// Geyer's initial positive sequence | ||
int max_s = 1; | ||
for (size_t s = 1; | ||
(s < (num_draws - 2) && (rho_hat_even + rho_hat_odd) >= 0); | ||
s += 2) { | ||
for (int chain = 0; chain < num_chains; ++chain) | ||
acov_s(chain) = acov(chain)(s + 1); | ||
rho_hat_even = 1 - (mean_var - acov_s.mean()) / var_plus; | ||
for (int chain = 0; chain < num_chains; ++chain) | ||
acov_s(chain) = acov(chain)(s + 2); | ||
rho_hat_odd = 1 - (mean_var - acov_s.mean()) / var_plus; | ||
if ((rho_hat_even + rho_hat_odd) >= 0) { | ||
rho_hat_s(s + 1) = rho_hat_even; | ||
rho_hat_s(s + 2) = rho_hat_odd; | ||
} | ||
max_s = s + 2; | ||
} | ||
// Geyer's initial monotone sequence | ||
for (int s = 3; s <= max_s - 2; s += 2) { | ||
if (rho_hat_s(s + 1) + rho_hat_s(s + 2) > | ||
rho_hat_s(s - 1) + rho_hat_s(s)) { | ||
rho_hat_s(s + 1) = (rho_hat_s(s - 1) + rho_hat_s(s)) / 2; | ||
rho_hat_s(s + 2) = rho_hat_s(s + 1); | ||
} | ||
} | ||
|
||
return num_chains * num_draws / (1 + 2 * rho_hat_s.sum()); | ||
} | ||
|
||
double compute_effective_sample_size(std::vector<const double*> draws, | ||
size_t size) { | ||
int num_chains = draws.size(); | ||
std::vector<size_t> sizes(num_chains, size); | ||
return compute_effective_sample_size(draws, sizes); | ||
} | ||
} | ||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
141 changes: 141 additions & 0 deletions
141
src/test/unit/analyze/mcmc/compute_effective_sample_size_test.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
#include <stan/mcmc/chains.hpp> | ||
#include <stan/io/stan_csv_reader.hpp> | ||
#include <gtest/gtest.h> | ||
#include <fstream> | ||
#include <sstream> | ||
|
||
class ComputeEss : public testing::Test { | ||
public: | ||
void SetUp() { | ||
blocker1_stream.open("src/test/unit/mcmc/test_csv_files/blocker.1.csv"); | ||
blocker2_stream.open("src/test/unit/mcmc/test_csv_files/blocker.2.csv"); | ||
} | ||
|
||
void TearDown() { | ||
blocker1_stream.close(); | ||
blocker2_stream.close(); | ||
} | ||
std::ifstream blocker1_stream, blocker2_stream; | ||
}; | ||
|
||
TEST_F(ComputeEss,compute_effective_sample_size) { | ||
std::stringstream out; | ||
stan::io::stan_csv blocker1 = stan::io::stan_csv_reader::parse(blocker1_stream, &out); | ||
stan::io::stan_csv blocker2 = stan::io::stan_csv_reader::parse(blocker2_stream, &out); | ||
EXPECT_EQ("", out.str()); | ||
|
||
stan::mcmc::chains<> chains(blocker1); | ||
chains.add(blocker2); | ||
|
||
Eigen::VectorXd n_eff(48); | ||
n_eff << 466.099,136.953,1170.390,541.256, | ||
518.051,589.244,764.813,688.294, | ||
323.777,502.892,353.823,588.142, | ||
654.336,480.914,176.978,182.649, | ||
642.389,470.949,561.947,581.187, | ||
446.389,397.641,338.511,678.772, | ||
1442.250,837.956,869.865,951.124, | ||
619.336,875.805,233.260,786.568, | ||
910.144,231.582,907.666,747.347, | ||
720.660,195.195,944.547,767.271, | ||
723.665,1077.030,470.903,954.924, | ||
497.338,583.539,697.204,98.421; | ||
|
||
// replicates calls to stan::anlayze::compute_effective_sample_size | ||
// for any interface *without* access to chains class | ||
Eigen::Matrix<Eigen::VectorXd, Eigen::Dynamic, 1> | ||
samples(chains.num_chains()); | ||
std::vector<const double*> draws(chains.num_chains()); | ||
std::vector<size_t> sizes(chains.num_chains()); | ||
for (int index = 4; index < chains.num_params(); index++) { | ||
for (int chain = 0; chain < chains.num_chains(); ++chain) { | ||
samples(chain) = chains.samples(chain, index); | ||
draws[chain] = &samples(chain)(0); | ||
sizes[chain] = samples(chain).size(); | ||
} | ||
ASSERT_NEAR(n_eff(index - 4), | ||
stan::analyze::compute_effective_sample_size(draws, sizes), 1.0) | ||
<< "n_effective for index: " << index << ", parameter: " | ||
<< chains.param_name(index); | ||
} | ||
} | ||
|
||
TEST_F(ComputeEss,compute_effective_sample_size_convenience) { | ||
std::stringstream out; | ||
stan::io::stan_csv blocker1 = stan::io::stan_csv_reader::parse(blocker1_stream, &out); | ||
stan::io::stan_csv blocker2 = stan::io::stan_csv_reader::parse(blocker2_stream, &out); | ||
EXPECT_EQ("", out.str()); | ||
|
||
stan::mcmc::chains<> chains(blocker1); | ||
chains.add(blocker2); | ||
|
||
Eigen::VectorXd n_eff(48); | ||
n_eff << 466.099,136.953,1170.390,541.256, | ||
518.051,589.244,764.813,688.294, | ||
323.777,502.892,353.823,588.142, | ||
654.336,480.914,176.978,182.649, | ||
642.389,470.949,561.947,581.187, | ||
446.389,397.641,338.511,678.772, | ||
1442.250,837.956,869.865,951.124, | ||
619.336,875.805,233.260,786.568, | ||
910.144,231.582,907.666,747.347, | ||
720.660,195.195,944.547,767.271, | ||
723.665,1077.030,470.903,954.924, | ||
497.338,583.539,697.204,98.421; | ||
|
||
Eigen::Matrix<Eigen::VectorXd, Eigen::Dynamic, 1> | ||
samples(chains.num_chains()); | ||
std::vector<const double*> draws(chains.num_chains()); | ||
std::vector<size_t> sizes(chains.num_chains()); | ||
for (int index = 4; index < chains.num_params(); index++) { | ||
for (int chain = 0; chain < chains.num_chains(); ++chain) { | ||
samples(chain) = chains.samples(chain, index); | ||
draws[chain] = &samples(chain)(0); | ||
|
||
} | ||
size_t size = samples(0).size(); | ||
ASSERT_NEAR(n_eff(index - 4), | ||
stan::analyze::compute_effective_sample_size(draws, size), 1.0) | ||
<< "n_effective for index: " << index << ", parameter: " | ||
<< chains.param_name(index); | ||
} | ||
} | ||
|
||
TEST_F(ComputeEss,chains_compute_effective_sample_size) { | ||
std::stringstream out; | ||
stan::io::stan_csv blocker1 = stan::io::stan_csv_reader::parse(blocker1_stream, &out); | ||
stan::io::stan_csv blocker2 = stan::io::stan_csv_reader::parse(blocker2_stream, &out); | ||
EXPECT_EQ("", out.str()); | ||
|
||
stan::mcmc::chains<> chains(blocker1); | ||
chains.add(blocker2); | ||
|
||
Eigen::VectorXd n_eff(48); | ||
n_eff << 466.099,136.953,1170.390,541.256, | ||
518.051,589.244,764.813,688.294, | ||
323.777,502.892,353.823,588.142, | ||
654.336,480.914,176.978,182.649, | ||
642.389,470.949,561.947,581.187, | ||
446.389,397.641,338.511,678.772, | ||
1442.250,837.956,869.865,951.124, | ||
619.336,875.805,233.260,786.568, | ||
910.144,231.582,907.666,747.347, | ||
720.660,195.195,944.547,767.271, | ||
723.665,1077.030,470.903,954.924, | ||
497.338,583.539,697.204,98.421; | ||
|
||
// replicates calls to stan::anlayze::compute_effective_sample_size | ||
// for any interface with access to chains class | ||
for (int index = 4; index < chains.num_params(); index++) { | ||
ASSERT_NEAR(n_eff(index - 4), | ||
chains.effective_sample_size(index), 1.0) | ||
<< "n_effective for index: " << index << ", parameter: " | ||
<< chains.param_name(index); | ||
} | ||
|
||
for (int index = 0; index < chains.num_params(); index++) { | ||
std::string name = chains.param_name(index); | ||
ASSERT_EQ(chains.effective_sample_size(index), | ||
chains.effective_sample_size(name)); | ||
} | ||
} |
Oops, something went wrong.