Skip to content

Commit

Permalink
Merge pull request #877 from null-a/mixture2
Browse files Browse the repository at this point in the history
Add mixture distribution primitive
  • Loading branch information
stuhlmueller committed Jul 28, 2017
2 parents a4e93d5 + 28ba319 commit fd46677
Show file tree
Hide file tree
Showing 8 changed files with 264 additions and 32 deletions.
7 changes: 7 additions & 0 deletions docs/primitive-distributions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,13 @@

`Wikipedia entry <https://en.wikipedia.org/wiki/Logit-normal_distribution>`__

.. js:function:: Mixture({dists: ..., ps: ...})

* dists: array of component distributions
* ps: component probabilities (can be unnormalized) *(vector or real array [0, Infinity))*

A finite mixture of distributions. The component distributions should be either all discrete or all continuous. All continuous distributions should share a common support.

.. js:function:: Multinomial({ps: ..., n: ...})

* ps: probabilities *(real array with elements that sum to one)*
Expand Down
31 changes: 19 additions & 12 deletions src/dists/base.ad.js
Original file line number Diff line number Diff line change
Expand Up @@ -175,20 +175,27 @@ function makeDistributionType(options) {
// Note that Chrome uses the name of this local variable in the
// output of `console.log` when it's called on a distribution that
// uses the default constructor.
var dist = function(params) {

// The option to skip parameter checks is only used internally. It
// makes it possible to avoid performing checks multiple times when
// one distribution uses another distribution internally.

var dist = function(params, skipParamChecks) {
params = params || {};
parameterNames.forEach(function(p, i) {
if (params.hasOwnProperty(p)) {
var type = parameterTypes[i];
if (type && !type.check(ad.valueRec(params[p]))) {
throw new Error('Parameter \"' + p + '\" should be of type "' + type.desc + '".');
}
} else {
if (!parameterOptionalFlags[i]) {
throw new Error('Parameter \"' + p + '\" missing from ' + this.meta.name + ' distribution.');
if (!skipParamChecks) {
parameterNames.forEach(function(p, i) {
if (params.hasOwnProperty(p)) {
var type = parameterTypes[i];
if (type && !type.check(ad.valueRec(params[p]))) {
throw new Error('Parameter \"' + p + '\" should be of type "' + type.desc + '".');
}
} else {
if (!parameterOptionalFlags[i]) {
throw new Error('Parameter \"' + p + '\" missing from ' + this.meta.name + ' distribution.');
}
}
}
}, this);
}, this);
}
this.params = params;
if (extraConstructorFn !== undefined) {
extraConstructorFn.call(this);
Expand Down
51 changes: 31 additions & 20 deletions src/dists/discrete.ad.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ var util = require('../util');
var numeric = require('../math/numeric');
var T = ad.tensor;

function sample(theta) {
var thetaSum = numeric._sum(theta);
function sample(theta, thetaSum) {
if (thetaSum === undefined) {
thetaSum = numeric._sum(theta);
}
var x = util.random() * thetaSum;
var k = theta.length;
var probAccum = 0;
Expand All @@ -22,24 +24,15 @@ function sample(theta) {
return k - 1;
}

function score(ps, i) {
var scoreFn = _.isArray(ps) ? scoreArray : scoreVector;
return scoreFn(ps, i);
}

function scoreVector(probs, val) {
function scoreVector(val, probs, norm) {
'use ad';
var _probs = ad.value(probs);
var d = _probs.dims[0];
return inSupport(val, d) ?
Math.log(T.get(probs, val) / T.sumreduce(probs)) :
-Infinity;
return Math.log(T.get(probs, val) / norm);

}

function scoreArray(probs, val) {
function scoreArray(val, probs, norm) {
'use ad';
var d = probs.length;
return inSupport(val, d) ? Math.log(probs[val] / numeric.sum(probs)) : -Infinity;
return Math.log(probs[val] / norm);
}

function inSupport(val, dim) {
Expand All @@ -60,15 +53,33 @@ var Discrete = base.makeDistributionType({
],
wikipedia: 'Categorical_distribution',
mixins: [base.finiteSupport],
constructor: function() {
// Compute the norm here, as it's required for both sampling and
// scoring.
if (_.isArray(this.params.ps)) {
this.norm = numeric.sum(this.params.ps);
this.scoreFn = scoreArray;
this.dim = this.params.ps.length;
}
else {
this.norm = T.sumreduce(this.params.ps);
this.scoreFn = scoreVector;
this.dim = ad.value(this.params.ps).length;
}
},
sample: function() {
return sample(toUnliftedArray(this.params.ps));
return sample(toUnliftedArray(this.params.ps), ad.value(this.norm));
},
score: function(val) {
return score(this.params.ps, val);
if (inSupport(val, this.dim)) {
return this.scoreFn(val, this.params.ps, this.norm);
}
else {
return -Infinity;
}
},
support: function() {
// This does the right thing for arrays and vectors.
return _.range(ad.value(this.params.ps).length);
return _.range(this.dim);
}
});

Expand Down
1 change: 1 addition & 0 deletions src/dists/index.ad.js
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ var distributions = _.chain(
['LogisticNormal', require('./logisticNormal')],
['LogitNormal', require('./logitNormal')],
['Marginal', require('./marginal')],
['Mixture', require('./mixture')],
['Multinomial', require('./multinomial')],
['MultivariateBernoulli', require('./multivariateBernoulli')],
['MultivariateGaussian', require('./multivariateGaussian')],
Expand Down
110 changes: 110 additions & 0 deletions src/dists/mixture.ad.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
'use strict';

var _ = require('lodash');
var ad = require('../ad');
var base = require('./base');
var types = require('../types');
var numeric = require('../math/numeric');
var util = require('../util');
var Discrete = require('./discrete').Discrete;

function continuousSupportEq(s1, s2) {
return s1 === s2 ||
(s1 !== undefined &&
s2 !== undefined &&
s1.lower === s2.lower &&
s1.upper === s2.upper);
}

function unionDiscreteSupports(supports) {
return _.chain(supports)
.flatten()
.uniqWith(supportElemEq)
.value();
}

function supportElemEq(x, y) {
return util.serialize(x) === util.serialize(y);
}

var Mixture = base.makeDistributionType({
name: 'Mixture',
desc: 'A finite mixture of distributions. ' +
'The component distributions should be either all discrete or all continuous. ' +
'All continuous distributions should share a common support.',
params: [
{
name: 'dists',
desc: 'array of component distributions'
},
{
name: 'ps',
desc: 'component probabilities (can be unnormalized)',
type: types.nonNegativeVectorOrRealArray
}
],
wikipedia: false,
constructor: function() {
var dists = this.params.dists;
var ps = this.params.ps;

if (!_.isArray(dists)) {
throw new Error('Parameter dists should be an array.');
}

if (dists.length !== ad.value(ps).length) {
throw new Error('Parameters ps and dists should have the same length.');
}

if (dists.length === 0) {
throw new Error('Parameters ps and dists should be non-empty.');
}

if (!_.every(dists, base.isDist)) {
throw new Error('Parameter dists should be an array of distributions.');
}

this.isContinuous = dists[0].isContinuous;
var support_0 = this.isContinuous ? dists[0].support && dists[0].support() : undefined;

for (var i = 1; i < dists.length; i++) {
var dist_i = dists[i];
if (dist_i.isContinuous !== this.isContinuous) {
throw new Error('Mixtures combining discrete and continuous distributions are not supported.');
}
if (this.isContinuous) {
var support_i = dist_i.support && dist_i.support();
if (!continuousSupportEq(support_0, support_i)) {
throw new Error('All continuous distributions should have the same support.');
}
}
}

if (this.isContinuous) {
this.support = support_0 && _.constant(support_0);
} else {
this.support = function() {
return unionDiscreteSupports(_.invokeMap(dists, 'support'));
};
}

this.indicatorDist = new Discrete({ps: ps}, true);
},
sample: function() {
var i = this.indicatorDist.sample();
return this.params.dists[i].sample();
},
score: function(val) {
'use ad';
var dists = this.params.dists;
var s = -Infinity;
for (var i = 0; i < dists.length; i++) {
s = numeric.logaddexp(s, this.indicatorDist.score(i) + dists[i].score(val));
}
return s;
}
});

module.exports = {
Mixture: Mixture
};
1 change: 1 addition & 0 deletions src/guide.js
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ function spec(targetDist) {
} else if (targetDist instanceof dists.Binomial) {
return discreteSpec(targetDist.params.n + 1);
} else if (targetDist instanceof dists.MultivariateGaussian ||
targetDist instanceof dists.Mixture ||
targetDist instanceof dists.Marginal ||
targetDist instanceof dists.SampleBasedMarginal) {
throwAutoGuideError(targetDist);
Expand Down
3 changes: 3 additions & 0 deletions tests/test-data/deterministic/expected/mixture.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"result": true
}
92 changes: 92 additions & 0 deletions tests/test-data/deterministic/models/mixture.wppl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
var approxEq = function(x, y) {
return Math.abs(x - y) < 1e-8;
};

var g1 = Gaussian({mu: ad.lift(0), sigma: 1});
var g2 = Gaussian({mu: 1, sigma: 2});
var g3 = DiagCovGaussian({mu: ad.lift(Vector([0])), sigma: Vector([1])});
var g4 = DiagCovGaussian({mu: Vector([1]), sigma: Vector([2])});

var discreteMixture = Mixture({
ps: [0.4,0.6],
dists: [
Categorical({vs: [[0], [1]]}),
Categorical({vs: [[1], [2], [3]]})
]
});

var cases = [

// Scoring

(function() {
var x = 3;
var m = Mixture({ps: [ad.lift(0.4), 0.6], dists: [g1, g2]});
var trueScore = Math.log(0.4 * Math.exp(g1.score(x)) + 0.6 * Math.exp(g2.score(x)));
return approxEq(m.score(x), trueScore);
})(),

(function() {
var x = 3;
var m = Mixture({ps: ad.lift(Vector([0.4, 0.6])), dists: [g1, g2]});
var trueScore = Math.log(0.4 * Math.exp(g1.score(x)) + 0.6 * Math.exp(g2.score(x)));
return approxEq(m.score(x), trueScore);
})(),

(function() {
var x = 3;
var m = Mixture({ps: [ad.lift(0.4), 0.6], dists: [g3, g4]});
var trueScore = Math.log(0.4 * Math.exp(g1.score(x)) + 0.6 * Math.exp(g2.score(x)));
return approxEq(m.score(Vector([x])), trueScore);
})(),

(function() {
return approxEq(discreteMixture.score([0]), Math.log(0.2)) &&
approxEq(discreteMixture.score([1]), Math.log(0.4)) &&
approxEq(discreteMixture.score([2]), Math.log(0.2)) &&
approxEq(discreteMixture.score([3]), Math.log(0.2)) &&
discreteMixture.score([4]) === -Infinity;
})(),

// Sampling

(function() {
var m = Mixture({ps: [1, 0], dists: [Gaussian({mu: -3, sigma: 1e-12}), g1]});
return approxEq(sample(m), -3);
})(),

(function() {
var m = Mixture({ps: [1, 0], dists: [Delta({v: 'a'}), Bernoulli({p: 0.5})]});
return sample(m) === 'a';
})(),

// Support

(function() {
var m = Mixture({ps: [0.5, 0.5], dists: [g1, g2]});
return m.support === undefined;
})(),

(function() {
var m = Mixture({ps: [0.5, 0.5], dists: [Uniform({a: 0, b: 1}), Uniform({a: 0, b: 1})]});
return m.support().lower === 0 && m.support().upper === 1;
})(),

(function() {
return _.isEqual(discreteMixture.support(), [[0], [1], [2], [3]]);
})(),

// `isContinuous` flag

(function() {
var m = Mixture({ps: [0.5, 0.5], dists: [Uniform({a: 0, b: 1}), Uniform({a: 0, b: 1})]});
return m.isContinuous;
})(),

(function() {
return !discreteMixture.isContinuous;
})()

];

all(idF, cases);

0 comments on commit fd46677

Please sign in to comment.