Skip to content

Commit

Permalink
Merge pull request #872 from null-a/kde
Browse files Browse the repository at this point in the history
Add KDE distribution
  • Loading branch information
stuhlmueller committed Jul 12, 2017
2 parents 0a4a6b3 + e4aa07b commit d2d0a5c
Show file tree
Hide file tree
Showing 10 changed files with 226 additions and 7 deletions.
5 changes: 5 additions & 0 deletions docs/functions/other.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,8 @@ Other
.. js:function:: error(msg)

Halts execution of the program and prints ``msg`` to the console.

.. js:function:: kde(marginal[, kernelWidth])

Constructs a :js:func:`KDE` distribution from a sample based
marginal distribution.
9 changes: 9 additions & 0 deletions docs/primitive-distributions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@

`Wikipedia entry <https://en.wikipedia.org/wiki/Normal_distribution>`__

.. js:function:: KDE({data: ..., width: ...})

* data: data array
* width: kernel width

A distribution based on a kernel density estimate of ``data``. A Gaussian kernel is used, and both real and vector valued data are supported. When the data are vector valued, ``width`` should be a vector specifying the kernel width for each dimension of the data. When ``width`` is omitted, Silverman's rule of thumb is used to select a kernel width. This rule assumes the data are approximately Gaussian distributed. When this assumption does not hold, a ``width`` should be specified in order to obtain sensible results.

`Wikipedia entry <https://en.wikipedia.org/wiki/Kernel_density_estimation>`__

.. js:function:: Laplace({location: ..., scale: ...})

* location: *(real)*
Expand Down
18 changes: 12 additions & 6 deletions src/aggregation/ScoreAggregator.ad.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@ var _ = require('lodash');
var dists = require('../dists');
var util = require('../util');

