Skip to content

Commit

Permalink
Merge pull request #750 from longouyang/add-marginalize-helper
Browse files Browse the repository at this point in the history
Add marginalization helper
  • Loading branch information
stuhlmueller committed Jan 28, 2017
2 parents 57db442 + 4bd2452 commit d812c0b
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 0 deletions.
24 changes: 24 additions & 0 deletions docs/functions/other.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,30 @@ Other

expectation(Categorical({ps: [.2, .8], vs: [0, 1]})); // => 0.8

.. js:function:: marginalize(dist, project)

Marginalizes out certain variables in a distribution. ``project``
can be either a function or a string. Using it as a function:

::

var dist = Infer({model: function() {
var a = flip(0.9);
var b = flip();
var c = flip();
return {a: a, b: b, c: c};
}});

marginalize(dist, function(x) {
return x.a;
}) // => Marginal with p(true) = 0.9, p(false) = 0.1

Using it as a string:

::

marginalize(dist, 'a') // => Marginal with p(true) = 0.9, p(false) = 0.1

.. js:function:: mapObject(fn, obj)

Returns the object obtained by mapping the function ``fn`` over the
Expand Down
12 changes: 12 additions & 0 deletions src/header.wppl
Original file line number Diff line number Diff line change
Expand Up @@ -538,3 +538,15 @@ var Tensor = function(dims, arr) {
var t = ad.tensor.fromScalars(arr);
return ad.tensor.reshape(t, dims);
};

var marginalize = function(dist, project) {
if (_.isString(project)) {
return Infer({model: function() {
return sample(dist)[project]
}})
} else {
return Infer({model: function() {
return project(sample(dist));
}});
}
};
3 changes: 3 additions & 0 deletions tests/test-data/deterministic/expected/marginalize.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"result": [5, 5]
}
13 changes: 13 additions & 0 deletions tests/test-data/deterministic/models/marginalize.wppl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
var dist = Infer({model: function() {
var a = uniformDraw([4, 5, 6]);
var b = flip();
var c = flip(0.9);
return {a: a, b: b, c: c};
}});

[
expectation(marginalize(dist, 'a')),
expectation(marginalize(dist, function(x) {
return x.a;
}))
]

0 comments on commit d812c0b

Please sign in to comment.