Skip to content

Commit

Permalink
Merge pull request #1699 from stan-dev/feature/advi-tuning
Browse files Browse the repository at this point in the history
ADVI stepsize sequence parameter (eta) adaptation
  • Loading branch information
syclik committed Nov 23, 2015
2 parents 5515d17 + c30bb4c commit 99289c8
Show file tree
Hide file tree
Showing 29 changed files with 1,562 additions and 166 deletions.
3 changes: 3 additions & 0 deletions make/tests
Expand Up @@ -158,6 +158,9 @@ src/test/unit/variational/gradient_warn_test.cpp: src/test/test-models/good/vari
src/test/unit/variational/hier_logistic_test.cpp: src/test/test-models/good/variational/hier_logistic.hpp
src/test/unit/variational/hier_logistic_cp_test.cpp: src/test/test-models/good/variational/hier_logistic_cp.hpp
src/test/unit/variational/advi_messages_test.cpp:src/test/test-models/good/variational/univariate_no_constraint.hpp
src/test/unit/variational/eta_adapt_fail_test.cpp:src/test/test-models/good/variational/eta_should_fail.hpp
src/test/unit/variational/eta_adapt_big_test.cpp:src/test/test-models/good/variational/eta_should_be_big.hpp
src/test/unit/variational/eta_adapt_small_test.cpp:src/test/test-models/good/variational/eta_should_be_small.hpp

##
# Compile models depends on every model within
Expand Down
3 changes: 3 additions & 0 deletions src/stan/io/stan_csv_reader.hpp
Expand Up @@ -194,6 +194,9 @@ namespace stan {
}
ss.seekg(std::ios_base::beg);

if (lines < 4)
return false;

char comment; // Buffer for comment indicator, #

// Skip first two lines
Expand Down
4 changes: 3 additions & 1 deletion src/stan/services/arguments/arg_variational.hpp
Expand Up @@ -7,12 +7,12 @@
#include <stan/services/arguments/arg_variational_iter.hpp>
#include <stan/services/arguments/arg_variational_num_samples.hpp>
#include <stan/services/arguments/arg_variational_eta.hpp>
#include <stan/services/arguments/arg_variational_adapt.hpp>
#include <stan/services/arguments/arg_tolerance.hpp>
#include <stan/services/arguments/arg_variational_eval_elbo.hpp>
#include <stan/services/arguments/arg_variational_output_samples.hpp>

namespace stan {

namespace services {

class arg_variational: public categorical_argument {
Expand All @@ -31,6 +31,7 @@ namespace stan {
"of ELBO (objective function)",
100));
_subarguments.push_back(new arg_variational_eta());
_subarguments.push_back(new arg_variational_adapt());
_subarguments.push_back(new arg_tolerance("tol_rel_obj",
"Convergence tolerance on the relative norm of the objective", 1e-2));
_subarguments.push_back(new arg_variational_eval_elbo("eval_elbo",
Expand All @@ -41,6 +42,7 @@ namespace stan {
1000));
}
};

} // services
} // stan

Expand Down
25 changes: 25 additions & 0 deletions src/stan/services/arguments/arg_variational_adapt.hpp
@@ -0,0 +1,25 @@
#ifndef STAN_SERVICES_ARGUMENTS_ARG_VARIATIONAL_ADAPT_HPP
#define STAN_SERVICES_ARGUMENTS_ARG_VARIATIONAL_ADAPT_HPP

#include <stan/services/arguments/categorical_argument.hpp>
#include <stan/services/arguments/arg_variational_adapt_engaged.hpp>
#include <stan/services/arguments/arg_variational_adapt_iter.hpp>

namespace stan {
namespace services {

class arg_variational_adapt: public categorical_argument {
public:
arg_variational_adapt() {
_name = "adapt";
_description = "Eta Adaptation for Variational Inference";

_subarguments.push_back(new arg_variational_adapt_engaged());
_subarguments.push_back(new arg_variational_adapt_iter());
}
};

} // services
} // stan
#endif

26 changes: 26 additions & 0 deletions src/stan/services/arguments/arg_variational_adapt_engaged.hpp
@@ -0,0 +1,26 @@
#ifndef STAN_SERVICES_ARGUMENTS_ARG_VARIATIONAL_ADAPT_ENGAGED_HPP
#define STAN_SERVICES_ARGUMENTS_ARG_VARIATIONAL_ADAPT_ENGAGED_HPP

