Skip to content

Commit

Permalink
Merge pull request #282 from null-a/fix-util-expectation
Browse files Browse the repository at this point in the history
Updates to expectation and std util funcs.
  • Loading branch information
stuhlmueller committed Jan 8, 2016
2 parents c092ddd + faa2dc9 commit e9d30f5
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 55 deletions.
1 change: 1 addition & 0 deletions src/erp.js
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,7 @@ function makeMarginalERP(marginal) {
});

dist.MAP = function() {return mapEst};
dist.hist = marginal;
return dist;
}

Expand Down
6 changes: 3 additions & 3 deletions src/inference/mh-diagnostics/diagnostics.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
var util = require('../../util');
var stats = require('../../statistics');
var fs = require('fs');
var os = require('os');

Expand All @@ -15,8 +15,8 @@ function geweke(traces, first, last, intervals) {
for (var i = 0; i < traces.length / 2; i = i + Math.floor((traces.length / 2) / (intervals - 1))) {
var firstSlice = traces.slice(i, i + Math.floor(first * (end - i)));
var lastSlice = traces.slice(Math.floor(end - last * (end - i)), traces.length);
var mu = (util.mean(firstSlice) - util.mean(lastSlice));
var zscore = mu / Math.sqrt(Math.pow(util.std(firstSlice), 2) + Math.pow(util.std(lastSlice), 2));
var mu = (stats.mean(firstSlice) - stats.mean(lastSlice));
var zscore = mu / Math.sqrt(stats.variance(firstSlice) + stats.variance(lastSlice));
zscores.push([i, zscore]);
}
return zscores;
Expand Down
59 changes: 19 additions & 40 deletions src/util.js
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,6 @@ function product(xs) {
return result;
}

function normalizeHist(hist) {
var normHist = {};
var Z = sum(_.values(hist));
_.each(hist, function(val, key) {
normHist[key] = hist[key] / Z;
});
return normHist;
}

var logHist = function(hist) {
return _.mapObject(hist, function(x) {
return {prob: Math.log(x.prob), val: x.val}
Expand Down Expand Up @@ -143,7 +134,22 @@ function cpsIterate(n, initial, func, cont) {
function() { return cont(val); });
}

function histsApproximatelyEqual(hist, expectedHist, tolerance) {
function histExpectation(hist, func) {
var f = func || _.identity;
return _.reduce(hist, function(acc, obj) {
return acc + obj.prob * f(obj.val);
}, 0);
}

function histStd(hist) {
var m = histExpectation(hist);
return Math.sqrt(histExpectation(hist, function(x) {
return Math.pow(x - m, 2);
}));
}

function histsApproximatelyEqual(actualHist, expectedHist, tolerance) {
var hist = _.mapObject(actualHist, function(obj) { return obj.prob; });
var allOk = (expectedHist !== undefined);
_.each(
expectedHist,
Expand All @@ -159,32 +165,6 @@ function histsApproximatelyEqual(hist, expectedHist, tolerance) {
return allOk;
}

function expectation(hist, func) {
var f = func == undefined ? function(x) {return x;} : func;
if (_.isArray(hist)) {
return sum(hist) / hist.length;
} else {
var expectedValue = sum(_.mapObject(hist, function(v, x) {
return f(x) * v;
}));
return expectedValue;
}
}

function std(hist) {
var mu = expectation(hist);
if (_.isArray(hist)) {
var variance = expectation(hist.map(function(x) {
return Math.pow(x - mu, 2);
}));
} else {
var variance = sum(_.mapObject(hist, function(v, x) {
return v * Math.pow(mu - x, 2);
}));
}
return Math.sqrt(variance);
}

function mergeDefaults(options, defaults) {
return _.defaults(options ? _.clone(options) : {}, defaults);
}
Expand Down Expand Up @@ -250,18 +230,17 @@ module.exports = {
cpsForEach: cpsForEach,
cpsLoop: cpsLoop,
cpsIterate: cpsIterate,
expectation: expectation,
gensym: gensym,
histExpectation: histExpectation,
histStd: histStd,
histsApproximatelyEqual: histsApproximatelyEqual,
gensym: gensym,
logsumexp: logsumexp,
logHist: logHist,
deleteIndex: deleteIndex,
makeGensym: makeGensym,
normalizeArray: normalizeArray,
normalizeHist: normalizeHist,
prettyJSON: prettyJSON,
runningInBrowser: runningInBrowser,
std: std,
mergeDefaults: mergeDefaults,
sum: sum,
product: product,
Expand Down
15 changes: 3 additions & 12 deletions tests/test-inference.js
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,6 @@ var wpplRunInference = function(modelName, testDef) {

var performTest = function(modelName, testDef, test) {
var result = wpplRunInference(modelName, testDef);
result.hist = getHist(result.erp);
var expectedResults = helpers.loadExpected(testDataDir, modelName);

_.each(expectedResults, function(expected, testName) {
Expand All @@ -330,13 +329,13 @@ var getInferenceArgs = function(testDef, model) {

var testFunctions = {
hist: function(test, result, expected, args) {
test.ok(util.histsApproximatelyEqual(result.hist, expected, args.tol));
test.ok(util.histsApproximatelyEqual(result.erp.hist, expected, args.tol));
},
mean: function(test, result, expected, args) {
helpers.testWithinTolerance(test, util.expectation(result.hist), expected, args.tol, 'mean');
helpers.testWithinTolerance(test, util.histExpectation(result.erp.hist), expected, args.tol, 'mean');
},
std: function(test, result, expected, args) {
helpers.testWithinTolerance(test, util.std(result.hist), expected, args.tol, 'std');
helpers.testWithinTolerance(test, util.histStd(result.erp.hist), expected, args.tol, 'std');
},
logZ: function(test, result, expected, args) {
if (args.check) {
Expand All @@ -355,14 +354,6 @@ var testFunctions = {
}
};

var getHist = function(erp) {
var hist = {};
erp.support().forEach(function(value) {
hist[util.serialize(value)] = Math.exp(erp.score([], value));
});
return util.normalizeHist(hist);
};

var generateTestCases = function(seed) {
_.each(tests, function(testDef) {
exports[testDef.name] = {};
Expand Down
41 changes: 41 additions & 0 deletions tests/test-statistics.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
'use strict';

var stats = require('../src/statistics');

module.exports = {

testMean: {

test1: function(test) {
test.strictEqual(stats.mean([0, 3, 9]), 4);
test.done();
},
test2: function(test) {
test.strictEqual(stats.mean(new Float64Array([0, 3, 9])), 4);
test.done();
},
test3: function(test) {
test.throws(function() { stats.mean([]); });
test.done();
}

},

testStandardDeviation: {

test1: function(test) {
test.strictEqual(stats.sd([0, 1, 2]), Math.sqrt(2 / 3));
test.done();
},
test2: function(test) {
test.strictEqual(stats.sd(new Float64Array([0, 1, 2])), Math.sqrt(2 / 3));
test.done();
},
test3: function(test) {
test.throws(function() { stats.sd([]); });
test.done();
}

}

};

0 comments on commit e9d30f5

Please sign in to comment.