Skip to content

Commit

Permalink
Merge pull request #287 from null-a/histogram-everywhere
Browse files Browse the repository at this point in the history
Use Histogram everywhere
  • Loading branch information
stuhlmueller committed Jan 20, 2016
2 parents 4b8c735 + df94878 commit a905fb2
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 114 deletions.
44 changes: 40 additions & 4 deletions src/aggregation.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,48 @@ Histogram.prototype.add = function(value) {
var value = untapify(value);
var k = util.serialize(value);
if (this.hist[k] === undefined) {
this.hist[k] = { prob: 0, val: value };
this.hist[k] = { count: 0, val: value };
}
this.hist[k].prob += 1;
this.hist[k].count += 1;
};

function normalizeHist(hist) {
var totalCount = _.reduce(hist, function(acc, obj) {
return acc + obj.count;
}, 0);
return _.mapObject(hist, function(obj) {
return { val: obj.val, prob: obj.count / totalCount };
});
}

Histogram.prototype.toERP = function() {
return erp.makeMarginalERP(util.logHist(this.hist));
return erp.makeMarginalERP(normalizeHist(this.hist));
};

var Distribution = function() {
this.dist = {};
};

Distribution.prototype.add = function(value, score) {
var k = util.serialize(value);
if (this.dist[k] === undefined) {
this.dist[k] = { score: -Infinity, val: value };
}
this.dist[k].score = util.logsumexp([this.dist[k].score, score]);
};

function normalizeDist(dist) {
// Note, this also maps dist from log space into probability space.
var logNorm = _.reduce(dist, function(acc, obj) {
return util.logsumexp([acc, obj.score]);
}, -Infinity);
return _.mapObject(dist, function(obj) {
return { val: obj.val, prob: Math.exp(obj.score - logNorm) };
});
}

Distribution.prototype.toERP = function() {
return erp.makeMarginalERP(normalizeDist(this.dist));
};