#include <stan/services/arguments/singleton_argument.hpp>

namespace stan {
namespace services {

class arg_variational_adapt_engaged: public bool_argument {
public:
arg_variational_adapt_engaged(): bool_argument() {
_name = "engaged";
_description = "Adaptation engaged?";
_validity = "[0, 1]";
_default = "1";
_default_value = true;
_constrained = false;
_good_value = 1;
_value = _default_value;
}
};

} // services
} // stan
#endif

30 changes: 30 additions & 0 deletions src/stan/services/arguments/arg_variational_adapt_iter.hpp
@@ -0,0 +1,30 @@
#ifndef STAN_SERVICES_ARGUMENTS_ARG_VARIATIONAL_ADAPT_ITER_HPP
#define STAN_SERVICES_ARGUMENTS_ARG_VARIATIONAL_ADAPT_ITER_HPP

#include <stan/services/arguments/singleton_argument.hpp>

namespace stan {
namespace services {

class arg_variational_adapt_iter: public int_argument {
public:
arg_variational_adapt_iter(): int_argument() {
_name = "iter";
_description = "Number of iterations for eta adaptation";
_validity = "0 < iter";
_default = "50";
_default_value = 50;
_constrained = true;
_good_value = 2.0;
_bad_value = -1.0;
_value = _default_value;
}

bool is_valid(int value) {
return value > 0;
}
};

} // services
} // stan
#endif
8 changes: 4 additions & 4 deletions src/stan/services/arguments/arg_variational_eta.hpp
Expand Up @@ -12,15 +12,15 @@ namespace stan {
arg_variational_eta(): real_argument() {
_name = "eta";
_description = "Stepsize scaling parameter for variational inference";
_validity = "0 < eta <= 1.0";
_default = "0.1";
_default_value = 0.1;
_validity = "0 < eta";
_default = "1.0";
_default_value = 1.0;
_constrained = true;
_good_value = 1.0;
_bad_value = -1.0;
_value = _default_value;
}
bool is_valid(double value) { return value > 0 && value <= 1.0; }
bool is_valid(double value) { return value > 0; }
};
} // services
} // stan
Expand Down
71 changes: 71 additions & 0 deletions src/stan/services/variational/print_progress.hpp
@@ -0,0 +1,71 @@
#ifndef STAN_SERVICES_VARIATIONAL_PRINT_PROGRESS_HPP
#define STAN_SERVICES_VARIATIONAL_PRINT_PROGRESS_HPP

#include <stan/math/prim/scal/err/check_positive.hpp>
#include <stan/math/prim/scal/err/check_nonnegative.hpp>
#include <stan/services/io/do_print.hpp>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <string>

namespace stan {
namespace services {
namespace variational {

/**
* Helper function for printing progress for variational inference
*
* @param m total number of iterations
* @param start starting iteration
* @param finish final iteration
* @param refresh how frequently we want to print an update
* @param tune boolean indicates tuning vs. variational inference
* @param prefix prefix string
* @param suffix suffix string
* @param o output stream
*/
void print_progress(int m,
int start,
int finish,
int refresh,
bool tune,
const std::string& prefix,
const std::string& suffix,
std::ostream& o) {
static const char* function =
"stan::services::variational::print_progress";

stan::math::check_positive(function,
"Total number of iterations",
m);
stan::math::check_nonnegative(function,
"Starting iteration",
start);
stan::math::check_positive(function,
"Final iteration",
finish);
stan::math::check_positive(function,
"Refresh rate",
refresh);

int it_print_width = std::ceil(std::log10(static_cast<double>(finish)));
if (io::do_print(m - 1, (start + m == finish), refresh)) {
o << prefix;
o << "Iteration: ";
o << std::setw(it_print_width) << m + start
<< " / " << finish;
o << " [" << std::setw(3)
<< (100 * (start + m)) / finish
<< "%] ";
o << (tune ? " (Adaptation)" : " (Variational Inference)");
o << suffix;
o << std::endl;
}
}

}
}
}

#endif

0 comments on commit 99289c8

Please sign in to comment.