Skip to content

Commit

Permalink
Merge pull request #807 from null-a/on-step-callback
Browse files Browse the repository at this point in the history
Add onStep callback to Optimize.
  • Loading branch information
stuhlmueller committed Apr 3, 2017
2 parents 9ac99ad + 2c3193e commit 2b8568e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
9 changes: 9 additions & 0 deletions docs/optimization/optimize.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ Optimize

Default: ``ELBO``

.. describe:: onStep

Specifies a function that will be called after each step. The
function will be passed the index of the current step and the
value of the objective as arguments. For example::

var callback = function(index, value) { /* ... */ };
Optimize({model: model, steps: 100, onStep: callback});

.. describe:: verbose

Default: ``true``
Expand Down
12 changes: 11 additions & 1 deletion src/inference/optimize.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ var nodeUtil = require('util');

module.exports = function(env) {

var applyd = require('../headerUtils')(env).applyd;
var estimators = {
ELBO: require('./elbo')(env),
EUBO: require('./eubo')(env),
Expand Down Expand Up @@ -48,6 +49,7 @@ module.exports = function(env) {
showGradNorm: false,
checkGradients: true,
verbose: true,
onStep: function(s, k, a) { return k(s); },
onFinish: function(s, k, a) { return k(s); },

logProgress: false,
Expand Down Expand Up @@ -109,6 +111,12 @@ module.exports = function(env) {
checkpointParams = _.throttle(saveParams, options.checkpointParamsThrottle, { trailing: false });
}

var onStep = function(i, objective, cont) {
return applyd(s, function(s, val) {
return cont();
}, a, options.onStep, [i, objective], 'callback');
};

// Main loop.
return util.cpsLoop(
options.steps,
Expand Down Expand Up @@ -150,7 +158,9 @@ module.exports = function(env) {
optimizer(gradObj, paramsObj, i);

// Send updated params to store
return params.set(paramsObj, next);
return params.set(paramsObj, function() {
return onStep(i, objective, next);
});

}, { incremental: true });

Expand Down

0 comments on commit 2b8568e

Please sign in to comment.