Skip to content

Commit

Permalink
Merge pull request #1108 from andrjohns/log_prob-constrained
Browse files Browse the repository at this point in the history
Update log_prob method to take constrained_params in init format
  • Loading branch information
rok-cesnovar committed Nov 7, 2022
2 parents 0d28770 + af8343b commit d0a98ac
Show file tree
Hide file tree
Showing 12 changed files with 141 additions and 78 deletions.
5 changes: 5 additions & 0 deletions src/cmdstan/arguments/arg_log_prob.hpp
Expand Up @@ -8,6 +8,11 @@

namespace cmdstan {

/**
* Argument used for calling the log_prob method to calculate the log density
* and its gradients with respect to a user-provided set of parameter values
* on the constrained and/or unconstrained scale
*/
class arg_log_prob : public categorical_argument {
public:
arg_log_prob() {
Expand Down
7 changes: 4 additions & 3 deletions src/cmdstan/arguments/arg_log_prob_constrained_params.hpp
Expand Up @@ -6,9 +6,10 @@
namespace cmdstan {
/**
* Argument for providing a file of parameters on the constrained scale
* for use with the 'log_prob' method. The file can be in CSV or JSON format
* and should contain a variable 'params_r' with either a vector or list/array
* of vectors of constrained parameter values.
* for use with the `log_prob` argument. The file can be in JSON or R Dump
* format, using the same structure as the 'init' argument. Like the 'init'
* argument, if the file has a '.json' extension it is treated as a JSON file,
* otherwise it is treated as an RDump file.
*/
class arg_log_prob_constrained_params : public string_argument {
public:
Expand Down
5 changes: 4 additions & 1 deletion src/cmdstan/arguments/arg_log_prob_output_file.hpp
Expand Up @@ -4,7 +4,10 @@
#include <cmdstan/arguments/singleton_argument.hpp>

namespace cmdstan {

/**
* Sub-argument to provide a custom path and filename for the output from
* the log_prob method
*/
class arg_log_prob_output_file : public string_argument {
public:
arg_log_prob_output_file() : string_argument() {
Expand Down
4 changes: 3 additions & 1 deletion src/cmdstan/arguments/arg_log_prob_unconstrained_params.hpp
Expand Up @@ -8,7 +8,9 @@ namespace cmdstan {
* Argument for providing a file of parameters on the unconstrained scale
* for use with the 'log_prob' method. The file can be in CSV or JSON format
* and should contain a variable 'params_r' with either a vector or list/array
* of vectors of unconstrained parameter values.
* of vectors of unconstrained parameter values. Like the 'init' argument, if
* the file has a '.json' extension it is treated as a JSON file, otherwise it
* is treated as an RDump file.
*/
class arg_log_prob_unconstrained_params : public string_argument {
public:
Expand Down
98 changes: 43 additions & 55 deletions src/cmdstan/command.hpp
Expand Up @@ -485,7 +485,7 @@ int command(int argc, const char *argv[]) {
return std::string("_" + std::to_string(i + id));
}
};
for (int i = 0; i < num_chains; i++) {
for (int i = 0; i < num_chains; ++i) {
auto output_filename = output_name + name_iterator(i) + output_ending;
auto unique_fstream
= std::make_unique<std::fstream>(output_filename, std::fstream::out);
Expand All @@ -504,7 +504,7 @@ int command(int argc, const char *argv[]) {
diagnostic_writers.emplace_back(nullptr, "# ");
}
}
for (int i = 0; i < num_chains; i++) {
for (int i = 0; i < num_chains; ++i) {
write_stan(sample_writers[i]);
write_model(sample_writers[i], model.model_name());
write_datetime(sample_writers[i]);
Expand Down Expand Up @@ -620,8 +620,13 @@ int command(int argc, const char *argv[]) {
std::shared_ptr<stan::io::var_context> upars_context
= get_var_context(u_fname);

u_params_r = (*upars_context).vals_r("params_r");
dims_u_params_r = (*upars_context).dims_r("params_r");
u_params_r = upars_context->vals_r("params_r");
if (u_params_r.size() == 0) {
msg << "Unconstrained parameters file has no variable 'params_r' with "
"unconstrained parameter values!";
throw std::invalid_argument(msg.str());
}
dims_u_params_r = upars_context->dims_r("params_r");

// Detect whether multiple sets of parameter values have been passed
// and set the sizes accordingly
Expand All @@ -635,59 +640,42 @@ int command(int argc, const char *argv[]) {
std::vector<std::string> param_names;
std::vector<std::vector<size_t>> param_dimss;
stan::services::get_model_parameters(model, param_names, param_dimss);

if (u_params_size > 0 && u_params_size != param_names.size()) {
size_t num_upars = model.num_params_r();
if (u_params_size > 0 && u_params_size != num_upars) {
msg << "Incorrect number of unconstrained parameters provided! "
"Model has "
<< param_names.size() << " parameters but " << u_params_size
<< " were found.";
throw std::invalid_argument(msg.str());
}

size_t c_params_vec_size = 0;
size_t c_params_size = 0;
std::vector<double> c_params_r;
std::vector<size_t> dims_c_params_r;
if (!(cpars_file->is_default())) {
std::string c_fname(cpars_file->value());
std::ifstream c_stream(c_fname.c_str());

std::shared_ptr<stan::io::var_context> cpars_context
= get_var_context(c_fname);

c_params_r = (*cpars_context).vals_r("params_r");
dims_c_params_r = (*cpars_context).dims_r("params_r");

c_params_vec_size = dims_c_params_r.size() == 2 ? dims_c_params_r[0] : 1;
c_params_size = dims_c_params_r.size() == 2 ? dims_c_params_r[1]
: dims_c_params_r[0];
}

if (c_params_size > 0 && c_params_size != param_names.size()) {
msg << "Incorrect number of constrained parameters provided! "
"Model has "
<< param_names.size() << " parameters but " << c_params_size
<< " were found.";
<< num_upars << " parameters but " << u_params_size << " were found.";
throw std::invalid_argument(msg.str());
}

bool has_cpars = !cpars_file->is_default();
// Store in single nested array to allow single loop for calc and print
size_t num_par_sets = c_params_vec_size + u_params_vec_size;
size_t num_par_sets = u_params_vec_size + has_cpars;
std::vector<std::vector<double>> params_r_ind(num_par_sets);

// Use Map with inner stride to operate on all values from parameter set
using StrideT = Eigen::Stride<1, Eigen::Dynamic>;
std::vector<int> dummy_params_i;
for (size_t i = 0; i < c_params_vec_size; i++) {
Eigen::Map<Eigen::VectorXd, 0, StrideT> map_r(
c_params_r.data() + i, c_params_size, StrideT(1, c_params_vec_size));
if (!cpars_file->is_default()) {
std::string cpars_filename(cpars_file->value());

stan::io::array_var_context context(param_names, map_r, param_dimss);
model.transform_inits(context, dummy_params_i, params_r_ind[i], &msg);
std::shared_ptr<stan::io::var_context> cpars_context
= get_var_context(cpars_filename);
std::vector<std::string> input_cpar_names;
cpars_context->names_r(input_cpar_names);

for (const std::string &m_param_name : param_names) {
if (!cpars_context->contains_r(m_param_name)) {
msg << "Constrained value(s) for parameter " << m_param_name
<< " not found!";
throw std::invalid_argument(msg.str());
}
}
model.transform_inits((*cpars_context), dummy_params_i, params_r_ind[0],
&msg);
}

for (size_t i = c_params_vec_size; i < num_par_sets; i++) {
size_t iter = i - c_params_vec_size;
// Use Map with inner stride to operate on all values from parameter set
using StrideT = Eigen::Stride<1, Eigen::Dynamic>;
for (size_t i = has_cpars; i < num_par_sets; ++i) {
size_t iter = i - has_cpars;
Eigen::Map<Eigen::VectorXd, 0, StrideT> map_r(
u_params_r.data() + iter, u_params_size,
StrideT(1, u_params_vec_size));
Expand All @@ -702,28 +690,28 @@ int command(int argc, const char *argv[]) {

std::vector<std::string> p_names;
model.constrained_param_names(p_names, false, false);
for (size_t i = 1; i < p_names.size(); i++) {
for (size_t i = 0; i < (p_names.size() - 1); ++i) {
output_stream << "g_" << p_names[i] << ",";
}
output_stream << "g_" << p_names.back() << "\n";
// Output last element separately so that no trailing comma is printed
output_stream << p_names.back() << "\n";
try {
double lp;
std::vector<double> gradients;
for (size_t i = 0; i < num_par_sets; i++) {
for (auto &&param_set : params_r_ind) {
if (jacobian_adjust) {
lp = stan::model::log_prob_grad<false, true>(
model, params_r_ind[i], dummy_params_i, gradients);
lp = stan::model::log_prob_grad<true, true>(
model, param_set, dummy_params_i, gradients);
} else {
lp = stan::model::log_prob_grad<false, false>(
model, params_r_ind[i], dummy_params_i, gradients);
lp = stan::model::log_prob_grad<true, false>(
model, param_set, dummy_params_i, gradients);
}

output_stream << lp << ",";

std::copy(gradients.begin(), gradients.end() - 1,
std::ostream_iterator<double>(output_stream, ","));
output_stream << gradients.back();
output_stream << "\n";
output_stream << gradients.back() << "\n";
}
output_stream.close();
return stan::services::error_codes::error_codes::OK;
Expand Down
18 changes: 9 additions & 9 deletions src/test/interface/log_prob_test.cpp
Expand Up @@ -12,7 +12,7 @@ using cmdstan::test::run_command_output;
class CmdStan : public testing::Test {
public:
void SetUp() {
bern_extra_model = {"src", "test", "test-models", "bern_extra_model"};
bern_log_prob_model = {"src", "test", "test-models", "bern_log_prob_model"};
bern_data = {"src", "test", "test-models", "bern.data.json"};
bern_unconstrained_params_rdump
= {"src", "test", "test-models", "bern_unconstrained_params.R"};
Expand All @@ -28,7 +28,7 @@ class CmdStan : public testing::Test {
= {"src", "test", "test-models", "bern_constrained_params_short.json"};
dev_null_path = {"/dev", "null"};
}
std::vector<std::string> bern_extra_model;
std::vector<std::string> bern_log_prob_model;
std::vector<std::string> bern_data;
std::vector<std::string> bern_unconstrained_params_rdump;
std::vector<std::string> bern_constrained_params_rdump;
Expand All @@ -41,7 +41,7 @@ class CmdStan : public testing::Test {

TEST_F(CmdStan, log_prob_good_rdump) {
std::stringstream ss;
ss << convert_model_path(bern_extra_model)
ss << convert_model_path(bern_log_prob_model)
<< " data file=" << convert_model_path(bern_data)
<< " output file=" << convert_model_path(dev_null_path)
<< " method=log_prob unconstrained_params="
Expand All @@ -55,7 +55,7 @@ TEST_F(CmdStan, log_prob_good_rdump) {

TEST_F(CmdStan, log_prob_good_json) {
std::stringstream ss;
ss << convert_model_path(bern_extra_model)
ss << convert_model_path(bern_log_prob_model)
<< " data file=" << convert_model_path(bern_data)
<< " output file=" << convert_model_path(dev_null_path)
<< " method=log_prob unconstrained_params="
Expand All @@ -69,7 +69,7 @@ TEST_F(CmdStan, log_prob_good_json) {

TEST_F(CmdStan, log_prob_good_rdump_json) {
std::stringstream ss;
ss << convert_model_path(bern_extra_model)
ss << convert_model_path(bern_log_prob_model)
<< " data file=" << convert_model_path(bern_data)
<< " output file=" << convert_model_path(dev_null_path)
<< " method=log_prob unconstrained_params="
Expand All @@ -83,7 +83,7 @@ TEST_F(CmdStan, log_prob_good_rdump_json) {

TEST_F(CmdStan, log_prob_no_params) {
std::stringstream ss;
ss << convert_model_path(bern_extra_model)
ss << convert_model_path(bern_log_prob_model)
<< " data file=" << convert_model_path(bern_data)
<< " output file=" << convert_model_path(dev_null_path)
<< " method=log_prob";
Expand All @@ -94,7 +94,7 @@ TEST_F(CmdStan, log_prob_no_params) {

TEST_F(CmdStan, log_prob_no_data) {
std::stringstream ss;
ss << convert_model_path(bern_extra_model)
ss << convert_model_path(bern_log_prob_model)
<< " output file=" << convert_model_path(dev_null_path)
<< " method=log_prob unconstrained_params="
<< convert_model_path(bern_unconstrained_params_rdump)
Expand All @@ -107,7 +107,7 @@ TEST_F(CmdStan, log_prob_no_data) {

TEST_F(CmdStan, log_prob_constrained_length) {
std::stringstream ss;
ss << convert_model_path(bern_extra_model)
ss << convert_model_path(bern_log_prob_model)
<< " data file=" << convert_model_path(bern_data)
<< " output file=" << convert_model_path(dev_null_path)
<< " method=log_prob constrained_params="
Expand All @@ -119,7 +119,7 @@ TEST_F(CmdStan, log_prob_constrained_length) {

TEST_F(CmdStan, log_prob_unconstrained_length) {
std::stringstream ss;
ss << convert_model_path(bern_extra_model)
ss << convert_model_path(bern_log_prob_model)
<< " data file=" << convert_model_path(bern_data)
<< " output file=" << convert_model_path(dev_null_path)
<< " method=log_prob unconstrained_params="
Expand Down
17 changes: 14 additions & 3 deletions src/test/test-models/bern_constrained_params.R
@@ -1,3 +1,14 @@
params_r <-
structure(c(0.5, 0.1, -1.5, -3.5, 2.6, 1.6),
.Dim = c(2, 3))
theta <- 0.8
mu_vector <-
c(-1.03192029533272, -0.479415242304312)
mu_row_vector <-
c(-1.01866970062961, 0.67699454858025)
mu_array_1d <-
c(2.15402112425165, 1.11815245534682)
mu_matrix <-
structure(c(0.176100961348628, -1.47536938928953, 2.17309623676142, -0.263993484783208),
.Dim = c(2, 2))
mu_matrix_array <-
structure(c(0.704803511993464, 0.596396994136899, -1.04596396482332, -0.454458757498793, -0.711702526766766, -0.80021133760908, -1.71856146447761,
-0.384149678939469),
.Dim = c(2, 2, 2))
21 changes: 20 additions & 1 deletion src/test/test-models/bern_constrained_params.json
@@ -1 +1,20 @@
{"params_r":[[0.5, -1.5, 2.6], [0.1, -3.5, 1.6]]}
{
"theta": 0.8,
"mu_vector": [-1.03192029533272, -0.479415242304312],
"mu_row_vector": [-1.01866970062961, 0.67699454858025],
"mu_array_1d": [2.15402112425165, 1.11815245534682],
"mu_matrix": [
[0.176100961348628, 2.17309623676142],
[-1.47536938928953, -0.263993484783208]
],
"mu_matrix_array": [
[
[0.704803511993464, -0.711702526766766],
[-1.04596396482332, -1.71856146447761]
],
[
[0.596396994136899, -0.80021133760908],
[-0.454458757498793, -0.384149678939469]
]
]
}
5 changes: 4 additions & 1 deletion src/test/test-models/bern_constrained_params_short.json
@@ -1 +1,4 @@
{"params_r":[[0.5, -1.5, 2.6], [0.1, -3.5]]}
{
"theta": 0.1,
"mu1" : 2.0
}
28 changes: 28 additions & 0 deletions src/test/test-models/bern_log_prob_model.stan
@@ -0,0 +1,28 @@
data {
int<lower=0> N;
int<lower=0,upper=1> y[N];
}
parameters {
real<lower=0,upper=1> theta;
vector[2] mu_vector;
row_vector[2] mu_row_vector;
array[2] real mu_array_1d;
matrix[2, 2] mu_matrix;
array[2] matrix[2, 2] mu_matrix_array;
}
model {
theta ~ beta(1,1);
y ~ bernoulli(theta);
mu_vector ~ std_normal();
mu_row_vector ~ std_normal();
to_vector(mu_array_1d) ~ std_normal();
to_vector(mu_matrix) ~ std_normal();
to_vector(mu_matrix_array[1]) ~ std_normal();
to_vector(mu_matrix_array[2]) ~ std_normal();
}
generated quantities {
real theta_copy = theta;
int y_rep[N];
for (n in 1:N)
y_rep[n] = bernoulli_rng(theta);
}
7 changes: 4 additions & 3 deletions src/test/test-models/bern_unconstrained_params.R
@@ -1,3 +1,4 @@
params_r <-
structure(c(0.5, 3.1, -1.5, -3.5, 2.6, 1.6),
.Dim = c(2, 3))
params_r <-
c(1.38629436111989, -1.03192029533272, -0.479415242304312, -1.01866970062961, 0.67699454858025, 2.15402112425165, 1.11815245534682,
0.176100961348628, -1.47536938928953, 2.17309623676142, -0.263993484783208, 0.704803511993464, -1.04596396482332, -0.711702526766766, -1.71856146447761,
0.596396994136899, -0.454458757498793, -0.80021133760908, -0.384149678939469)
4 changes: 3 additions & 1 deletion src/test/test-models/bern_unconstrained_params.json
@@ -1 +1,3 @@
{"params_r":[[0.5, -1.5, 2.6], [3.1, -3.5, 1.6]]}
{
"params_r": [1.38629436111989, -1.03192029533272, -0.479415242304312, -1.01866970062961, 0.67699454858025, 2.15402112425165, 1.11815245534682, 0.176100961348628, -1.47536938928953, 2.17309623676142, -0.263993484783208, 0.704803511993464, -1.04596396482332, -0.711702526766766, -1.71856146447761, 0.596396994136899, -0.454458757498793, -0.80021133760908, -0.384149678939469]
}

0 comments on commit d0a98ac

Please sign in to comment.