Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1699 from stan-dev/feature/advi-tuning
ADVI stepsize sequence parameter (eta) adaptation
- Loading branch information
Showing
29 changed files
with
1,562 additions
and
166 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#ifndef STAN_SERVICES_ARGUMENTS_ARG_VARIATIONAL_ADAPT_HPP | ||
#define STAN_SERVICES_ARGUMENTS_ARG_VARIATIONAL_ADAPT_HPP | ||
|
||
#include <stan/services/arguments/categorical_argument.hpp> | ||
#include <stan/services/arguments/arg_variational_adapt_engaged.hpp> | ||
#include <stan/services/arguments/arg_variational_adapt_iter.hpp> | ||
|
||
namespace stan { | ||
namespace services { | ||
|
||
class arg_variational_adapt: public categorical_argument { | ||
public: | ||
arg_variational_adapt() { | ||
_name = "adapt"; | ||
_description = "Eta Adaptation for Variational Inference"; | ||
|
||
_subarguments.push_back(new arg_variational_adapt_engaged()); | ||
_subarguments.push_back(new arg_variational_adapt_iter()); | ||
} | ||
}; | ||
|
||
} // services | ||
} // stan | ||
#endif | ||
|
26 changes: 26 additions & 0 deletions
26
src/stan/services/arguments/arg_variational_adapt_engaged.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#ifndef STAN_SERVICES_ARGUMENTS_ARG_VARIATIONAL_ADAPT_ENGAGED_HPP | ||
#define STAN_SERVICES_ARGUMENTS_ARG_VARIATIONAL_ADAPT_ENGAGED_HPP | ||
|
||
#include <stan/services/arguments/singleton_argument.hpp> | ||
|
||
namespace stan { | ||
namespace services { | ||
|
||
class arg_variational_adapt_engaged: public bool_argument { | ||
public: | ||
arg_variational_adapt_engaged(): bool_argument() { | ||
_name = "engaged"; | ||
_description = "Adaptation engaged?"; | ||
_validity = "[0, 1]"; | ||
_default = "1"; | ||
_default_value = true; | ||
_constrained = false; | ||
_good_value = 1; | ||
_value = _default_value; | ||
} | ||
}; | ||
|
||
} // services | ||
} // stan | ||
#endif | ||
|
30 changes: 30 additions & 0 deletions
30
src/stan/services/arguments/arg_variational_adapt_iter.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#ifndef STAN_SERVICES_ARGUMENTS_ARG_VARIATIONAL_ADAPT_ITER_HPP | ||
#define STAN_SERVICES_ARGUMENTS_ARG_VARIATIONAL_ADAPT_ITER_HPP | ||
|
||
#include <stan/services/arguments/singleton_argument.hpp> | ||
|
||
namespace stan { | ||
namespace services { | ||
|
||
class arg_variational_adapt_iter: public int_argument { | ||
public: | ||
arg_variational_adapt_iter(): int_argument() { | ||
_name = "iter"; | ||
_description = "Number of iterations for eta adaptation"; | ||
_validity = "0 < iter"; | ||
_default = "50"; | ||
_default_value = 50; | ||
_constrained = true; | ||
_good_value = 2.0; | ||
_bad_value = -1.0; | ||
_value = _default_value; | ||
} | ||
|
||
bool is_valid(int value) { | ||
return value > 0; | ||
} | ||
}; | ||
|
||
} // services | ||
} // stan | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
#ifndef STAN_SERVICES_VARIATIONAL_PRINT_PROGRESS_HPP | ||
#define STAN_SERVICES_VARIATIONAL_PRINT_PROGRESS_HPP | ||
|
||
#include <stan/math/prim/scal/err/check_positive.hpp> | ||
#include <stan/math/prim/scal/err/check_nonnegative.hpp> | ||
#include <stan/services/io/do_print.hpp> | ||
#include <cmath> | ||
#include <iomanip> | ||
#include <iostream> | ||
#include <string> | ||
|
||
namespace stan { | ||
namespace services { | ||
namespace variational { | ||
|
||
/** | ||
* Helper function for printing progress for variational inference | ||
* | ||
* @param m total number of iterations | ||
* @param start starting iteration | ||
* @param finish final iteration | ||
* @param refresh how frequently we want to print an update | ||
* @param tune boolean indicates tuning vs. variational inference | ||
* @param prefix prefix string | ||
* @param suffix suffix string | ||
* @param o output stream | ||
*/ | ||
void print_progress(int m, | ||
int start, | ||
int finish, | ||
int refresh, | ||
bool tune, | ||
const std::string& prefix, | ||
const std::string& suffix, | ||
std::ostream& o) { | ||
static const char* function = | ||
"stan::services::variational::print_progress"; | ||
|
||
stan::math::check_positive(function, | ||
"Total number of iterations", | ||
m); | ||
stan::math::check_nonnegative(function, | ||
"Starting iteration", | ||
start); | ||
stan::math::check_positive(function, | ||
"Final iteration", | ||
finish); | ||
stan::math::check_positive(function, | ||
"Refresh rate", | ||
refresh); | ||
|
||
int it_print_width = std::ceil(std::log10(static_cast<double>(finish))); | ||
if (io::do_print(m - 1, (start + m == finish), refresh)) { | ||
o << prefix; | ||
o << "Iteration: "; | ||
o << std::setw(it_print_width) << m + start | ||
<< " / " << finish; | ||
o << " [" << std::setw(3) | ||
<< (100 * (start + m)) / finish | ||
<< "%] "; | ||
o << (tune ? " (Adaptation)" : " (Variational Inference)"); | ||
o << suffix; | ||
o << std::endl; | ||
} | ||
} | ||
|
||
} | ||
} | ||
} | ||
|
||
#endif |
Oops, something went wrong.