Skip to content

Commit

Permalink
Merge pull request #245 from null-a/var-factors
Browse files Browse the repository at this point in the history
Make rejuvenation work for the varFactors tests.
  • Loading branch information
stuhlmueller committed Oct 20, 2015
2 parents c187f9b + 911f03e commit 8b61553
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 63 deletions.
20 changes: 11 additions & 9 deletions src/inference/mhkernel.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ module.exports = function(env) {
function MHKernel(k, oldTrace, options) {
var options = util.mergeDefaults(options, {
proposalBoundary: 0,
exitFactor: 0,
permissive: false
});

Expand All @@ -20,8 +21,8 @@ module.exports = function(env) {
this.k = k;
this.oldTrace = oldTrace;
this.reused = {};
this.exitAddress = options.exitAddress;
this.proposalBoundary = options.proposalBoundary;
this.exitFactor = options.exitFactor;

this.coroutine = env.coroutine;
env.coroutine = this;
Expand All @@ -45,10 +46,11 @@ module.exports = function(env) {
if (score === -Infinity) {
return this.cont(this.oldTrace, false);
}
this.trace.numFactors += 1;
this.trace.score += score;
if (this.exitAddress === a) {
this.trace.saveContinuation(s, k, a);
return env.exit(s);
if (this.trace.numFactors === this.exitFactor) {
this.trace.saveContinuation(s, k);
return this.exit(s, undefined, true);
}
return k(s);
};
Expand Down Expand Up @@ -81,13 +83,13 @@ module.exports = function(env) {
return k(s, val);
};

MHKernel.prototype.exit = function(s, val) {
if (!this.exitAddress) {
MHKernel.prototype.exit = function(s, val, earlyExit) {
if (!earlyExit) {
this.trace.complete(val);
} else {
// We're rejuvenating a particle - ensure that exitAddress was reached by
// checking that the continuation was saved.
assert(!this.trace.isComplete(), 'Particle missed exit address during rejuvenation.');
assert(this.trace.store);
assert(this.trace.k);
assert(!this.trace.isComplete());
}
var prob = acceptProb(this.trace, this.oldTrace, this.regenFrom, this.reused, this.proposalBoundary);
var accept = util.random() < prob;
Expand Down
114 changes: 70 additions & 44 deletions src/inference/smc.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ module.exports = function(env) {
var options = util.mergeDefaults(options, {
particles: 100,
rejuvSteps: 0,
allowOutOfSyncRejuv: false
finalRejuv: true
});

if (!options.rejuvKernel) {
Expand All @@ -27,22 +27,24 @@ module.exports = function(env) {
this.rejuvSteps = options.rejuvSteps;
this.rejuvKernel = options.rejuvKernel;
this.performRejuv = this.rejuvSteps > 0;
this.allowOutOfSyncRejuv = options.allowOutOfSyncRejuv;
this.performFinalRejuv = this.performRejuv && options.finalRejuv;
this.numParticles = options.particles;
this.debug = options.debug;

this.particles = [];
this.completeParticles = [];
this.particleIndex = 0;

this.step = 0;

var exitK = function(s) {
return wpplFn(s, env.exit, a);
};

// Create initial particles.
for (var i = 0; i < this.numParticles; i++) {
var trace = new Trace();
trace.saveContinuation(_.clone(s), exitK, a);
trace.saveContinuation(_.clone(s), exitK);
this.particles.push(new Particle(trace));
}

Expand Down Expand Up @@ -75,7 +77,8 @@ module.exports = function(env) {
SMC.prototype.factor = function(s, k, a, score) {
// Update particle.
var particle = this.currentParticle();
particle.trace.saveContinuation(s, k, a);
particle.trace.numFactors += 1;
particle.trace.saveContinuation(s, k);
particle.trace.score += score;
particle.logWeight += score;
this.debugLog('(' + this.particleIndex + ') Factor: ' + a);
Expand Down Expand Up @@ -149,46 +152,51 @@ module.exports = function(env) {
return allParticles;
}

SMC.prototype.rejuvenateParticles = function(cont) {

SMC.prototype.rejuvenateParticles = function(particles, cont) {
if (!this.performRejuv) {
return cont();
}
assert(!this.particlesAreWeighted(), 'Cannot rejuvenate weighted particles.');
if (!this.allowOutOfSyncRejuv && this.particlesAreOutOfSync()) {
throw 'Cannot rejuvenate out of sync particles.';
return cont(particles);
}

assert(!this.particlesAreWeighted(particles), 'Cannot rejuvenate weighted particles.');

return util.cpsForEach(
function(p, i, ps, next) {
return this.rejuvenateParticle(next, i);
return this.rejuvenateParticle(next, p);
}.bind(this),
cont,
this.particles
function() {
return cont(particles);
},
particles
);
};

SMC.prototype.rejuvenateParticle = function(cont, i) {
var exitAddress = this.particles[i].trace.address;
assert.notStrictEqual(exitAddress, undefined);
var kernelOptions = {
exitAddress: exitAddress,
proposalBoundary: this.particles[i].proposalBoundary
};
SMC.prototype.rejuvenateParticle = function(cont, particle) {
var kernelOptions = { proposalBoundary: particle.proposalBoundary };
if (this.performRejuv) {
kernelOptions.exitFactor = this.step;
}
var kernel = _.partial(this.rejuvKernel, _, _, kernelOptions);
var chain = repeatKernel(this.rejuvSteps, kernel);
return chain(function(trace) {
this.particles[i].trace = trace;
particle.trace = trace;
return cont();
}.bind(this), this.particles[i].trace);
}, particle.trace);
};

SMC.prototype.particlesAreWeighted = function() {
var lw = _.first(this.particles).logWeight;
return _.any(this.particles, function(p) { return p.logWeight !== lw; });
SMC.prototype.particlesAreWeighted = function(particles) {
var lw = _.first(particles).logWeight;
return _.any(particles, function(p) { return p.logWeight !== lw; });
};

SMC.prototype.particlesAreOutOfSync = function() {
var a = _.first(this.particles).trace.address;
return _.any(this.particles, function(p) { return p.trace.address !== a; });
SMC.prototype.particlesAreInSync = function(particles) {
// All particles are either at the step^{th} factor statement, or
// at the exit having encountered < than step factor statements.
return _.all(particles, function(p) {
var trace = p.trace;
return ((trace.isComplete() && trace.numFactors < this.step) ||
(!trace.isComplete() && trace.numFactors === this.step));
}.bind(this));
};

SMC.prototype.sync = function() {
Expand All @@ -199,23 +207,42 @@ module.exports = function(env) {
this.advanceParticleIndex();
return this.runCurrentParticle();
} else {
this.debugLog('***** SYNC *****');

var resampledParticles = resampleParticles(this.allParticles());
this.step += 1;
this.debugLog('***** sync :: step = ' + this.step + ' *****');

// Resampling and rejuvenation are applied to all particles.
// Active and complete particles are combined here and
// re-partitioned after rejuvenation.
var allParticles = this.allParticles();
assert(this.particlesAreInSync(allParticles));
var resampledParticles = resampleParticles(allParticles);
assert.strictEqual(resampledParticles.length, this.numParticles);

var p = _.partition(resampledParticles, function(p) { return p.trace.isComplete(); });
this.completeParticles = p[0];
this.particles = p[1];
var numActiveParticles = _.reduce(resampledParticles, function(acc, p) {
return acc + (p.trace.isComplete() ? 0 : 1);
}, 0);

this.debugLog('After resampling: active = ' + p[1].length + ', complete = ' + p[0].length + '\n');

if (this.particles.length > 0) {
if (numActiveParticles > 0) {
// We still have active particles, wrap-around:
this.particleIndex = 0;
return this.rejuvenateParticles(this.runCurrentParticle.bind(this));
return this.rejuvenateParticles(resampledParticles, function(rejuvenatedParticles) {
assert(this.particlesAreInSync(rejuvenatedParticles));

var p = _.partition(rejuvenatedParticles, function(p) { return p.trace.isComplete(); });
this.completeParticles = p[0];
this.particles = p[1];
this.debugLog(p[1].length + ' active particles after resample/rejuv.\n');

if (this.particles.length > 0) {
return this.runCurrentParticle();
} else {
return this.finish();
}
}.bind(this));
} else {
// All particles complete.
this.particles = [];
this.completeParticles = resampledParticles;
return this.finish();
}
}
Expand All @@ -230,7 +257,7 @@ module.exports = function(env) {
SMC.prototype.exit = function(s, val) {
// Complete the trace.
this.currentParticle().trace.complete(val);
this.debugLog('(' + this.particleIndex + ') Exit');
this.debugLog('(' + this.particleIndex + ') Exit | Value: ' + val);
return this.sync();
};

Expand All @@ -242,18 +269,17 @@ module.exports = function(env) {

return util.cpsForEach(
function(particle, i, ps, k) {
assert.strictEqual(particle.logWeight, logAvgW, 'Expected un-weighted particles.');
if (!this.performRejuv) {
hist.add(particle.trace.value);
return k();
} else {
if (this.performFinalRejuv) {
// Final rejuvenation.
var chain = repeatKernel(
this.rejuvSteps,
sequenceKernels(
this.rejuvKernel,
tapKernel(function(trace) { hist.add(trace.value); })));
return chain(k, particle.trace);
} else {
hist.add(particle.trace.value);
return k();
}
}.bind(this),
function() {
Expand Down
23 changes: 13 additions & 10 deletions src/trace.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ var Trace = function() {
this.addressMap = {}; // Maps addresses => choices.
this.length = 0;
this.score = 0;
this.numFactors = 0; // The number of factors encountered so far.
};

Trace.prototype.choiceAtIndex = function(index) {
Expand All @@ -19,11 +20,10 @@ Trace.prototype.findChoice = function(address) {
return this.addressMap[address];
};

Trace.prototype.saveContinuation = function(s, k, a) {
Trace.prototype.saveContinuation = function(s, k) {
this.store = s;
this.k = k;
this.address = a;
//this.checkConsistency();
// this.checkConsistency();
};

Trace.prototype.addChoice = function(erp, params, val, address, store, continuation) {
Expand All @@ -45,26 +45,27 @@ Trace.prototype.addChoice = function(erp, params, val, address, store, continuat
// need if we regen from this choice.
score: this.score,
val: val,
store: _.clone(store)
store: _.clone(store),
numFactors: this.numFactors
};

this.choices.push(choice);
this.addressMap[address] = choice;
this.length += 1;
this.score += erp.score(params, val);
//this.checkConsistency();
// this.checkConsistency();
};

Trace.prototype.complete = function(value) {
// Called at coroutine exit.
assert.strictEqual(this.value, undefined);
this.value = value;
// Ensure any attempt to continue a completed trace fails in an obvious way.
this.k = this.store = this.address = undefined;
this.k = this.store = undefined;
};

Trace.prototype.isComplete = function() {
return this.k === undefined && this.store === undefined && this.address === undefined;
return this.k === undefined && this.store === undefined;
};

Trace.prototype.upto = function(i) {
Expand All @@ -77,7 +78,8 @@ Trace.prototype.upto = function(i) {
t.choices.forEach(function(choice) { t.addressMap[choice.address] = choice; });
t.length = t.choices.length;
t.score = this.choices[i].score;
//t.checkConsistency();
t.numFactors = this.choices[i].numFactors;
// t.checkConsistency();
return t;
};

Expand All @@ -91,7 +93,8 @@ Trace.prototype.copy = function() {
t.store = _.clone(this.store);
t.address = this.address;
t.value = this.value;
//t.checkConsistency();
t.numFactors = this.numFactors;
// t.checkConsistency();
return t;
};

Expand All @@ -101,7 +104,7 @@ Trace.prototype.checkConsistency = function() {
this.choices.forEach(function(choice) {
assert(_.has(this.addressMap, choice.address));
}, this);
assert(this.value === undefined || (this.k === undefined && this.store === undefined && this.address === undefined));
assert(this.value === undefined || (this.k === undefined && this.store === undefined));
};

module.exports = Trace;
2 changes: 2 additions & 0 deletions tests/test-inference.js
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ var tests = [
store2: { hist: { tol: 0 }, args: { particles: 30, rejuvSteps: 30 } },
geometric: true,
drift: { mean: { tol: 0.3 }, std: { tol: 0.3 }, args: { particles: 1000, rejuvSteps: 15 } },
varFactors1: true,
varFactors2: true,
importance: true,
importance2: { args: { particles: 3000, rejuvSteps: 10 } },
importance3: true,
Expand Down

0 comments on commit 8b61553

Please sign in to comment.