function logsumexp(a, b) {
assert.ok(a !== -Infinity || b !== -Infinity);
var m = Math.max(a, b);
return Math.log(Math.exp(a - m) + Math.exp(b - m)) + m;
function logaddexp(a, b) {
if (a === -Infinity) {
return b;
} else if (b === -Infinity) {
return a;
} else if (a > b) {
return Math.log(1 + Math.exp(b - a)) + a;
} else {
return Math.log(1 + Math.exp(a - b)) + b;
}
}

var ScoreAggregator = function() {
Expand All @@ -28,13 +34,13 @@ ScoreAggregator.prototype.add = function(value, score) {
if (this.dist[key] === undefined) {
this.dist[key] = { score: -Infinity, val: value };
}
this.dist[key].score = logsumexp(this.dist[key].score, score);
this.dist[key].score = logaddexp(this.dist[key].score, score);
};

function normalize(dist) {
// Note, this also maps dist from log space into probability space.
var logNorm = _.reduce(dist, function(acc, obj) {
return logsumexp(acc, obj.score);
return logaddexp(acc, obj.score);
}, -Infinity);
return _.mapValues(dist, function(obj) {
return { val: obj.val, prob: Math.exp(obj.score - logNorm) };
Expand Down
102 changes: 101 additions & 1 deletion src/dists.ad.js
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ var util = require('./util');
var assert = require('assert');
var inspect = require('util').inspect;
var types = require('./types');
var stats = require('./math/statistics');

var T = ad.tensor;

Expand Down Expand Up @@ -1530,6 +1531,104 @@ function printMarginal(dist) {
.join('\n');
}

// The implementation of defaultWidth for both Gaussian kernels uses
// Silverman's rule of thumb:
// https://en.wikipedia.org/wiki/Multivariate_kernel_density_estimation#Rule_of_thumb

var kdeKernels = {
gaussian: {
dataType: types.unboundedReal,
widthType: types.positiveReal,
sample: gaussianSample,
score: gaussianScore,
defaultWidth: function(data) {
var sd = stats.sd(data);
var n = data.length;
var width = 1.06 * sd * Math.pow(n, -0.2);
return width;
}
},
mvGaussian: {
dataType: types.unboundedVector,
widthType: types.positiveVectorCB,
sample: diagCovGaussianSample,
score: diagCovGaussianScore,
defaultWidth: function(data) {
var d = data[0].dims[0];
var n = data.length;
var mean = data.reduce(function(acc, x) {
return acc.add(x);
}).div(n);
var sd = data.reduce(function(acc, x) {
return acc.add(x.sub(mean).pow(2));
}, new Tensor(data[0].dims)).div(n).sqrt();
return sd.mul(Math.pow(4 / (d + 2), 1 / (d + 4)) * Math.pow(n, -1 / (d + 4)));
}
}
};

var KDE = makeDistributionType({
name: 'KDE',
desc: 'A distribution based on a kernel density estimate of ``data``. ' +
'A Gaussian kernel is used, and both real and vector valued data are supported. ' +
'When the data are vector valued, ``width`` should be a vector specifying the kernel ' +
'width for each dimension of the data. ' +
'When ``width`` is omitted, Silverman\'s rule of thumb ' +
'is used to select a kernel width. This rule assumes the data are ' +
'approximately Gaussian distributed. When this assumption does not hold, a ``width`` ' +
'should be specified in order to obtain sensible results.',
params: [
{name: 'data', desc: 'data array'},
{name: 'width', desc: 'kernel width', optional: true}
],
wikipedia: 'Kernel_density_estimation',
nohelper: true,
mixins: [continuousSupport],
constructor: function() {
// Check data parameter.
if (!_.isArray(this.params.data) ||
_.isEmpty(this.params.data)) {
throw new Error('Parameter "data" should be a non-empty array.');
}

// We assume an homogeneous array, and perform type checks on the
// first element of the array only.
var data = this.params.data;
this.kernel = _.find(kdeKernels, function(kernel) {
return kernel.dataType.check(data[0]);
});
if (!this.kernel) {
throw new Error('Parameter "data" should be an array of reals or vectors.');
}

// Compute default width if omitted.
if (this.params.width === undefined) {
this.params.width = this.kernel.defaultWidth(this.params.data);
}

// Check width parameter.
if (!this.kernel.widthType.check(this.params.width)) {
throw new Error('Parameter "width" should be of type ' + this.kernel.widthType.desc);
}
},
sample: function() {
var data = this.params.data;
var width = this.params.width;
var x = data[Math.floor(util.random() * data.length)];
return this.kernel.sample(x, width);
},
score: function(val) {
var data = this.params.data;
var width = this.params.width;
var n = data.length;
var kernel = this.kernel;
return data.reduce(
function(acc, x) {
return util.logaddexp(acc, kernel.score(x, width, val));
},
-Infinity) - Math.log(n);
}
});

var Categorical = makeDistributionType({
name: 'Categorical',
Expand Down Expand Up @@ -1643,7 +1742,8 @@ var distributions = {
Marginal: Marginal,
SampleBasedMarginal: SampleBasedMarginal,
Categorical: Categorical,
Delta: Delta
Delta: Delta,
KDE: KDE
};

// For each distribution type, we create a WebPPL function that
Expand Down
9 changes: 9 additions & 0 deletions src/header.wppl
Original file line number Diff line number Diff line change
Expand Up @@ -606,3 +606,12 @@ var marginalize = function(dist, project) {
}});
}
};

var kde = function(marginal, width) {
if (!(dists.isDist(marginal) &&
marginal.meta.name === 'SampleBasedMarginal')) {
error('kde expects a sample based marginal as its first argument.');
}
var data = _.map(marginal.samples, 'value');
return KDE(width === undefined ? {data} : {data, width});
};
1 change: 1 addition & 0 deletions src/types.js
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ module.exports = {
unboundedVector: vector(parseInterval('(-Infinity, Infinity)')),
nonNegativeVector: vector(parseInterval('[0, Infinity)')),
positiveVector: vector(parseInterval('(0, Infinity)')),
positiveVectorCB: vector(parseInterval('(0, Infinity)'), true),
unitIntervalVector: vector(parseInterval('[0, 1]')),
unboundedVectorOrRealArray: vectorOrRealArray(parseInterval('(-Infinity, Infinity)')),
nonNegativeVectorOrRealArray: vectorOrRealArray(parseInterval('[0, Infinity)')),
Expand Down
13 changes: 13 additions & 0 deletions src/util.js
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,18 @@ function logsumexp(a) {
return m + Math.log(sum);
}

function logaddexp(a, b) {
if (a === -Infinity) {
return b;
} else if (b === -Infinity) {
return a;
} else if (a > b) {
return Math.log(1 + Math.exp(b - a)) + a;
} else {
return Math.log(1 + Math.exp(a - b)) + b;
}
}

var deleteIndex = function(arr, i) {
return arr.slice(0, i).concat(arr.slice(i + 1))
}
Expand Down Expand Up @@ -385,6 +397,7 @@ module.exports = {
histsApproximatelyEqual: histsApproximatelyEqual,
gensym: gensym,
logsumexp: logsumexp,
logaddexp: logaddexp,
deleteIndex: deleteIndex,
makeGensym: makeGensym,
prettyJSON: prettyJSON,
Expand Down
3 changes: 3 additions & 0 deletions tests/test-data/deterministic/expected/kde.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"result": true
}
50 changes: 50 additions & 0 deletions tests/test-data/deterministic/models/kde.wppl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
var approxEq = function(x, y) {
return Math.abs(x - y) < 0.001;
};

var cases = [
(function() {
var d = KDE({data: [1, 2], width: 1.5});
return approxEq(d.score(4), -2.621);
})(),

(function() {
var kde = KDE({data: [5], width: 0.00001});
return approxEq(sample(kde), 5);
})(),

(function() {
var kde = KDE({data: [1, 2]});
return approxEq(kde.params.width, 0.461);
})(),

(function() {
var data = [Vector([0, 1]), Vector([2, 4])];
var width = Vector([1, 1.5]);
var kde = KDE({data, width});
return approxEq(kde.score(Vector([-1, -2])), -5.436);
})(),

(function() {
var kde = KDE({data: [Vector([5, 6])], width: Vector([0.00001, 0.00001])});
var x0 = T.get(sample(kde), 0);
var x1 = T.get(sample(kde), 1);
return approxEq(x0, 5) && approxEq(x1, 6);
})(),

(function() {
var kde = KDE({data: [Vector([1]), Vector([2])]});
var width = T.get(kde.params.width, 0);
return approxEq(width, 0.461);
})(),

(function() {
var m = Infer({method: 'forward', samples: 1, model() {
return delta(5);
}});
var x = sample(kde(m, 0.00001));
return approxEq(x, 5);
})()
];

all(idF, cases);
23 changes: 23 additions & 0 deletions tests/test-util.js
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,29 @@ module.exports = {

},

testLogAddExp: {
test1: function(test) {
testAlmostEqual(test, Math.exp(util.logaddexp(Math.log(1), Math.log(2))), 3, 1e-6);
test.done();
},
test2: function(test) {
testAlmostEqual(test, Math.exp(util.logaddexp(Math.log(2), Math.log(1))), 3, 1e-6);
test.done();
},
test3: function(test) {
testAlmostEqual(test, Math.exp(util.logaddexp(-Infinity, Math.log(1))), 1, 1e-6);
test.done();
},
test4: function(test) {
testAlmostEqual(test, Math.exp(util.logaddexp(Math.log(1), -Infinity)), 1, 1e-6);
test.done();
},
test5: function(test) {
testAlmostEqual(test, Math.exp(util.logaddexp(-Infinity, -Infinity)), 0, 1e-6);
test.done();
}
},

testCpsIterate: {

test1: function(test) {
Expand Down

0 comments on commit d2d0a5c

Please sign in to comment.