Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Waiting on Stan PR] Adds psis_resample and calculate_lp as options for pathfinder #1234

Merged
merged 5 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,7 @@ output*.csv
*.o
# gdb
.gdb_history

#vscode
.vscode/*
.vscode/**
16 changes: 16 additions & 0 deletions src/cmdstan/arguments/arg_pathfinder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,22 @@ class arg_pathfinder : public arg_lbfgs {
_subarguments.push_back(new arg_single_bool(
"save_single_paths", "Output single-path pathfinder draws as CSV",
false));
_subarguments.push_back(new arg_single_bool(
"psis_resample",
"If true, perform psis resampling on samples returned"
" from individual pathfinders. If false, returns num_paths * num_draws"
" samples",
true));
_subarguments.push_back(new arg_single_bool(
"calculate_lp",
"If true, individual pathfinders lp calculations are calculated and"
" returned with the output. If false, each pathfinder will only "
" calculate the lp values needed for the elbo calculation."
" If false, psis resampling cannot be performed and"
" the algorithm returns num_paths * num_draws samples."
" The output will still contain any lp values used when"
" calculating ELBO scores within LBFGS iterations.",
true));
_subarguments.push_back(new arg_single_int_pos(
"max_lbfgs_iters", "Maximum number of LBFGS iterations", 1000));
_subarguments.push_back(new arg_single_int_pos(
Expand Down
9 changes: 6 additions & 3 deletions src/cmdstan/command.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,10 @@ int command(int argc, const char *argv[]) {
int num_draws = get_arg_val<int_argument>(*pathfinder_arg, "num_draws");
int num_psis_draws
= get_arg_val<int_argument>(*pathfinder_arg, "num_psis_draws");

bool psis_resample
= get_arg_val<bool_argument>(*pathfinder_arg, "psis_resample");
bool calculate_lp
= get_arg_val<bool_argument>(*pathfinder_arg, "calculate_lp");
if (num_psis_draws > num_draws * num_chains) {
logger.warn(
"Warning: Number of PSIS draws is larger than the total number of "
Expand All @@ -332,7 +335,7 @@ int command(int argc, const char *argv[]) {
history_size, init_alpha, tol_obj, tol_rel_obj, tol_grad,
tol_rel_grad, tol_param, max_lbfgs_iters, num_elbo_draws, num_draws,
save_single_paths, refresh, interrupt, logger, init_writer,
sample_writers[0], diagnostic_json_writers[0]);
sample_writers[0], diagnostic_json_writers[0], calculate_lp);
} else {
auto output_filenames = make_filenames(output_file, "", ".csv", 1, id);
auto ofs = std::make_unique<std::ofstream>(output_filenames[0]);
Expand All @@ -348,7 +351,7 @@ int command(int argc, const char *argv[]) {
max_lbfgs_iters, num_elbo_draws, num_draws, num_psis_draws,
num_chains, save_single_paths, refresh, interrupt, logger,
init_writers, sample_writers, diagnostic_json_writers,
pathfinder_writer, dummy_json_writer);
pathfinder_writer, dummy_json_writer, calculate_lp, psis_resample);
}
// ---- pathfinder end ---- //
} else if (user_method->arg("generate_quantities")) {
Expand Down
7 changes: 6 additions & 1 deletion src/cmdstan/command_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,12 @@ inline constexpr auto get_arg(List &&arg_list, const char *arg1,
*/
template <typename caster, typename Arg>
inline constexpr auto get_arg_val(Arg &&argument, const char *arg_name) {
return dynamic_cast<std::decay_t<caster> *>(argument.arg(arg_name))->value();
auto *arg = argument.arg(arg_name);
if (arg) {
return dynamic_cast<std::decay_t<caster> *>(arg)->value();
} else {
throw std::invalid_argument(std::string("Unable to find: ") + arg_name);
}
}

/**
Expand Down