Skip to content

Commit

Permalink
Merge pull request #3230 from stan-dev/feature/3181-json-hmc-tuning-p…
Browse files Browse the repository at this point in the history
…arams

Feature/3181 json hmc tuning params
  • Loading branch information
mitzimorris committed Sep 21, 2023
2 parents 98a6b55 + c9ed01d commit e3089d2
Show file tree
Hide file tree
Showing 21 changed files with 1,077 additions and 183 deletions.
13 changes: 12 additions & 1 deletion src/stan/mcmc/hmc/base_hmc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <stan/callbacks/logger.hpp>
#include <stan/callbacks/writer.hpp>
#include <stan/callbacks/structured_writer.hpp>
#include <stan/mcmc/base_mcmc.hpp>
#include <stan/mcmc/hmc/hamiltonians/ps_point.hpp>
#include <boost/random/uniform_01.hpp>
Expand Down Expand Up @@ -63,6 +64,16 @@ class base_hmc : public base_mcmc {
write_sampler_metric(writer);
}

/**
* write stepsize and elements of mass matrix as a JSON object
*/
void write_sampler_state_struct(callbacks::structured_writer& struct_writer) {
struct_writer.begin_record();
struct_writer.write("stepsize", get_nominal_stepsize());
struct_writer.write("inv_metric", z_.inv_e_metric_);
struct_writer.end_record();
}

void get_sampler_diagnostic_names(std::vector<std::string>& model_names,
std::vector<std::string>& names) {
z_.get_param_names(model_names, names);
Expand Down Expand Up @@ -183,7 +194,7 @@ class base_hmc : public base_mcmc {

protected:
typename Hamiltonian<Model, BaseRNG>::PointType z_;
Integrator<Hamiltonian<Model, BaseRNG> > integrator_;
Integrator<Hamiltonian<Model, BaseRNG>> integrator_;
Hamiltonian<Model, BaseRNG> hamiltonian_;

BaseRNG& rand_int_;
Expand Down
24 changes: 19 additions & 5 deletions src/stan/mcmc/hmc/hamiltonians/unit_e_point.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef STAN_MCMC_HMC_HAMILTONIANS_UNIT_E_POINT_HPP
#define STAN_MCMC_HMC_HAMILTONIANS_UNIT_E_POINT_HPP

#include <stan/callbacks/writer.hpp>
#include <stan/mcmc/hmc/hamiltonians/ps_point.hpp>

namespace stan {
Expand All @@ -11,12 +12,25 @@ namespace mcmc {
*/
class unit_e_point : public ps_point {
public:
explicit unit_e_point(int n) : ps_point(n) {}
};
/**
* Vector of diagonal elements of inverse mass matrix.
*/
Eigen::VectorXd inv_e_metric_;

/**
* Construct a diag point in n-dimensional phase space
* with vector of ones for diagonal elements of inverse mass matrix.
*
* @param n number of dimensions
*/
explicit unit_e_point(int n) : ps_point(n), inv_e_metric_(n) {
inv_e_metric_.setOnes();
}

inline void write_metric(stan::callbacks::writer& writer) {
writer("No free parameters for unit metric");
}
inline void write_metric(stan::callbacks::writer& writer) {
writer("No free parameters for unit metric");
}
};

} // namespace mcmc
} // namespace stan
Expand Down
6 changes: 3 additions & 3 deletions src/stan/services/sample/fixed_param.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
#include <stan/math/prim.hpp>
#include <stan/mcmc/fixed_param_sampler.hpp>
#include <stan/services/error_codes.hpp>
#include <stan/services/util/mcmc_writer.hpp>
#include <stan/services/util/generate_transitions.hpp>
#include <stan/services/util/create_rng.hpp>
#include <stan/services/util/generate_transitions.hpp>
#include <stan/services/util/initialize.hpp>
#include <stan/services/util/mcmc_writer.hpp>
#include <chrono>
#include <vector>

Expand Down Expand Up @@ -63,7 +63,7 @@ int fixed_param(Model& model, const stan::io::var_context& init,
}

stan::mcmc::fixed_param_sampler sampler;
util::mcmc_writer writer(sample_writer, diagnostic_writer, logger);
services::util::mcmc_writer writer(sample_writer, diagnostic_writer, logger);
Eigen::VectorXd cont_params(cont_vector.size());
for (size_t i = 0; i < cont_vector.size(); i++)
cont_params[i] = cont_vector[i];
Expand Down

0 comments on commit e3089d2

Please sign in to comment.