Skip to content

Commit

Permalink
Merge pull request #715 from null-a/guide-thunks
Browse files Browse the repository at this point in the history
Guide Thunks
  • Loading branch information
stuhlmueller committed Dec 20, 2016
2 parents 0db8ab1 + 132d7bd commit 5e3e103
Show file tree
Hide file tree
Showing 17 changed files with 268 additions and 88 deletions.
11 changes: 8 additions & 3 deletions docs/guides.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,15 @@ A number of :ref:`inference <inference>` strategies make use of an
auxiliary distribution which we call a *guide distribution*. They are
specified like so::

sample(dist, {guide: guideDist});
sample(dist, {guide: guideFn});

Where ``guideDist`` is a distribution object.
Where ``guideFn`` is a function that takes zero arguments, and returns
a distribution object.

For example::

sample(Cauchy(params), {guide: Gaussian(guideParams)});
sample(Cauchy(params), {
guide: function() {
return Gaussian(guideParams);
}
});
35 changes: 28 additions & 7 deletions docs/inference/methods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,34 @@ SMC

Default: ``'MH'``

.. describe:: importance

Controls the importance distribution used during inference.

Specifying an importance distribution can be useful when you
know something about the posterior distribution, as
specifying an importance distribution that is closer to the
posterior than the prior will improve the statistical
efficiency of inference.

This option accepts the following values:

* ``'default'``: When a random choice has a :ref:`guide
distribution <guides>` specified, use that as the
importance distribution. For all other random choices, use
the prior.

* ``'ignoreGuide'``: Use the prior as the importance
distribution for all random choices.

* ``'autoGuide'``: When a random choice has a :ref:`guide
distribution <guides>` specified, use that as the
importance distribution. For all other random choices,
automatically generate a mean-field guide and use that as
the importance distribution.

Default: ``'default'``

.. describe:: onlyMAP

When ``true``, only the sample with the highest score is
Expand All @@ -282,13 +310,6 @@ SMC

Infer({method: 'SMC', particles: 100, rejuvSteps: 5, model: model});

By default SMC uses the prior as the importance distribution. Other
distributions can be used by specifying :ref:`guide distributions
<guides>`. This can be useful when you know something about the
posterior distribution as specifying an importance distribution
that is closer to the posterior than the prior will improve the
statistical efficiency of inference.

Optimization
------------

Expand Down
92 changes: 91 additions & 1 deletion src/guide.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,100 @@ var params = require('./params/params');

var T = ad.tensor;

function notAllowed(fn) {
return function() {
throw new Error(fn + ' cannot be used within the guide.');
};
}

var sampleNotAllowed = notAllowed('sample');
var factorNotAllowed = notAllowed('factor');

function guideCoroutine(env) {
return {
sample: sampleNotAllowed,
factor: factorNotAllowed,
incrementalize: env.defaultCoroutine.incrementalize,
// Copy the entry address from the current coroutine so that
// parameter names continue to be relative to it.
a: env.coroutine.a,
// Use params/paramsSeen of the current coroutine so that params
// are fetch/tracked correctly.
params: env.coroutine.params,
paramsSeen: env.coroutine.paramsSeen,
// A flag used when creating parameters to check whether we're in
// a guide thunk. Note that this does the right thing if Infer is
// used within a guide. This can be checked from a webppl program
// using the `inGuide()` helper.
_guide: true
};
}

function runInCoroutine(coroutine, env, k, f) {
var prevCoroutine = env.coroutine;
env.coroutine = coroutine;
return f(function(s, val) {
env.coroutine = prevCoroutine;
return k(s, val);
});
}

function runThunk(thunk, env, s, a, k) {
if (!_.isFunction(thunk)) {
throw new Error('The guide is expected to be a function.');
}
// Run the thunk with the guide coroutine installed. Check the
// return value is a distribution before continuing.
return runInCoroutine(
guideCoroutine(env),
env, k,
function(k) {
return thunk(s, function(s2, dist) {
// Clear the stack now the thunk has returned.
return function() {
if (!dists.isDist(dist)) {
throw new Error('The guide did not return a distribution.');
}
return k(s2, dist);
};
}, a + '_guide');
});
}

function runIfThunk(s, k, a, env, maybeThunk, alternate) {
return maybeThunk ?
runThunk(maybeThunk, env, s, a, k) :
alternate(s, k, a);
}

