Skip to content

Commit

Permalink
Make observe pass options to sample (#794)
Browse files Browse the repository at this point in the history
* Make observe pass options to sample

* Add test for observe parameters
  • Loading branch information
stuhlmueller authored and null-a committed Mar 22, 2017
1 parent 6d3b353 commit 9197bd9
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/header.js
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,17 @@ module.exports = function(env) {
// If observation value is given then factor accordingly,
// otherwise sample a new value.
// The value is passed to the continuation.
env.observe = function(s, k, a, dist, val) {
env.observe = function(s, k, a, dist, val, options) {
if (typeof env.coroutine.observe === 'function') {
return env.coroutine.observe(s, k, a, dist, val);
return env.coroutine.observe(s, k, a, dist, val, options);
} else {
if (val !== undefined) {
var factorK = function(s) {
return k(s, val);
};
return env.factor(s, factorK, a, dist.score(val));
} else {
return env.sample(s, k, a, dist);
return env.sample(s, k, a, dist, options);
}
}
};
Expand Down
2 changes: 1 addition & 1 deletion src/inference/dream/sample.js
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ module.exports = function(env) {
throw new Error('dream: factor not supported, use observe instead.');
},

observe: function(s, k, a, dist) {
observe: function(s, k, a, dist, options) {
if (!this.insideMapData) {
throw new Error('dream: observe can only be used within mapData with this estimator.');
}
Expand Down
3 changes: 3 additions & 0 deletions tests/test-data/deterministic/expected/observeParams.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"result": "ok"
}
6 changes: 6 additions & 0 deletions tests/test-data/deterministic/models/observeParams.wppl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
sample(Infer({ method: 'forward', samples: 1, guide: true, model: function() {
return observe(
Discrete({ ps: [1] }),
undefined,
{ guide: function() { return Delta({ v: 'ok' }); }});
}}));

0 comments on commit 9197bd9

Please sign in to comment.