-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #775 from null-a/dream
Dream Learning
- Loading branch information
Showing
21 changed files
with
528 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
'use strict'; | ||
|
||
var util = require('../../util'); | ||
var paramStruct = require('../../params/struct'); | ||
|
||
// This estimator currently makes the following assumptions: | ||
|
||
// 1. The model includes no more than one `mapData`. | ||
|
||
// 2. Every evaluation of the observation function (associated with a | ||
// `mapData`) includes one or more calls to `observe`, and either: | ||
|
||
// 2a. There is exactly one call to `observe`, and the value yielded | ||
// to the observation function is the value passed to `observe`. For | ||
// example: | ||
|
||
// var model = function() { | ||
// mapData({data}, function(datum) { | ||
// observe(dist, datum); | ||
// }); | ||
// }; | ||
|
||
// 2b. There is more than one call to `observe`, the value yielded to | ||
// the observation function is an array, and successive observations | ||
// are passed successive elements of the array, starting from the | ||
// first element. For example: | ||
|
||
// var model = function() { | ||
// mapData({data}, function(arr) { | ||
// observe(dist, arr[0]); | ||
// observe(dist, arr[1]); | ||
// observe(dist, arr[2]); | ||
// }); | ||
// }; | ||
|
||
// 3. There are no factor statements. We assume we can generate | ||
// samples from the posterior predictive distribution directly by | ||
// forward sampling. If there were additional factors we'd need to | ||
// account for them with e.g. importance sampling. | ||
|
||
// 4. observe is only used within mapData. | ||
|
||
module.exports = function(env) { | ||
|
||
var dreamSample = require('./sample')(env).dreamSample; | ||
var dreamGradients = require('./gradients')(env); | ||
|
||
return function(wpplFn, s, a, options, state, step, cont) { | ||
var opts = util.mergeDefaults(options, { | ||
samples: 1 | ||
}); | ||
|
||
var objVal = 0; | ||
var grad = {}; | ||
|
||
return util.cpsLoop( | ||
opts.samples, | ||
// Loop body. | ||
function(i, next) { | ||
return dreamSample(s, function(s, record) { | ||
return dreamGradients(wpplFn, record, s, a, function(g, objVal_i) { | ||
paramStruct.addEq(grad, g); | ||
objVal += objVal_i; | ||
return next(); | ||
}); | ||
}, a, wpplFn); | ||
}, | ||
// Continuation. | ||
function() { | ||
paramStruct.divEq(grad, opts.samples); | ||
objVal /= opts.samples; | ||
return cont(grad, objVal); | ||
}); | ||
}; | ||
|
||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
'use strict'; | ||
'use ad'; | ||
|
||
var assert = require('assert'); | ||
var _ = require('lodash'); | ||
var guide = require('../../guide'); | ||
var ad = require('../../ad'); | ||
|
||
module.exports = function(env) { | ||
|
||
function dreamGradients(wpplFn, record, s, a, cont) { | ||
this.wpplFn = wpplFn; | ||
this.record = record; | ||
this.s = s; | ||
this.a = a; | ||
this.cont = cont; | ||
|
||
this.guideRequired = true; | ||
this.insideMapData = false; | ||
|
||
this.coroutine = env.coroutine; | ||
env.coroutine = this; | ||
} | ||
|
||
dreamGradients.prototype = { | ||
|
||
run: function() { | ||
return this.estimateGradient(function(grad, objVal) { | ||
env.coroutine = this.coroutine; | ||
return this.cont(grad, objVal); | ||
}.bind(this)); | ||
}, | ||
|
||
estimateGradient: function(cont) { | ||
|
||
this.paramsSeen = {}; | ||
this.logq = 0; | ||
|
||
return this.wpplFn(_.clone(this.s), function(s, val) { | ||
|
||
var objective = -this.logq; | ||
if (ad.isLifted(objective)) { | ||
objective.backprop(); | ||
} | ||
|
||
var grads = _.mapValues(this.paramsSeen, function(params) { | ||
return params.map(ad.derivative); | ||
}); | ||
|
||
return cont(grads, -ad.value(this.logq)); | ||
|
||
}.bind(this), this.a); | ||
|
||
}, | ||
|
||
sample: function(s, k, a, dist, options) { | ||
options = options || {}; | ||
var choice = this.record.trace.findChoice(a); | ||
assert.ok(choice !== undefined, 'dream: No entry for this choice in the trace.'); | ||
var val = choice.val; | ||
|
||
if (this.insideMapData) { | ||
return guide.getDist( | ||
options.guide, options.noAutoGuide, dist, env, s, a, | ||
function(s, guideDist) { | ||
if (!guideDist) { | ||
throw new Error('dream: No guide distribution specified.'); | ||
} | ||
this.logq += guideDist.score(val); | ||
return k(s, val); | ||
}.bind(this)); | ||
} | ||
else { | ||
return k(s, val); | ||
} | ||
}, | ||
|
||
factor: function(s, k, a, score) { | ||
// This will only be called by the default implementation of | ||
// observe. (Because we checked that factor isn't called during | ||
// the sampling phase, and since we're reusing choices from the | ||
// trace, the execution here will follow the same path through | ||
// the program.) | ||
return k(s); | ||
}, | ||
|
||
mapDataFetch: function(data, opts, a) { | ||
if (this.insideMapData) { | ||
throw new Error('dream: nested mapData is not supported by this estimator.'); | ||
} | ||
this.insideMapData = true; | ||
return {data: this.record.data, ix: null, address: a + '_dream'}; | ||
}, | ||
|
||
mapDataFinal: function() { | ||
this.insideMapData = false; | ||
}, | ||
|
||
incrementalize: env.defaultCoroutine.incrementalize, | ||
constructor: dreamGradients | ||
|
||
}; | ||
|
||
return function() { | ||
var coroutine = Object.create(dreamGradients.prototype); | ||
dreamGradients.apply(coroutine, arguments); | ||
return coroutine.run(); | ||
}; | ||
|
||
}; |
Oops, something went wrong.