Skip to content

Commit

Permalink
Merge pull request #745 from null-a/guide-thunks-anywhere
Browse files Browse the repository at this point in the history
Allow guide thunks anywhere
  • Loading branch information
stuhlmueller committed Jan 11, 2017
2 parents 7a8fdc9 + 4b11d70 commit 84fd4eb
Show file tree
Hide file tree
Showing 15 changed files with 110 additions and 27 deletions.
2 changes: 2 additions & 0 deletions docs/globalstore.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.. _globalstore:

The Global Store
================

Expand Down
21 changes: 21 additions & 0 deletions docs/guides.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,24 @@ For example::
return Gaussian(guideParams);
}
});

Note that such functions will only be called when using an inference
strategy that makes use of the guide.

In some situations, it is convenient to be able to specify part of a
guide computation outside of the functions passed to ``sample``. This
can be accomplished with the ``guide`` function, which takes a
function of zero arguments representing the computation::

guide(function() {
// Some guide computation.
});

As with the functions passed to ``sample``, the function passed to
``guide`` will only be called when required for inference.

It's important to note that ``guide`` does not return the value of the
computation. Instead, the :ref:`global store <globalstore>` should be
used to pass results to subsequent guide computations. This
arrangement encourages a programming style in which there is
separation between the model and the guide.
35 changes: 20 additions & 15 deletions src/guide.js
Original file line number Diff line number Diff line change
Expand Up @@ -57,34 +57,38 @@ function runThunk(thunk, env, s, a, k) {
guideCoroutine(env),
env, k,
function(k) {
return thunk(s, function(s2, dist) {
return thunk(s, function(s2, val) {
// Clear the stack now the thunk has returned.
return function() {
if (!dists.isDist(dist)) {
throw new Error('The guide did not return a distribution.');
}
return k(s2, dist);
return k(s2, val);
};
}, a + '_guide');
});
}

function runIfThunk(s, k, a, env, maybeThunk, alternate) {
function runDistThunk(thunk, env, s, a, k) {
return runThunk(thunk, env, s, a, function(s2, dist) {
if (!dists.isDist(dist)) {
throw new Error('The guide did not return a distribution.');
}
return k(s2, dist);
});
}

function runDistThunkCond(s, k, a, env, maybeThunk, alternate) {
return maybeThunk ?
runThunk(maybeThunk, env, s, a, k) :
runDistThunk(maybeThunk, env, s, a, k) :
alternate(s, k, a);
}

// Convenient variations on runGuideThunk.

function runIfThunkElseNull(maybeThunk, env, s, a, k) {
return runIfThunk(s, k, a, env, maybeThunk, function(s, k, a) {
function getDist(maybeThunk, env, s, a, k) {
return runDistThunkCond(s, k, a, env, maybeThunk, function(s, k, a) {
return k(s, null);
});
}

function runIfThunkElseAuto(maybeThunk, targetDist, env, s, a, k) {
return runIfThunk(s, k, a, env, maybeThunk, function(s, k, a) {
function getDistOrAuto(maybeThunk, targetDist, env, s, a, k) {
return runDistThunkCond(s, k, a, env, maybeThunk, function(s, k, a) {
return k(s, independent(targetDist, a, env));
});
}
Expand Down Expand Up @@ -329,6 +333,7 @@ function squishToInterval(interval) {

module.exports = {
independent: independent,
runIfThunkElseAuto: runIfThunkElseAuto,
runIfThunkElseNull: runIfThunkElseNull
runThunk: runThunk,
getDist: getDist,
getDistOrAuto: getDistOrAuto
};
14 changes: 13 additions & 1 deletion src/headerUtils.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ var Tensor = require('./tensor');
var LRU = require('lru-cache');
var ad = require('./ad');
var assert = require('assert');
var runThunk = require('./guide').runThunk;

module.exports = function(env) {

Expand Down Expand Up @@ -154,6 +155,16 @@ module.exports = function(env) {
}
}

function guide(s, k, a, thunk) {
if (env.coroutine.guideRequired) {
return runThunk(thunk, env, s, a, function(s2, val) {
return k(s2);
});
} else {
return k(s);
}
}

return {
display: display,
cache: cache,
Expand All @@ -162,7 +173,8 @@ module.exports = function(env) {
_addr: _addr,
zeros: zeros,
ones: ones,
mapData: mapData
mapData: mapData,
guide: guide
};

};
3 changes: 2 additions & 1 deletion src/inference/elbo.ad.js
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ module.exports = function(env) {
this.wpplFn = wpplFn;
this.s = s;
this.a = a;
this.guideRequired = true;

// Initialize mapData state.
this.mapDataStack = [{multiplier: 1}];
Expand Down Expand Up @@ -242,7 +243,7 @@ module.exports = function(env) {

sample: function(s, k, a, dist, options) {
options = options || {};
return guide.runIfThunkElseAuto(options.guide, dist, env, s, a, function(s, guideDist) {
return guide.getDistOrAuto(options.guide, dist, env, s, a, function(s, guideDist) {

var ret = this.sampleGuide(guideDist, options);
var val = ret.val;
Expand Down
1 change: 1 addition & 0 deletions src/inference/eubo.ad.js
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ module.exports = function(env) {
this.wpplFn = wpplFn;
this.s = s;
this.a = a;
this.guideRequired = true;

this.coroutine = env.coroutine;
env.coroutine = this;
Expand Down
3 changes: 2 additions & 1 deletion src/inference/forwardSample.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ module.exports = function(env) {
this.s = s;
this.k = k;
this.a = a;
this.guideRequired = this.opts.guide;

this.factorWarningIssued = false;

Expand Down Expand Up @@ -65,7 +66,7 @@ module.exports = function(env) {
sample: function(s, k, a, dist, options) {
if (this.opts.guide) {
options = options || {};
return guide.runIfThunkElseAuto(options.guide, dist, env, s, a, function(s, guideDist) {
return guide.getDistOrAuto(options.guide, dist, env, s, a, function(s, guideDist) {
return k(s, guideDist.sample());
});
} else {
Expand Down
3 changes: 2 additions & 1 deletion src/inference/smc.js
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ module.exports = function(env) {
this.debug = options.debug;
this.saveTraces = options.saveTraces;
this.importanceOpt = options.importance;
this.guideRequired = options.importance !== 'ignoreGuide';
this.onlyMAP = options.onlyMAP;

this.particles = [];
Expand Down Expand Up @@ -73,7 +74,7 @@ module.exports = function(env) {
SMC.prototype.sample = function(s, k, a, dist, options) {
options = options || {};
var thunk = (this.importanceOpt === 'ignoreGuide') ? undefined : options.guide;
return guide.runIfThunkElseNull(thunk, env, s, a, function(s, maybeDist) {
return guide.getDist(thunk, env, s, a, function(s, maybeDist) {

// maybeDist will be null if either the 'ignoreGuide' option is
// set, or no guide is specified in the program.
Expand Down
1 change: 0 additions & 1 deletion tests/test-data/deterministic/expected/smc.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
{
"result": [
[0],
0,
1
]
Expand Down
7 changes: 0 additions & 7 deletions tests/test-data/deterministic/models/smc.wppl
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
[
// Check that the guide is ignored.
Infer({method: 'SMC', particles: 1, importance: 'ignoreGuide', model() {
return sample(Delta({v: 0}), {guide() {
assert.ok(false);
}});
}}).support(),

// Check (indirectly) that a guide is automatically generated, by
// checking that a parameter is created.

Expand Down
5 changes: 5 additions & 0 deletions tests/test-data/stochastic/expected/guideThunks.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"hist": {
"true": 1
}
}
5 changes: 5 additions & 0 deletions tests/test-data/stochastic/expected/noGuideThunks.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"hist": {
"true": 1
}
}
17 changes: 17 additions & 0 deletions tests/test-data/stochastic/models/guideThunks.wppl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
var model = function() {
globalStore.t1 = false;
globalStore.t2 = false;

guide(function() {
globalStore.t1 = true;
});
assert.ok(globalStore.t1, 'Guide thunk 1 did not run.');

sample(Bernoulli({p: 0.5}), {guide() {
globalStore.t2 = true;
return Bernoulli({p: Math.sigmoid(param())});
}});
assert.ok(globalStore.t2, 'Guide thunk 2 did not run.');

return true;
};
11 changes: 11 additions & 0 deletions tests/test-data/stochastic/models/noGuideThunks.wppl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
var model = function() {
guide(function() {
assert.ok(false, 'Guide thunk 1 ran unexpectedly.');
});

sample(Bernoulli({p: 0.5}), {guide() {
assert.ok(false, 'Guide thunk 2 ran unexpectedly.');
}});

return true;
};
9 changes: 9 additions & 0 deletions tests/test-inference.js
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,14 @@ var tests = [
nestedEnum7: { mean: { tol: 0.075 }, std: { tol: 0.05 } },
nestedEnum8: { mean: { tol: 0.075 }, std: { tol: 0.05 } },
nestedEnumWithFactor: { mean: { tol: 0.075 }, std: { tol: 0.05 } },
guideThunks: {
hist: { exact: true },
args: { particles: 100 }
},
noGuideThunks: {
hist: { exact: true },
args: { particles: 100, importance: 'ignoreGuide' }
},
guidedFlip: true,
mapData: true
}
Expand Down Expand Up @@ -559,6 +567,7 @@ var tests = [
args: { verbose: false, checkGradients: false }
},
withCaching: true,
guideThunks: { hist: { exact: true } },
gaussianMean: true,
guidedFlip: true,
guidedGaussian: {
Expand Down

0 comments on commit 84fd4eb

Please sign in to comment.