Skip to content

Commit

Permalink
Merge pull request #876 from null-a/numeric
Browse files Browse the repository at this point in the history
Pull more fns into numeric module.
  • Loading branch information
stuhlmueller committed Jul 26, 2017
2 parents 86c16ee + a39b501 commit a4e93d5
Show file tree
Hide file tree
Showing 15 changed files with 86 additions and 80 deletions.
17 changes: 3 additions & 14 deletions src/aggregation/ScoreAggregator.ad.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,7 @@ var assert = require('assert');
var _ = require('lodash');
var dists = require('../dists');
var util = require('../util');

function logaddexp(a, b) {
if (a === -Infinity) {
return b;
} else if (b === -Infinity) {
return a;
} else if (a > b) {
return Math.log(1 + Math.exp(b - a)) + a;
} else {
return Math.log(1 + Math.exp(a - b)) + b;
}
}
var numeric = require('../math/numeric');

var ScoreAggregator = function() {
this.dist = {};
Expand All @@ -34,13 +23,13 @@ ScoreAggregator.prototype.add = function(value, score) {
if (this.dist[key] === undefined) {
this.dist[key] = { score: -Infinity, val: value };
}
this.dist[key].score = logaddexp(this.dist[key].score, score);
this.dist[key].score = numeric.logaddexp(this.dist[key].score, score);
};

function normalize(dist) {
// Note, this also maps dist from log space into probability space.
var logNorm = _.reduce(dist, function(acc, obj) {
return logaddexp(acc, obj.score);
return numeric.logaddexp(acc, obj.score);
}, -Infinity);
return _.mapValues(dist, function(obj) {
return { val: obj.val, prob: Math.exp(obj.score - logNorm) };
Expand Down
2 changes: 1 addition & 1 deletion src/dists/discrete.ad.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ var numeric = require('../math/numeric');
var T = ad.tensor;

function sample(theta) {
var thetaSum = util.sum(theta);
var thetaSum = numeric._sum(theta);
var x = util.random() * thetaSum;
var k = theta.length;
var probAccum = 0;
Expand Down
3 changes: 2 additions & 1 deletion src/dists/kde.ad.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ var types = require('../types');
var util = require('../util');
var Tensor = require('../tensor');
var stats = require('../math/statistics');
var numeric = require('../math/numeric');
var gaussian = require('./gaussian');
var diagCovGaussian = require('./diagCovGaussian');

Expand Down Expand Up @@ -102,7 +103,7 @@ var KDE = base.makeDistributionType({
var kernel = this.kernel;
return data.reduce(
function(acc, x) {
return util.logaddexp(acc, kernel.score(x, width, val));
return numeric.logaddexp(acc, kernel.score(x, width, val));
},
-Infinity) - Math.log(n);
}
Expand Down
2 changes: 1 addition & 1 deletion src/dists/multinomial.ad.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function zeros(n) {
}

function sample(theta, n) {
// var thetaSum = util.sum(theta);
// var thetaSum = numeric._sum(theta);
var a = zeros(theta.length);
for (var i = 0; i < n; i++) {
a[discrete.sample(theta)]++;
Expand Down
5 changes: 3 additions & 2 deletions src/inference/asyncpf.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

var _ = require('lodash');
var util = require('../util');
var numeric = require('../math/numeric');
var CountAggregator = require('../aggregation/CountAggregator');

module.exports = function(env) {
Expand Down Expand Up @@ -114,8 +115,8 @@ module.exports = function(env) {
var currWeight = this.activeParticle.weight;
var denom = lk.length + currMultiplicity; // k - 1 + Ckn
var prevWBar = lk[lk.length - 1].wbar;
var wbar = -Math.log(denom) + util.logsumexp([Math.log(lk.length) + prevWBar,
Math.log(currMultiplicity) + currWeight]);
var wbar = -Math.log(denom) + numeric._logsumexp([Math.log(lk.length) + prevWBar,
Math.log(currMultiplicity) + currWeight]);
if (wbar > 0) throw new Error('Positive weight!!'); // sanity check
var logRatio = currWeight - wbar;
var numChildrenAndWeight = [];
Expand Down
3 changes: 2 additions & 1 deletion src/inference/forwardSample.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

var _ = require('lodash');
var util = require('../util');
var numeric = require('../math/numeric');
var CountAggregator = require('../aggregation/CountAggregator');
var ad = require('../ad');
var guide = require('../guide');
Expand Down Expand Up @@ -104,7 +105,7 @@ module.exports = function(env) {
function() {
var dist = hist.toDist();
if (!opts.guide) {
dist.normalizationConstant = util.logsumexp(logWeights) - Math.log(opts.samples);
dist.normalizationConstant = numeric._logsumexp(logWeights) - Math.log(opts.samples);
}
return k(s, dist);
}
Expand Down
3 changes: 2 additions & 1 deletion src/inference/mhkernel.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
var _ = require('lodash');
var assert = require('assert');
var util = require('../util');
var numeric = require('../math/numeric');
var ad = require('../ad');

module.exports = function(env) {
Expand Down Expand Up @@ -215,7 +216,7 @@ module.exports = function(env) {
var score = ad.value(proposalDist.score(regenChoice.val));

// Rest of the trace.
score += util.sum(toTrace.choices.slice(this.regenFrom + 1).map(function(choice) {
score += numeric._sum(toTrace.choices.slice(this.regenFrom + 1).map(function(choice) {
return this.reused.hasOwnProperty(choice.address) ? 0 : ad.value(choice.dist.score(choice.val));
}, this));

Expand Down
3 changes: 2 additions & 1 deletion src/inference/smc.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

var _ = require('lodash');
var util = require('../util');
var numeric = require('../math/numeric');
var discrete = require('../dists/discrete');
var Trace = require('../trace');

Expand Down Expand Up @@ -154,7 +155,7 @@ module.exports = function(env) {
}
// Residual resampling following Liu 2008; p. 72, section 3.4.4
var m = particles.length;
var logW = util.logsumexp(_.map(particles, 'logWeight'));
var logW = numeric._logsumexp(_.map(particles, 'logWeight'));
var logAvgW = logW - Math.log(m);
if (logAvgW === -Infinity) {
// do not return, execution continues
Expand Down
53 changes: 52 additions & 1 deletion src/math/numeric.ad.js
Original file line number Diff line number Diff line change
@@ -1,15 +1,40 @@
'use strict';

var _ = require('lodash');
var ad = require('../ad');

var LOG_PI = 1.1447298858494002;
var LOG_2PI = 1.8378770664093453;

// By convention, non-adified versions of functions are prefixed with
// an underscore.

function sum(xs) {
'use ad';
return xs.reduce(function(a, b) { return a + b; }, 0);
}

function _sum(xs) {
if (xs.length === 0) {
return 0.0;
} else {
var total = _.reduce(xs,
function(a, b) {
return a + b;
});
return total;
}
}

function product(xs) {
'use ad';
var result = 1;
for (var i = 0, n = xs.length; i < n; i++) {
result *= xs[i];
}
return result;
}

function fact(x) {
'use ad';
var t = 1;
Expand Down Expand Up @@ -47,11 +72,37 @@ function squishToProbSimplex(x) {
return ad.tensor.softmax(u);
}

function logaddexp(a, b) {
'use ad';
if (a === -Infinity) {
return b;
} else if (b === -Infinity) {
return a;
} else if (a > b) {
return Math.log(1 + Math.exp(b - a)) + a;
} else {
return Math.log(1 + Math.exp(a - b)) + b;
}
}

function _logsumexp(a) {
var m = Math.max.apply(null, a);
var sum = 0;
for (var i = 0; i < a.length; ++i) {
sum += (a[i] === -Infinity ? 0 : Math.exp(a[i] - m));
}
return m + Math.log(sum);
}

module.exports = {
LOG_PI: LOG_PI,
LOG_2PI: LOG_2PI,
sum: sum,
_sum: _sum,
product: product,
fact: fact,
lnfact: lnfact,
squishToProbSimplex: squishToProbSimplex
squishToProbSimplex: squishToProbSimplex,
logaddexp: logaddexp,
_logsumexp: _logsumexp
};
1 change: 0 additions & 1 deletion src/transforms/adify.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ var generate = require('escodegen').generate;
var build = require('ast-types').builders;
var _ = require('lodash');
var ad = require('./ad').ad;
var util = require('../util');

function isMarkedForGlobalTransform(ast) {
assert.ok(ast.type === 'Program');
Expand Down
3 changes: 2 additions & 1 deletion src/types.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

var _ = require('lodash');
var util = require('./util');
var numeric = require('./math/numeric');
var interval = require('./math/interval');

var isInterval = interval.isInterval;
Expand Down Expand Up @@ -136,7 +137,7 @@ var probabilityArray = function() {
name: 'probabilityArray',
desc: 'real array with elements that sum to one',
check: function(val) {
return baseType.check(val) && Math.abs(1 - util.sum(val)) < tol;
return baseType.check(val) && Math.abs(1 - numeric._sum(val)) < tol;
}
};
};
Expand Down
48 changes: 4 additions & 44 deletions src/util.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@ var assert = require('assert');
var seedrandom = require('seedrandom');
var ad = require('./ad');
var Tensor = require('./tensor');
var numeric = require('./math/numeric');

var rng = Math.random;

// Re-export sum from this module, as expected by webppl-viz.
var sum = numeric._sum;

var trampolineRunners = {
web: function(yieldEvery) {
yieldEvery = yieldEvery || 100;
Expand Down Expand Up @@ -83,47 +87,6 @@ function asArray(arg) {
return arg ? [].concat(arg) : [];
}

function sum(xs) {
if (xs.length === 0) {
return 0.0;
} else {
var total = _.reduce(xs,
function(a, b) {
return a + b;
});
return total;
}
}

function product(xs) {
var result = 1;
for (var i = 0, n = xs.length; i < n; i++) {
result *= xs[i];
}
return result;
}

function logsumexp(a) {
var m = Math.max.apply(null, a);
var sum = 0;
for (var i = 0; i < a.length; ++i) {
sum += (a[i] === -Infinity ? 0 : Math.exp(a[i] - m));
}
return m + Math.log(sum);
}

function logaddexp(a, b) {
if (a === -Infinity) {
return b;
} else if (b === -Infinity) {
return a;
} else if (a > b) {
return Math.log(1 + Math.exp(b - a)) + a;
} else {
return Math.log(1 + Math.exp(a - b)) + b;
}
}

var deleteIndex = function(arr, i) {
return arr.slice(0, i).concat(arr.slice(i + 1))
}
Expand Down Expand Up @@ -396,16 +359,13 @@ module.exports = {
histStd: histStd,
histsApproximatelyEqual: histsApproximatelyEqual,
gensym: gensym,
logsumexp: logsumexp,
logaddexp: logaddexp,
deleteIndex: deleteIndex,
makeGensym: makeGensym,
prettyJSON: prettyJSON,
runningInBrowser: runningInBrowser,
mergeDefaults: mergeDefaults,
getValAndOpts: getValAndOpts,
sum: sum,
product: product,
asArray: asArray,
serialize: serialize,
deserialize: deserialize,
Expand Down
4 changes: 2 additions & 2 deletions tests/test-data/sampler/beta.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
var _ = require('lodash');
var assert = require('assert');
var beta = require('../../../src/dists/beta');
var util = require('../../../src/util');
var numeric = require('../../../src/math/numeric');
var statistics = require('../../../src/math/statistics');

var ln = Math.log,
Expand Down Expand Up @@ -42,7 +42,7 @@ module.exports = {
var a = params[0];
var b = params[1];
// https://en.wikipedia.org/wiki/Beta_distribution#Higher_moments
return util.product(_.range(0, n - 1).map(function(k) { return (a + k) / (a + b + k) }))
return numeric.product(_.range(0, n - 1).map(function(k) { return (a + k) / (a + b + k) }))
},
// mostly HT https://en.wikipedia.org/wiki/Gamma_distribution
populationStatisticFunctions: {
Expand Down
4 changes: 2 additions & 2 deletions tests/test-data/sampler/gamma.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
var _ = require('lodash');
var assert = require('assert');
var gamma = require('../../../src/dists/gamma');
var util = require('../../../src/util');
var numeric = require('../../../src/math/numeric');
var statistics = require('../../../src/math/statistics');

var ln = Math.log,
Expand Down Expand Up @@ -41,7 +41,7 @@ module.exports = {
// HT http://ocw.mit.edu/courses/mathematics/
// 18-443-statistics-for-applications-fall-2006/lecture-notes/lecture6.pdf
// (but NB: they use shape, rate whereas we have shape, scale)
return util.product(_.range(0, n - 1).map(function(k) { return shape + k })) * pow(scale, n)
return numeric.product(_.range(0, n - 1).map(function(k) { return shape + k })) * pow(scale, n)
},
// mostly HT https://en.wikipedia.org/wiki/Gamma_distribution
populationStatisticFunctions: {
Expand Down

0 comments on commit a4e93d5

Please sign in to comment.