Skip to content

Commit

Permalink
Merge pull request #801 from null-a/fix-789
Browse files Browse the repository at this point in the history
Make param work from within nested Enumerate.
  • Loading branch information
stuhlmueller committed Mar 27, 2017
2 parents 5fd6417 + ca65221 commit 7023e0f
Show file tree
Hide file tree
Showing 14 changed files with 78 additions and 17 deletions.
9 changes: 2 additions & 7 deletions src/guide.js
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,7 @@ function guideCoroutine(env) {
sample: sampleNotAllowed,
factor: factorNotAllowed,
incrementalize: env.defaultCoroutine.incrementalize,
// Copy the entry address from the current coroutine so that
// parameter names continue to be relative to it.
a: env.coroutine.a,
// Use paramsSeen of the current coroutine so that params are
// tracked correctly.
paramsSeen: env.coroutine.paramsSeen,
coroutine: env.coroutine,
// A flag used when creating parameters to check whether we're in
// a guide thunk. Note that this does the right thing if Infer is
// used within a guide. This can be checked from a webppl program
Expand Down Expand Up @@ -105,7 +100,7 @@ function independent(targetDist, sampleAddress, env) {
// avoid collisions when the distribution type changes between
// calls. (As a result of the distribution passed depending on a
// random choice.)
var relativeAddress = util.relativizeAddress(env, sampleAddress);
var relativeAddress = util.relativizeAddress(params.baseAddress(env), sampleAddress);
var baseName = relativeAddress + '$mf$' + targetDist.meta.name + '$';

var distSpec = spec(targetDist);
Expand Down
1 change: 1 addition & 0 deletions src/header.js
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ module.exports = function(env) {
var args = [s, k, a].concat(args);
return fn.apply(global, args);
},
isParamBase: true,
a: '' // Entry address. Enables relative addressing.
};

Expand Down
2 changes: 2 additions & 0 deletions src/inference/dream/gradients.ad.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ module.exports = function(env) {
this.cont = cont;

this.guideRequired = true;
this.isParamBase = true;

this.insideMapData = false;

this.coroutine = env.coroutine;
Expand Down
2 changes: 2 additions & 0 deletions src/inference/dream/sample.js
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ module.exports = function(env) {
this.record = {trace: trace, data: []};

this.guideRequired = true;
this.isParamBase = true;

this.insideMapData = false;

this.coroutine = env.coroutine;
Expand Down
1 change: 1 addition & 0 deletions src/inference/elbo.ad.js
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ module.exports = function(env) {
this.s = s;
this.a = a;
this.guideRequired = true;
this.isParamBase = true;

// Initialize mapData state.
this.mapDataStack = [{multiplier: 1}];
Expand Down
3 changes: 2 additions & 1 deletion src/inference/eubo.ad.js
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ module.exports = function(env) {
this.s = s;
this.a = a;
this.guideRequired = true;
this.isParamBase = true;

this.coroutine = env.coroutine;
env.coroutine = this;
Expand Down Expand Up @@ -120,7 +121,7 @@ module.exports = function(env) {
throw new Error('EUBO: No guide distribution to optimize.');
}

var rel = util.relativizeAddress(env, a);
var rel = util.relativizeAddress(this.a, a);
var guideVal = this.trace.findChoice(this.trace.baseAddress + rel).val;
assert.notStrictEqual(guideVal, undefined);

Expand Down
1 change: 1 addition & 0 deletions src/inference/forwardSample.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ module.exports = function(env) {

// Indicate that guide thunks should run.
this.guideRequired = sampleGuide;
this.isParamBase = true;

this.score = 0;
this.logWeight = 0;
Expand Down
1 change: 1 addition & 0 deletions src/inference/smc.js
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ module.exports = function(env) {
this.saveTraces = options.saveTraces;
this.importanceOpt = options.importance;
this.guideRequired = options.importance !== 'ignoreGuide';
this.isParamBase = true;
this.onlyMAP = options.onlyMAP;

this.particles = [];
Expand Down
4 changes: 3 additions & 1 deletion src/params/header.js
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ module.exports = function(env) {
}

var dims = options.dims;
var name = _.has(options, 'name') ? options.name : util.relativizeAddress(env, a);
var name = _.has(options, 'name') ?
options.name :
util.relativizeAddress(params.baseAddress(env), a);

if (params.exists(name)) {
return finish(s);
Expand Down
34 changes: 32 additions & 2 deletions src/params/params.js
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ function fetch(name, env) {
}

var paramTable = get();
var paramsSeen = env.coroutine.paramsSeen;
var paramsSeen = getParamsSeen(env);

// If we're outside of optimization, just return the value of the
// parameter, unlifted.
Expand All @@ -116,6 +116,35 @@ function fetch(name, env) {
}
}

function findCoroutine(predicate, coroutine) {
if (predicate(coroutine)) {
return coroutine;
} else if (_.has(coroutine, 'coroutine')) {
return findCoroutine(predicate, coroutine.coroutine);
} else {
return null;
}
}

function getParamsSeen(env) {
var coroutine = findCoroutine(_.property('paramsSeen'), env.coroutine);
return coroutine ? coroutine.paramsSeen : null;
}

// Returns the base address used when automatically generating
// parameter names based on relative stack addresses. The strategy is
// to walk the coroutine stack starting from the current coroutine,
// looking for the first coroutine with the isParamBase flag set. The
// entry address of the coroutine found this way is returned. This is
// expected to always find a coroutine, since env.defaultCoroutine has
// the flag set.
function baseAddress(env) {
var baseCoroutine = findCoroutine(_.property('isParamBase'), env.coroutine);
assert.ok(baseCoroutine, 'Could not find base coroutine.');
assert.ok(_.has(baseCoroutine, 'a'), 'Entry address not saved on coroutine.');
return baseCoroutine.a;
}

module.exports = {
get: get,
set: set,
Expand All @@ -125,5 +154,6 @@ module.exports = {
sync: sync,
exists: exists,
create: create,
fetch: fetch
fetch: fetch,
baseAddress: baseAddress
};
7 changes: 1 addition & 6 deletions src/util.js
Original file line number Diff line number Diff line change
Expand Up @@ -340,12 +340,7 @@ function tensorEqDims(t1, t2) {
return true;
}

function relativizeAddress(env, address) {
// Takes the env and a full stack address and returns a new address
// relative to the entry address of the current coroutine. This
// requires each coroutine to save its entry address as `this.a`.
assert.ok(_.has(env.coroutine, 'a'), 'Entry address not saved on coroutine.');
var baseAddress = env.coroutine.a;
function relativizeAddress(baseAddress, address) {
assert.ok(address.slice(0, baseAddress.length) === baseAddress, 'Address prefix mismatch.');
return address.slice(baseAddress.length);
}
Expand Down
6 changes: 6 additions & 0 deletions tests/test-data/stochastic/expected/enumGuide.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"hist": {
"true": 0.8,
"false": 0.2
}
}
17 changes: 17 additions & 0 deletions tests/test-data/stochastic/models/enumGuide.wppl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
var model = function() {
var flips = mapData({data: [0.8, 0.2]}, function(p) {
return sample(Bernoulli({p}), {guide() {
// The guide returned here is expected to be equivalent to:
// Bernoulli({p: Math.sigmoid(param())});

// The correct behavior relies on optimization and automatic
// parameter naming correctly handling the enumeration of the
// guide distribution.

return Infer({method: 'enumerate', model() {
return flip(Math.sigmoid(param()));
}});
}});
});
return first(flips);
};
7 changes: 7 additions & 0 deletions tests/test-inference.js
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,13 @@ var tests = [
onlyMAP: true,
verbose: false
}
},
enumGuide: {
args: {
steps: 6000,
samples: 10000,
verbose: false
}
}
}
},
Expand Down

0 comments on commit 7023e0f

Please sign in to comment.