-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #877 from null-a/mixture2
Add mixture distribution primitive
- Loading branch information
Showing
8 changed files
with
264 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
'use strict'; | ||
|
||
var _ = require('lodash'); | ||
var ad = require('../ad'); | ||
var base = require('./base'); | ||
var types = require('../types'); | ||
var numeric = require('../math/numeric'); | ||
var util = require('../util'); | ||
var Discrete = require('./discrete').Discrete; | ||
|
||
function continuousSupportEq(s1, s2) { | ||
return s1 === s2 || | ||
(s1 !== undefined && | ||
s2 !== undefined && | ||
s1.lower === s2.lower && | ||
s1.upper === s2.upper); | ||
} | ||
|
||
function unionDiscreteSupports(supports) { | ||
return _.chain(supports) | ||
.flatten() | ||
.uniqWith(supportElemEq) | ||
.value(); | ||
} | ||
|
||
function supportElemEq(x, y) { | ||
return util.serialize(x) === util.serialize(y); | ||
} | ||
|
||
var Mixture = base.makeDistributionType({ | ||
name: 'Mixture', | ||
desc: 'A finite mixture of distributions. ' + | ||
'The component distributions should be either all discrete or all continuous. ' + | ||
'All continuous distributions should share a common support.', | ||
params: [ | ||
{ | ||
name: 'dists', | ||
desc: 'array of component distributions' | ||
}, | ||
{ | ||
name: 'ps', | ||
desc: 'component probabilities (can be unnormalized)', | ||
type: types.nonNegativeVectorOrRealArray | ||
} | ||
], | ||
wikipedia: false, | ||
constructor: function() { | ||
var dists = this.params.dists; | ||
var ps = this.params.ps; | ||
|
||
if (!_.isArray(dists)) { | ||
throw new Error('Parameter dists should be an array.'); | ||
} | ||
|
||
if (dists.length !== ad.value(ps).length) { | ||
throw new Error('Parameters ps and dists should have the same length.'); | ||
} | ||
|
||
if (dists.length === 0) { | ||
throw new Error('Parameters ps and dists should be non-empty.'); | ||
} | ||
|
||
if (!_.every(dists, base.isDist)) { | ||
throw new Error('Parameter dists should be an array of distributions.'); | ||
} | ||
|
||
this.isContinuous = dists[0].isContinuous; | ||
var support_0 = this.isContinuous ? dists[0].support && dists[0].support() : undefined; | ||
|
||
for (var i = 1; i < dists.length; i++) { | ||
var dist_i = dists[i]; | ||
if (dist_i.isContinuous !== this.isContinuous) { | ||
throw new Error('Mixtures combining discrete and continuous distributions are not supported.'); | ||
} | ||
if (this.isContinuous) { | ||
var support_i = dist_i.support && dist_i.support(); | ||
if (!continuousSupportEq(support_0, support_i)) { | ||
throw new Error('All continuous distributions should have the same support.'); | ||
} | ||
} | ||
} | ||
|
||
if (this.isContinuous) { | ||
this.support = support_0 && _.constant(support_0); | ||
} else { | ||
this.support = function() { | ||
return unionDiscreteSupports(_.invokeMap(dists, 'support')); | ||
}; | ||
} | ||
|
||
this.indicatorDist = new Discrete({ps: ps}, true); | ||
}, | ||
sample: function() { | ||
var i = this.indicatorDist.sample(); | ||
return this.params.dists[i].sample(); | ||
}, | ||
score: function(val) { | ||
'use ad'; | ||
var dists = this.params.dists; | ||
var s = -Infinity; | ||
for (var i = 0; i < dists.length; i++) { | ||
s = numeric.logaddexp(s, this.indicatorDist.score(i) + dists[i].score(val)); | ||
} | ||
return s; | ||
} | ||
}); | ||
|
||
module.exports = { | ||
Mixture: Mixture | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
{ | ||
"result": true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
var approxEq = function(x, y) { | ||
return Math.abs(x - y) < 1e-8; | ||
}; | ||
|
||
var g1 = Gaussian({mu: ad.lift(0), sigma: 1}); | ||
var g2 = Gaussian({mu: 1, sigma: 2}); | ||
var g3 = DiagCovGaussian({mu: ad.lift(Vector([0])), sigma: Vector([1])}); | ||
var g4 = DiagCovGaussian({mu: Vector([1]), sigma: Vector([2])}); | ||
|
||
var discreteMixture = Mixture({ | ||
ps: [0.4,0.6], | ||
dists: [ | ||
Categorical({vs: [[0], [1]]}), | ||
Categorical({vs: [[1], [2], [3]]}) | ||
] | ||
}); | ||
|
||
var cases = [ | ||
|
||
// Scoring | ||
|
||
(function() { | ||
var x = 3; | ||
var m = Mixture({ps: [ad.lift(0.4), 0.6], dists: [g1, g2]}); | ||
var trueScore = Math.log(0.4 * Math.exp(g1.score(x)) + 0.6 * Math.exp(g2.score(x))); | ||
return approxEq(m.score(x), trueScore); | ||
})(), | ||
|
||
(function() { | ||
var x = 3; | ||
var m = Mixture({ps: ad.lift(Vector([0.4, 0.6])), dists: [g1, g2]}); | ||
var trueScore = Math.log(0.4 * Math.exp(g1.score(x)) + 0.6 * Math.exp(g2.score(x))); | ||
return approxEq(m.score(x), trueScore); | ||
})(), | ||
|
||
(function() { | ||
var x = 3; | ||
var m = Mixture({ps: [ad.lift(0.4), 0.6], dists: [g3, g4]}); | ||
var trueScore = Math.log(0.4 * Math.exp(g1.score(x)) + 0.6 * Math.exp(g2.score(x))); | ||
return approxEq(m.score(Vector([x])), trueScore); | ||
})(), | ||
|
||
(function() { | ||
return approxEq(discreteMixture.score([0]), Math.log(0.2)) && | ||
approxEq(discreteMixture.score([1]), Math.log(0.4)) && | ||
approxEq(discreteMixture.score([2]), Math.log(0.2)) && | ||
approxEq(discreteMixture.score([3]), Math.log(0.2)) && | ||
discreteMixture.score([4]) === -Infinity; | ||
})(), | ||
|
||
// Sampling | ||
|
||
(function() { | ||
var m = Mixture({ps: [1, 0], dists: [Gaussian({mu: -3, sigma: 1e-12}), g1]}); | ||
return approxEq(sample(m), -3); | ||
})(), | ||
|
||
(function() { | ||
var m = Mixture({ps: [1, 0], dists: [Delta({v: 'a'}), Bernoulli({p: 0.5})]}); | ||
return sample(m) === 'a'; | ||
})(), | ||
|
||
// Support | ||
|
||
(function() { | ||
var m = Mixture({ps: [0.5, 0.5], dists: [g1, g2]}); | ||
return m.support === undefined; | ||
})(), | ||
|
||
(function() { | ||
var m = Mixture({ps: [0.5, 0.5], dists: [Uniform({a: 0, b: 1}), Uniform({a: 0, b: 1})]}); | ||
return m.support().lower === 0 && m.support().upper === 1; | ||
})(), | ||
|
||
(function() { | ||
return _.isEqual(discreteMixture.support(), [[0], [1], [2], [3]]); | ||
})(), | ||
|
||
// `isContinuous` flag | ||
|
||
(function() { | ||
var m = Mixture({ps: [0.5, 0.5], dists: [Uniform({a: 0, b: 1}), Uniform({a: 0, b: 1})]}); | ||
return m.isContinuous; | ||
})(), | ||
|
||
(function() { | ||
return !discreteMixture.isContinuous; | ||
})() | ||
|
||
]; | ||
|
||
all(idF, cases); |