// Convenient variations on runGuideThunk.

function runIfThunkElseNull(maybeThunk, env, s, a, k) {
return runIfThunk(s, k, a, env, maybeThunk, function(s, k, a) {
return k(s, null);
});
}

function runIfThunkElseAuto(maybeThunk, targetDist, env, s, a, k) {
return runIfThunk(s, k, a, env, maybeThunk, function(s, k, a) {
return k(s, independent(targetDist, a, env));
});
}

// Returns an independent guide distribution for the given target
// distribution, sample address pair. Guiding all choices with
// independent guide distributions and optimizing the elbo yields
// mean-field variational inference.

var autoGuideWarningIssued = false;

function independent(targetDist, sampleAddress, env) {

if (!autoGuideWarningIssued) {
autoGuideWarningIssued = true;
util.warn('Automatically generating guide for one or more choices.');
}

// Include the distribution name in the guide parameter name to
// avoid collisions when the distribution type changes between
// calls. (As a result of the distribution passed depending on a
Expand Down Expand Up @@ -246,5 +334,7 @@ function squishToInterval(interval) {
}

module.exports = {
independent: independent
independent: independent,
runIfThunkElseAuto: runIfThunkElseAuto,
runIfThunkElseNull: runIfThunkElseNull
};
6 changes: 3 additions & 3 deletions src/header.wppl
Original file line number Diff line number Diff line change
Expand Up @@ -517,9 +517,9 @@ var Infer = function(options, maybeFn) {
// Convenience function for creating maximum likelihood model
// parameters.
var modelParam = function(options) {
return sample(ImproperUniform(), {
guide: Delta({v: param(options)})
});
return sample(ImproperUniform(), {guide() {
return Delta({v: param(options)});
}});
};

// Convenience functions for building tensors out of scalars
Expand Down
1 change: 0 additions & 1 deletion src/headerUtils.js
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ module.exports = function(env) {
return k(s, new Tensor(dims).fill(1));
};


// It is the responsibility of individual coroutines to implement
// data sub-sampling and to make use of the conditional independence
// information mapData provides. To do so, coroutines can implement
Expand Down
40 changes: 15 additions & 25 deletions src/inference/elbo.ad.js
Original file line number Diff line number Diff line change
Expand Up @@ -242,38 +242,28 @@ module.exports = function(env) {

sample: function(s, k, a, dist, options) {
options = options || {};
return guide.runIfThunkElseAuto(options.guide, dist, env, s, a, function(s, guideDist) {

var guideDist;
if (options.guide) {
guideDist = options.guide;
} else {
guideDist = guide.independent(dist, a, env);
if (this.step === 0 &&
this.opts.verbose &&
!this.mfWarningIssued) {
this.mfWarningIssued = true;
console.log('ELBO: Defaulting to mean-field for one or more choices.');
}
}
var ret = this.sampleGuide(guideDist, options);
var val = ret.val;

var ret = this.sampleGuide(guideDist, options);
var val = ret.val;
var logp = dist.score(val);
var logq = guideDist.score(val);
checkScoreIsFinite(logp, 'target');
checkScoreIsFinite(logq, 'guide');

var logp = dist.score(val);
var logq = guideDist.score(val);
checkScoreIsFinite(logp, 'target');
checkScoreIsFinite(logq, 'guide');
var m = top(this.mapDataStack).multiplier;

var m = top(this.mapDataStack).multiplier;
var node = new SampleNode(
this.prevNode, logp, logq,
ret.reparam, a, dist, guideDist, val, m, this.opts.debugWeights);

var node = new SampleNode(
this.prevNode, logp, logq,
ret.reparam, a, dist, guideDist, val, m, this.opts.debugWeights);
this.prevNode = node;
this.nodes.push(node);

this.prevNode = node;
this.nodes.push(node);
return k(s, val);

return k(s, val);
}.bind(this));
},

sampleGuide: function(dist, options) {
Expand Down
12 changes: 8 additions & 4 deletions src/inference/forwardSample.js
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,14 @@ module.exports = function(env) {
},

sample: function(s, k, a, dist, options) {
var distribution = this.opts.guide ?
(options && options.guide) || guide.independent(dist, a, env) :
dist;
return k(s, distribution.sample());
if (this.opts.guide) {
options = options || {};
return guide.runIfThunkElseAuto(options.guide, dist, env, s, a, function(s, guideDist) {
return k(s, guideDist.sample());
});
} else {
return k(s, dist.sample());
}
},

factor: function(s, k, a, score) {
Expand Down
70 changes: 45 additions & 25 deletions src/inference/smc.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@ var Trace = require('../trace');
var assert = require('assert');
var CountAggregator = require('../aggregation/CountAggregator');
var ad = require('../ad');
var guide = require('../guide');

module.exports = function(env) {

var kernels = require('./kernels')(env);

var validImportanceOptVals = ['default', 'ignoreGuide', 'autoGuide'];

function SMC(s, k, a, wpplFn, options) {
util.throwUnlessOpts(options, 'SMC');
var options = util.mergeDefaults(options, {
Expand All @@ -21,10 +24,16 @@ module.exports = function(env) {
rejuvKernel: 'MH',
finalRejuv: true,
saveTraces: false,
ignoreGuide: false,
importance: 'default',
onlyMAP: false
});

if (!_.includes(validImportanceOptVals, options.importance)) {
var msg = options.importance + ' is not a valid importance option. ' +
'Valid options are: ' + validImportanceOptVals;
throw new Error(msg);
}

this.rejuvKernel = kernels.parseOptions(options.rejuvKernel);
this.rejuvSteps = options.rejuvSteps;

Expand All @@ -34,7 +43,7 @@ module.exports = function(env) {
this.numParticles = options.particles;
this.debug = options.debug;
this.saveTraces = options.saveTraces;
this.ignoreGuide = options.ignoreGuide;
this.importanceOpt = options.importance;
this.onlyMAP = options.onlyMAP;

this.particles = [];
Expand Down Expand Up @@ -62,31 +71,42 @@ module.exports = function(env) {
};

SMC.prototype.sample = function(s, k, a, dist, options) {
var _val, choiceScore, importanceScore;

if (options && _.has(options, 'guide') && !this.ignoreGuide) {
// Guide available.
var importanceDist = options.guide;
_val = importanceDist.sample();
choiceScore = dist.score(_val);
importanceScore = importanceDist.score(_val);
} else {
// No guide, sample from prior.
_val = dist.sample();
choiceScore = importanceScore = dist.score(_val);
}
options = options || {};
var thunk = (this.importanceOpt === 'ignoreGuide') ? undefined : options.guide;
return guide.runIfThunkElseNull(thunk, env, s, a, function(s, maybeDist) {

// maybeDist will be null if either the 'ignoreGuide' option is
// set, or no guide is specified in the program.

// Auto guide if requested.
var importanceDist =
!maybeDist && (this.importanceOpt === 'autoGuide') ?
guide.independent(dist, a, env) :
maybeDist;

var _val, choiceScore, importanceScore;
if (importanceDist) {
_val = importanceDist.sample();
choiceScore = dist.score(_val);
importanceScore = importanceDist.score(_val);
} else {
// No importance distribution, sample from prior.
_val = dist.sample();
choiceScore = importanceScore = dist.score(_val);
}

var particle = this.currentParticle();
particle.logWeight += ad.value(choiceScore) - ad.value(importanceScore);
var particle = this.currentParticle();
particle.logWeight += ad.value(choiceScore) - ad.value(importanceScore);

var val = this.adRequired && dist.isContinuous ? ad.lift(_val) : _val;
// Optimization: Choices are not required for PF without rejuvenation.
if (this.performRejuv || this.saveTraces) {
particle.trace.addChoice(dist, val, a, s, k, options);
} else {
particle.trace.score = ad.scalar.add(particle.trace.score, choiceScore);
}
return k(s, val);
var val = this.adRequired && dist.isContinuous ? ad.lift(_val) : _val;
// Optimization: Choices are not required for PF without rejuvenation.
if (this.performRejuv || this.saveTraces) {
particle.trace.addChoice(dist, val, a, s, k, options);
} else {
particle.trace.score = ad.scalar.add(particle.trace.score, choiceScore);
}
return k(s, val);
}.bind(this));
};

SMC.prototype.factor = function(s, k, a, score) {
Expand Down

0 comments on commit 5e3e103

Please sign in to comment.