Skip to content

Commit

Permalink
Merge pull request #775 from null-a/dream
Browse files Browse the repository at this point in the history
Dream Learning
  • Loading branch information
stuhlmueller committed Feb 21, 2017
2 parents 32527f7 + 31f5881 commit 9a8e801
Show file tree
Hide file tree
Showing 21 changed files with 528 additions and 35 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ src/dists.wppl
src/inference/enumerate.js
src/inference/elbo.js
src/inference/eubo.js
src/inference/dream/gradients.js
src/aggregation/ScoreAggregator.js
23 changes: 21 additions & 2 deletions src/header.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ try {
var incrementalmh = require('./inference/incrementalmh');
var optimize = require('./inference/optimize');
var forwardSample = require('./inference/forwardSample');
var dreamSample = require('./inference/dream/sample');
var headerUtils = require('./headerUtils');
var params = require('./params/header');
var Query = require('./query').Query;
Expand All @@ -51,7 +52,6 @@ try {

module.exports = function(env) {


// Inference interface

env.defaultCoroutine = {
Expand Down Expand Up @@ -90,6 +90,24 @@ module.exports = function(env) {
return env.coroutine.factor(s, k, a, score);
};

// If observation value is given then factor accordingly,
// otherwise sample a new value.
// The value is passed to the continuation.
env.observe = function(s, k, a, dist, val) {
if (typeof env.coroutine.observe === 'function') {
return env.coroutine.observe(s, k, a, dist, val);
} else {
if (val !== undefined) {
var factorK = function(s) {
return k(s, val);
};
return env.factor(s, factorK, a, dist.score(val));
} else {
return env.sample(s, k, a, dist);
}
}
};

env.sampleWithFactor = function(s, k, a, dist, scoreFn) {
if (typeof env.coroutine.sampleWithFactor === 'function') {
return env.coroutine.sampleWithFactor(s, k, a, dist, scoreFn);
Expand Down Expand Up @@ -134,6 +152,7 @@ module.exports = function(env) {
factor: env.factor,
sample: env.sample,
sampleWithFactor: env.sampleWithFactor,
observe: env.observe,
incrementalize: env.incrementalize,
query: env.query
});
Expand All @@ -152,7 +171,7 @@ module.exports = function(env) {
// Inference functions and header utils
var headerModules = [
enumerate, asyncpf, mcmc, incrementalmh, pmcmc,
smc, rejection, optimize, forwardSample,
smc, rejection, optimize, forwardSample, dreamSample,
headerUtils, params
];
headerModules.forEach(function(mod) {
Expand Down
9 changes: 0 additions & 9 deletions src/header.wppl
Original file line number Diff line number Diff line change
Expand Up @@ -447,15 +447,6 @@ var condition = function(bool) {
factor(bool ? 0 : -Infinity);
};

var observe = function(dist, val) {
if (val !== undefined) {
factor(dist.score(val));
return val;
} else {
return sample(dist);
}
};

var MH = function(wpplFn, samples, burn) {
return MCMC(wpplFn, { samples: samples, burn: burn });
};
Expand Down
43 changes: 28 additions & 15 deletions src/headerUtils.js
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,20 @@ module.exports = function(env) {
// one or more of the following methods:

// mapDataFetch: Called when mapData is entered, providing an
// opportunity to perform book-keeping etc. When sub-sampling data
// this method should return an array of indices indicating the data
// to be mapped over. Alternatively, null can be returned to
// indicate that all data should be used.
// opportunity to perform book-keeping etc. The method should return
// an object with data, ix and (optional) address properties.

// data: The array that will be mapped over.

// ix: An array of integers of the same length as data, where each
// entry indicates the position at which the corresponding entry
// in data can be found in the original data array. This is used
// to ensure that corresponding data items and stack addresses are
// used when applying the observation function. For convenience,
// null can be returned as a short hand for _.range(data.length).

// address: When present, mapData behaves as though it was called
// from this address.

// mapDataEnter/mapDataLeave: Called before/after every application
// of the observation function.
Expand All @@ -112,24 +122,27 @@ module.exports = function(env) {
throw new Error('mapData: No data given.');
}

var batchSize = opts.batchSize !== undefined ? opts.batchSize : data.length;
if (batchSize < 0 || batchSize > data.length) {
throw new Error('mapData: Invalid batchSize.');
}
var ret = env.coroutine.mapDataFetch ?
env.coroutine.mapDataFetch(data, opts, a) :
{data: data, ix: null};

var ix = ret.ix;
var finalData = ret.data;
var address = ret.address || a;

var ix = env.coroutine.mapDataFetch ?
env.coroutine.mapDataFetch(data, batchSize, a) :
null;
assert.ok(ix === null ||
(_.isArray(ix) && (ix.length === finalData.length)),
'Unexpected value returned by mapDataFetch.');

assert.ok(ix === null || _.isArray(ix));
var doReturn = ix === null; // We return undefined when sub-sampling data.
// We return undefined when sub-sampling data etc.
var doReturn = finalData === data;

return cpsMapData(s, function(s, v) {
if (env.coroutine.mapDataFinal) {
env.coroutine.mapDataFinal(a);
}
return k(s, doReturn ? v : undefined);
}, a, data, ix, obsFn);
}, address, finalData, ix, obsFn);
}

function cpsMapData(s, k, a, data, indices, f, acc, i) {
Expand All @@ -151,7 +164,7 @@ module.exports = function(env) {
return function() {
return cpsMapData(s, k, a, data, indices, f, acc.concat([v]), i + 1);
};
}, a.concat('_$$' + ix), data[ix], ix);
}, a.concat('_$$' + ix), data[i], ix);
}
}

Expand Down
76 changes: 76 additions & 0 deletions src/inference/dream/estimator.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
'use strict';

var util = require('../../util');
var paramStruct = require('../../params/struct');

// This estimator currently makes the following assumptions:

// 1. The model includes no more than one `mapData`.

// 2. Every evaluation of the observation function (associated with a
// `mapData`) includes one or more calls to `observe`, and either:

// 2a. There is exactly one call to `observe`, and the value yielded
// to the observation function is the value passed to `observe`. For
// example:

// var model = function() {
// mapData({data}, function(datum) {
// observe(dist, datum);
// });
// };

// 2b. There is more than one call to `observe`, the value yielded to
// the observation function is an array, and successive observations
// are passed successive elements of the array, starting from the
// first element. For example:

// var model = function() {
// mapData({data}, function(arr) {
// observe(dist, arr[0]);
// observe(dist, arr[1]);
// observe(dist, arr[2]);
// });
// };

// 3. There are no factor statements. We assume we can generate
// samples from the posterior predictive distribution directly by
// forward sampling. If there were additional factors we'd need to
// account for them with e.g. importance sampling.

// 4. observe is only used within mapData.

module.exports = function(env) {

var dreamSample = require('./sample')(env).dreamSample;
var dreamGradients = require('./gradients')(env);

return function(wpplFn, s, a, options, state, step, cont) {
var opts = util.mergeDefaults(options, {
samples: 1
});

var objVal = 0;
var grad = {};

return util.cpsLoop(
opts.samples,
// Loop body.
function(i, next) {
return dreamSample(s, function(s, record) {
return dreamGradients(wpplFn, record, s, a, function(g, objVal_i) {
paramStruct.addEq(grad, g);
objVal += objVal_i;
return next();
});
}, a, wpplFn);
},
// Continuation.
function() {
paramStruct.divEq(grad, opts.samples);
objVal /= opts.samples;
return cont(grad, objVal);
});
};

};
110 changes: 110 additions & 0 deletions src/inference/dream/gradients.ad.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
'use strict';
'use ad';

var assert = require('assert');
var _ = require('lodash');
var guide = require('../../guide');
var ad = require('../../ad');

module.exports = function(env) {

function dreamGradients(wpplFn, record, s, a, cont) {
this.wpplFn = wpplFn;
this.record = record;
this.s = s;
this.a = a;
this.cont = cont;

this.guideRequired = true;
this.insideMapData = false;

this.coroutine = env.coroutine;
env.coroutine = this;
}

dreamGradients.prototype = {

run: function() {
return this.estimateGradient(function(grad, objVal) {
env.coroutine = this.coroutine;
return this.cont(grad, objVal);
}.bind(this));
},

estimateGradient: function(cont) {

this.paramsSeen = {};
this.logq = 0;

return this.wpplFn(_.clone(this.s), function(s, val) {

var objective = -this.logq;
if (ad.isLifted(objective)) {
objective.backprop();
}

var grads = _.mapValues(this.paramsSeen, function(params) {
return params.map(ad.derivative);
});

return cont(grads, -ad.value(this.logq));

}.bind(this), this.a);

},

sample: function(s, k, a, dist, options) {
options = options || {};
var choice = this.record.trace.findChoice(a);
assert.ok(choice !== undefined, 'dream: No entry for this choice in the trace.');
var val = choice.val;

if (this.insideMapData) {
return guide.getDist(
options.guide, options.noAutoGuide, dist, env, s, a,
function(s, guideDist) {
if (!guideDist) {
throw new Error('dream: No guide distribution specified.');
}
this.logq += guideDist.score(val);
return k(s, val);
}.bind(this));
}
else {
return k(s, val);
}
},

factor: function(s, k, a, score) {
// This will only be called by the default implementation of
// observe. (Because we checked that factor isn't called during
// the sampling phase, and since we're reusing choices from the
// trace, the execution here will follow the same path through
// the program.)
return k(s);
},

mapDataFetch: function(data, opts, a) {
if (this.insideMapData) {
throw new Error('dream: nested mapData is not supported by this estimator.');
}
this.insideMapData = true;
return {data: this.record.data, ix: null, address: a + '_dream'};
},

mapDataFinal: function() {
this.insideMapData = false;
},

incrementalize: env.defaultCoroutine.incrementalize,
constructor: dreamGradients

};

return function() {
var coroutine = Object.create(dreamGradients.prototype);
dreamGradients.apply(coroutine, arguments);
return coroutine.run();
};

};

0 comments on commit 9a8e801

Please sign in to comment.