var MAP = function(retainSamples) {
Expand All @@ -31,7 +66,7 @@ var MAP = function(retainSamples) {
MAP.prototype.add = function(value, score) {
var value = untapify(value);
if (this.retainSamples) {
this.samples.push(value);
this.samples.push({ value: value, score: score });
}
if (score > this.max.score) {
this.max.value = value;
Expand Down Expand Up @@ -61,5 +96,6 @@ function untapify(x) {

module.exports = {
Histogram: Histogram,
Distribution: Distribution,
MAP: MAP
};
41 changes: 12 additions & 29 deletions src/erp.ad.js
Original file line number Diff line number Diff line change
Expand Up @@ -610,29 +610,15 @@ function multinomialSample(theta) {
return k - 1;
}

// Make a discrete ERP from a {val: prob, etc.} object (unormalized).
// Make a discrete ERP from a normalized {val: ..., prob: ...} object.
function makeMarginalERP(marginal) {
assert.ok(_.size(marginal) > 0);
// Normalize distribution:
var norm = -Infinity;
var supp = [];
for (var v in marginal) {if (marginal.hasOwnProperty(v)) {
var d = marginal[v];
norm = util.logsumexp([norm, d.prob]);
supp.push(d.val);
}}
var mapEst = {val: undefined, prob: 0};
for (v in marginal) {if (marginal.hasOwnProperty(v)) {
var dd = marginal[v];
var nprob = dd.prob - norm;
var nprobS = Math.exp(nprob)
if (nprobS > mapEst.prob)
mapEst = {val: dd.val, prob: nprobS};
marginal[v].prob = nprobS;
}}

var norm = _.reduce(marginal, function(acc, obj) { return acc + obj.prob; }, 0);
assert.ok(Math.abs(1 - norm) < 1e-8, 'Expected marginal to be normalized.');
var support = _.map(marginal, function(obj) {
return obj.val;
});
// Make an ERP from marginal:
var dist = new ERP({
return new ERP({
sample: function(params) {
var x = util.random();
var probAccum = 0;
Expand All @@ -647,19 +633,16 @@ function makeMarginalERP(marginal) {
return marginal[i].val;
},
score: function(params, val) {
var lk = marginal[util.serialize(val)];
return lk ? Math.log(lk.prob) : -Infinity;
var obj = marginal[util.serialize(val)];
return obj ? Math.log(obj.prob) : -Infinity;
},
support: function(params) {
return supp;
return support;
},
parameterized: false,
name: 'marginal'
name: 'marginal',
hist: marginal
});

dist.MAP = function() {return mapEst};
dist.hist = marginal;
return dist;
}

// note: ps is expected to be normalized
Expand Down
10 changes: 4 additions & 6 deletions src/inference/asyncpf.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

var _ = require('underscore');
var util = require('../util');
var erp = require('../erp');
var Histogram = require('../aggregation').Histogram;

module.exports = function(env) {

Expand Down Expand Up @@ -51,7 +51,7 @@ module.exports = function(env) {

this.obsWeights = {};
this.exitedParticles = 0;
this.hist = {};
this.hist = new Histogram();

// Move old coroutine out of the way and install this as current handler.
this.k = k;
Expand Down Expand Up @@ -163,14 +163,12 @@ module.exports = function(env) {
this.activeParticle.weight += Math.log(this.activeParticle.multiplicity);
this.exitedParticles += 1;

var k = util.serialize(retval);
if (this.hist[k] === undefined) this.hist[k] = {prob: 0, val: retval};
this.hist[k].prob += 1;
this.hist.add(retval);

if (this.exitedParticles < this.numParticles) {
return this.run();
} else {
var dist = erp.makeMarginalERP(util.logHist(this.hist));
var dist = this.hist.toERP();

var lastFactorIndex = this.activeParticle.factorIndex;
var olk = this.obsWeights[lastFactorIndex];
Expand Down
16 changes: 4 additions & 12 deletions src/inference/enumerate.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

var _ = require('underscore');
var PriorityQueue = require('priorityqueuejs');
var erp = require('../erp');
var util = require('../util');
var Distribution = require('../aggregation').Distribution;

module.exports = function(env) {

function Enumerate(store, k, a, wpplFn, maxExecutions, Q) {
this.score = 0; // Used to track the score of the path currently being explored
this.marginal = {}; // We will accumulate the marginal distribution here
this.marginal = new Distribution(); // We will accumulate the marginal distribution here
this.numCompletedExecutions = 0;
this.store = store; // will be reinstated at the end
this.k = k;
Expand Down Expand Up @@ -120,13 +120,7 @@ module.exports = function(env) {

Enumerate.prototype.exit = function(s, retval) {
// We have reached an exit of the computation. Accumulate probability into retval bin.
var r = util.serialize(retval);
if (this.score !== -Infinity) {
if (this.marginal[r] === undefined) {
this.marginal[r] = {val: retval, prob: -Infinity};
}
this.marginal[r].prob = util.logsumexp([this.marginal[r].prob, this.score])
}
this.marginal.add(retval, this.score);

// Increment the completed execution counter
this.numCompletedExecutions++;
Expand All @@ -135,12 +129,10 @@ module.exports = function(env) {
if (this.queue.size() > 0 && (this.numCompletedExecutions < this.maxExecutions)) {
return this.nextInQueue();
} else {
var marginal = this.marginal;
var dist = erp.makeMarginalERP(marginal);
// Reinstate previous coroutine:
env.coroutine = this.coroutine;
// Return from enumeration by calling original continuation with original store:
return this.k(this.store, dist);
return this.k(this.store, this.marginal.toERP());
}
};

Expand Down
46 changes: 7 additions & 39 deletions src/inference/incrementalmh.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
var _ = require('underscore');
var assert = require('assert');
var util = require('../util');
var erp = require('../erp');
var Hashtable = require('../hashtable').Hashtable
var Query = require('../query').Query;
var aggregation = require('../aggregation');

module.exports = function(env) {

Expand Down Expand Up @@ -768,12 +768,10 @@ module.exports = function(env) {
this.s = s;
this.a = a;

this.onlyMAP = onlyMAP;
if (justSample)
this.returnSamps = [];
else
this.returnHist = {};
this.MAP = { val: undefined, score: -Infinity };
this.aggregator = (justSample || onlyMAP) ?
new aggregation.MAP(justSample) :
new aggregation.Histogram();

this.totalIterations = numIterations;
this.acceptedProps = 0;
this.lag = lag;
Expand Down Expand Up @@ -902,22 +900,7 @@ module.exports = function(env) {
if (val === env.query)
val = this.query.getTable();
// add val to hist:
if (!this.onlyMAP) {
if (this.returnSamps)
this.returnSamps.push({score: this.score, value: val})
else {
var stringifiedVal = util.serialize(val);
if (this.returnHist[stringifiedVal] === undefined) {
this.returnHist[stringifiedVal] = { prob: 0, val: val };
}
this.returnHist[stringifiedVal].prob += 1;
}
}
// also update the MAP
if (this.score > this.MAP.score) {
this.MAP.score = this.score;
this.MAP.value = val;
}
this.aggregator.add(val, this.score);
}

if (DEBUG >= 6) {
Expand Down Expand Up @@ -948,21 +931,6 @@ module.exports = function(env) {
}
}
} else {
var hist;
if (this.returnSamps || this.onlyMAP) {
hist = {};
hist[util.serialize(this.MAP.value)] = { prob: 1, val: this.MAP.value };
} else {
hist = this.returnHist;
}
var dist = erp.makeMarginalERP(util.logHist(hist));
if (this.returnSamps) {
if (this.onlyMAP)
this.returnSamps.push(this.MAP);
dist.samples = this.returnSamps;
}
dist.MAP = this.MAP.value;

// Reinstate previous coroutine:
var k = this.k;
env.coroutine = this.oldCoroutine;
Expand All @@ -973,7 +941,7 @@ module.exports = function(env) {
}

// Return by calling original continuation:
return k(this.oldStore, dist);
return k(this.oldStore, this.aggregator.toERP());
}
};

Expand Down
13 changes: 4 additions & 9 deletions src/inference/pmcmc.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
var _ = require('underscore');
var erp = require('../erp');
var util = require('../util')
var Histogram = require('../aggregation').Histogram;

module.exports = function(env) {

Expand Down Expand Up @@ -34,7 +35,7 @@ module.exports = function(env) {
this.address = a;
this.numParticles = numParticles;
this.resetParticles();
this.returnHist = {};
this.hist = new Histogram();
}

PMCMC.prototype.run = function() {
Expand Down Expand Up @@ -168,11 +169,7 @@ module.exports = function(env) {
if (this.sweep > 0) {
this.particles.concat(this.retainedParticle).forEach(
function(particle) {
var k = util.serialize(particle.value);
if (this.returnHist[k] === undefined) {
this.returnHist[k] = {prob: 0, val: particle.value};
}
this.returnHist[k].prob += 1;
this.hist.add(particle.value);
}.bind(this));
}

Expand All @@ -188,13 +185,11 @@ module.exports = function(env) {
return this.activeContinuationWithStore();

} else {
var dist = erp.makeMarginalERP(util.logHist(this.returnHist));

// Reinstate previous coroutine:
env.coroutine = this.oldCoroutine;

// Return from particle filter by calling original continuation:
return this.k(this.oldStore, dist);
return this.k(this.oldStore, this.hist.toERP());

}
}
Expand Down
12 changes: 4 additions & 8 deletions src/inference/rejection.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
var erp = require('../erp');
var assert = require('assert');
var util = require('../util')
var Histogram = require('../aggregation').Histogram;

module.exports = function(env) {

Expand All @@ -21,7 +22,7 @@ module.exports = function(env) {
this.wpplFn = wpplFn;
this.maxScore = maxScore === undefined ? 0 : maxScore
this.incremental = incremental;
this.hist = {};
this.hist = new Histogram();
this.numSamples = numSamples;
this.oldCoroutine = env.coroutine;
env.coroutine = this;
Expand Down Expand Up @@ -63,18 +64,13 @@ module.exports = function(env) {

if (this.scoreSoFar > this.threshold) {
// Accept.
var r = util.serialize(retval);
if (this.hist[r] === undefined) {
this.hist[r] = { prob: 0, val: retval };
}
this.hist[r].prob += 1;
this.hist.add(retval);
this.numSamples -= 1;
}

if (this.numSamples === 0) {
var dist = erp.makeMarginalERP(util.logHist(this.hist));
env.coroutine = this.oldCoroutine;
return this.k(this.s, dist);
return this.k(this.s, this.hist.toERP());
} else {
return this.run();
}
Expand Down
7 changes: 0 additions & 7 deletions src/util.js
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,6 @@ function product(xs) {
return result;
}

var logHist = function(hist) {
return _.mapObject(hist, function(x) {
return {prob: Math.log(x.prob), val: x.val}
});
};

function logsumexp(a) {
var m = Math.max.apply(null, a);
var sum = 0;
Expand Down Expand Up @@ -227,7 +221,6 @@ module.exports = {
histsApproximatelyEqual: histsApproximatelyEqual,
gensym: gensym,
logsumexp: logsumexp,
logHist: logHist,
deleteIndex: deleteIndex,
makeGensym: makeGensym,
prettyJSON: prettyJSON,
Expand Down

0 comments on commit a905fb2

Please sign in to comment.