Skip to content

Commit

Permalink
Merge pull request #815 from null-a/fwd-sample-helpers
Browse files Browse the repository at this point in the history
Add forward and forwardGuide helpers.
  • Loading branch information
stuhlmueller committed Apr 10, 2017
2 parents b969e55 + a59f7c0 commit 16ecef5
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 7 deletions.
15 changes: 15 additions & 0 deletions docs/functions/other.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,21 @@ Other

marginalize(dist, 'a') // => Marginal with p(true) = 0.9, p(false) = 0.1

.. js:function:: forward(model)

Evaluates function of zero arguments ``model``, ignoring any
:ref:`factor <factor>` statements.

Also see: :ref:`Forward Sampling <forward_sampling>`

.. js:function:: forwardGuide(model)

Evaluates function of zero arguments ``model``, ignoring any
``factor`` statements, and sampling from the :ref:`guide <guides>`
at each random choice.

Also see: :ref:`Forward Sampling <forward_sampling>`

.. js:function:: mapObject(fn, obj)

Returns the object obtained by mapping the function ``fn`` over the
Expand Down
2 changes: 2 additions & 0 deletions docs/inference/methods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,8 @@ Optimization

Infer({method: 'optimize', samples: 100, steps: 100, model: model});

.. _forward_sampling:

Forward Sampling
----------------

Expand Down
13 changes: 12 additions & 1 deletion src/inference/forwardSample.js
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,20 @@ module.exports = function(env) {
);
}

function extractVal(k) {
return function(s, obj) {
return k(s, obj.val);
};
}

return {
ForwardSample: ForwardSample,
runForward: runForward
forward: function(s, k, a, model) {
return runForward(s, extractVal(k), a, model, false);
},
forwardGuide: function(s, k, a, model) {
return runForward(s, extractVal(k), a, model, true);
}
};

};
7 changes: 3 additions & 4 deletions src/params/header.js
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ function defaultInit(mu, sigma) {

module.exports = function(env) {

var runForward = require('../inference/forwardSample')(env).runForward;
var forward = require('../inference/forwardSample')(env).forward;

var dimsForScalarParam = [1];

Expand Down Expand Up @@ -86,8 +86,7 @@ module.exports = function(env) {
return init(s, k, a, dims);
};

var next = function(k, ret) {
var initialVal = ret.val;
var next = function(k, initialVal) {
params.create(name, initialVal);
if (!_.isEqual(dims, initialVal.dims)) {
var msg = 'The init function did not return a tensor with the expected shape.';
Expand All @@ -96,7 +95,7 @@ module.exports = function(env) {
return finish(s);
};

return runForward(s, next, a, initThunk);
return forward(s, next, a, initThunk);
}

function finish(s) {
Expand Down
3 changes: 3 additions & 0 deletions tests/test-data/deterministic/expected/forward.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"result": [0, 1]
}
11 changes: 11 additions & 0 deletions tests/test-data/deterministic/models/forward.wppl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
var model = function() {
factor(0);
return sample(Delta({v: 0}), {guide() {
return Delta({v: 1});
}});
};

[
forward(model),
forwardGuide(model)
];
4 changes: 2 additions & 2 deletions tests/test-data/deterministic/models/observeParams.wppl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
sample(Infer({ method: 'forward', samples: 1, guide: true, model: function() {
forwardGuide(function() {
return observe(
Discrete({ ps: [1] }),
undefined,
{ guide: function() { return Delta({ v: 'ok' }); }});
}}));
});

0 comments on commit 16ecef5

Please sign in to comment.