Skip to content

Commit

Permalink
Merge pull request #773 from null-a/only-map-fwd-sample
Browse files Browse the repository at this point in the history
Add onlyMAP option to forward sampling method
  • Loading branch information
stuhlmueller committed Feb 14, 2017
2 parents d5c974a + 63113aa commit 13ba93d
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 13 deletions.
14 changes: 14 additions & 0 deletions docs/inference/methods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,13 @@ Optimization

Default: ``1``

.. describe:: onlyMAP

When ``true``, only the sample with the highest score is
retained. The marginal is a delta distribution on this value.

Default: ``false``

In addition, all of the options supported by :ref:`Optimize
<optimize>` are also supported here.

Expand Down Expand Up @@ -368,6 +375,13 @@ Forward Sampling

Default: ``false``

.. describe:: onlyMAP

When ``true``, only the sample with the highest score is
retained. The marginal is a delta distribution on this value.

Default: ``false``

Example usage::

Infer({method: 'forward', model: model});
Expand Down
2 changes: 1 addition & 1 deletion src/header.wppl
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ var SampleGuide = function(wpplFn, options) {

var OptimizeThenSample = function(wpplFn, options) {
Optimize(wpplFn, _.omit(options, 'samples'));
var opts = _.pick(options, 'samples', 'verbose');
var opts = _.pick(options, 'samples', 'onlyMAP', 'verbose');
return SampleGuide(wpplFn, opts);
};

Expand Down
17 changes: 12 additions & 5 deletions src/inference/forwardSample.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ module.exports = function(env) {
this.opts = util.mergeDefaults(options, {
samples: 1,
guide: false, // true = sample guide, false = sample target
onlyMAP: false,
verbose: false
});

Expand All @@ -34,18 +35,19 @@ module.exports = function(env) {

run: function() {

var hist = new CountAggregator();
var hist = new CountAggregator(this.opts.onlyMAP);
var logWeights = []; // Save total factor weights

return util.cpsLoop(
this.opts.samples,

// Loop body.
function(i, next) {
this.score = 0;
this.logWeight = 0;
return this.wpplFn(_.clone(this.s), function(s, val) {
logWeights.push(this.logWeight);
hist.add(val);
hist.add(val, this.score);
return next();
}.bind(this), this.a);
}.bind(this),
Expand All @@ -64,16 +66,21 @@ module.exports = function(env) {
},

sample: function(s, k, a, dist, options) {
var cont = function(s, dist) {
var val = dist.sample();
this.score += dist.score(val);
return k(s, val);
}.bind(this);

if (this.opts.guide) {
options = options || {};
return guide.getDist(
options.guide, options.noAutoGuide, dist, env, s, a,
function(s, maybeGuideDist) {
var d = maybeGuideDist || dist;
return k(s, d.sample());
return cont(s, maybeGuideDist || dist);
});
} else {
return k(s, dist.sample());
return cont(s, dist);
}
},

Expand Down
3 changes: 2 additions & 1 deletion tests/test-data/stochastic/expected/onlyMAP.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{
"mean": 1
"mean": 1,
"std": 0
}
43 changes: 37 additions & 6 deletions tests/test-inference.js
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ var tests = [
mixed4: true,
bivariateGaussian: true,
indirectDependency: true,
guidedFlip: true
guidedFlip: true,
onlyMAP: {
mean: { tol: 0.1 },
std: { tol: 0 },
args: { samples: 150, onlyMAP: true }
}
}
},
{
Expand Down Expand Up @@ -112,7 +117,11 @@ var tests = [
variableSupport: true,
query: true,
query2: { hist: { tol: 0.1, exactSupport: true } },
onlyMAP: { mean: { tol: 0.1 }, args: { samples: 150, onlyMAP: true } },
onlyMAP: {
mean: { tol: 0.1 },
std: { tol: 0 },
args: { samples: 150, onlyMAP: true }
},
nestedEnum1: { mean: { tol: 0.1 }, std: { tol: 0.075 } },
nestedEnum2: { mean: { tol: 0.1 }, std: { tol: 0.075 } },
nestedEnum3: { mean: { tol: 0.1 }, std: { tol: 0.075 } },
Expand Down Expand Up @@ -222,7 +231,11 @@ var tests = [
store: { hist: { tol: 0 }, args: { particles: 100 } },
store2: { hist: { tol: 0 }, args: { particles: 100 } },
notapes: { hist: { tol: 0 }, args: { samples: 100 } },
onlyMAP: { mean: { tol: 0.1 }, args: { particles: 150, onlyMAP: true } },
onlyMAP: {
mean: { tol: 0.1 },
std: { tol: 0 },
args: { particles: 150, onlyMAP: true }
},
gaussianMean: { mean: { tol: 0.3 }, std: { tol: 0.3 }, args: { particles: 10000 } },
varFactors1: { args: { particles: 5000 } },
varFactors2: true,
Expand Down Expand Up @@ -372,7 +385,11 @@ var tests = [
variableSupport: true,
query: true,
query2: { hist: { tol: 0.1, exactSupport: true } },
onlyMAP: { mean: { tol: 0.1 }, args: { samples: 150, onlyMAP: true } },
onlyMAP: {
mean: { tol: 0.1 },
std: { tol: 0 },
args: { samples: 150, onlyMAP: true }
},
nestedEnum1: { mean: { tol: 0.1 }, std: { tol: 0.075 } },
nestedEnum2: { mean: { tol: 0.1 }, std: { tol: 0.075 } },
nestedEnum3: { mean: { tol: 0.1 }, std: { tol: 0.075 } },
Expand Down Expand Up @@ -408,7 +425,11 @@ var tests = [
variableSupport: true,
query: true,
query2: { hist: { tol: 0.1, exactSupport: true } },
onlyMAP: { mean: { tol: 0.1 }, args: { samples: 150, kernel: 'HMC', onlyMAP: true } },
onlyMAP: {
mean: { tol: 0.1 },
std: { tol: 0 },
args: { samples: 150, kernel: 'HMC', onlyMAP: true }
},
mixed1: true,
mixed1Factor: true,
mixed2: {
Expand Down Expand Up @@ -631,7 +652,17 @@ var tests = [
verbose: false
}
},
mapData: true
mapData: true,
onlyMAP: {
mean: { tol: 0.1 },
std: { tol: 0 },
args: {
steps: 10000,
samples: 150,
onlyMAP: true,
verbose: false
}
}
}
}
];
Expand Down

0 comments on commit 13ba93d

Please sign in to comment.