Skip to content

Commit

Permalink
Improve params.register() (#747)
Browse files Browse the repository at this point in the history
* Drop callback.

This is only used by daipp. This will be addressed later.

* Assume that getParams returns unlifted values.

* Rename variable.

* Move ad.lift outside of conditional.

* Only lift parameters when doing optimization.

* Improve commentary.

* Better name for function parameter.
  • Loading branch information
null-a authored and stuhlmueller committed Jan 6, 2017
1 parent 33d7b1c commit dc41f3b
Showing 1 changed file with 19 additions and 28 deletions.
47 changes: 19 additions & 28 deletions src/params/params.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
'use strict';

var assert = require('assert');
var _ = require('lodash');
var fs = require('fs');
var ad = require('../ad');
Expand Down Expand Up @@ -72,55 +73,45 @@ function set(params, k) {
}


function register(env, name, getParams, setParams) {

// getParams is expected to be a function which is used to
// initialize parameters the first time they are encoutered. At
// present I consider it to be `register` responsibility to
// perform lifting of params, so ideally `getParams` would not
// return lifted params. However, in the case of NN, `getParams`
// returns params already lifted. Hence, `getParams()` is replaced
// with `getParams().map(ad.value)` throughout this function.
function register(env, name, initParams) {

var paramTable = get();
var paramsSeen = env.coroutine.paramsSeen;

if (paramsSeen && _.has(paramsSeen, name)) {

// We've already lifted these params during this execution.
// We've already lifted these parameters during this execution.
// Re-use ad graph nodes.

return paramsSeen[name];

} else {

// This is the first time we've encounter these params during
// this execution. we will lift params at this point.

var params;

// Get parameter values from the store, or initialize if this is a
// new parameter.
var _params;
if (_.has(paramTable, name)) {
// Seen on previous execution. Fetch from store and lift.
params = paramTable[name].map(ad.lift);
// Parameters already initialized. Fetch values from store.
_params = paramTable[name];
} else {
// Never seen. Fetch initial values, add to store and lift.
var prms = getParams().map(ad.value);
paramTable[name] = prms;
params = prms.map(ad.lift);
// Never seen. Fetch initial values and add to store.
_params = initParams();
assert.ok(_.every(_params, _.negate(ad.isLifted)),
'initParams unexpectedly returned a lifted value.');
paramTable[name] = _params;
}

if (paramsSeen) {
// Lift parameters if the current coroutine is tracking
// parameters for optimization.
var params = _params.map(ad.lift);
paramsSeen[name] = params;
return params;
} else {
return _params;
}

// Callback with the fresh ad graph nodes.
if (setParams) {
setParams(params);
}

return params;
}

}


Expand Down

0 comments on commit dc41f3b

Please sign in to comment.