Skip to content

Commit

Permalink
Merge pull request #864 from null-a/check-extra-args-at-sample
Browse files Browse the repository at this point in the history
More sample/dist arg checks
  • Loading branch information
stuhlmueller committed Jul 2, 2017
2 parents d9a195c + d8b2130 commit 4da45df
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 7 deletions.
2 changes: 1 addition & 1 deletion scripts/distHeader
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ var helpers = {
};

var t = _.mapValues({
ctor: 'var <%= name %> = function(params) { return util.jsnew(dists.<%= name %>, params); };',
ctor: 'var <%= name %> = dists.make<%= name %>;',
helper: [
'var <% print(downcaseInitial(name)) %> = function(<% print(fnParams(params)) %>) {',
' var params = util.isObject(arg0) ? arg0 : {<% print(args2Obj(params)) %>};',
Expand Down
22 changes: 21 additions & 1 deletion src/dists.ad.js
Original file line number Diff line number Diff line change
Expand Up @@ -1646,6 +1646,26 @@ var distributions = {
Delta: Delta
};

// For each distribution type, we create a WebPPL function that
// creates instances of that type. We include the argument check here
// to avoid the expensive slicing of arguments that would be required
// if it were anywhere else. e.g. In the WebPPL distribution header or
// in the distribution's JS constructor.

var wpplFns = _.chain(distributions)
.mapValues(function(ctor) {
return function(s, k, a, params) {
if (arguments.length > 4) {
throw new Error('Too many arguments. Distributions take at most one argument.');
}
return k(s, new ctor(params));
};
})
.mapKeys(function(ctor, name) {
return 'make' + name;
})
.value();

module.exports = _.assign({
// rng
betaSample: betaSample,
Expand All @@ -1664,4 +1684,4 @@ module.exports = _.assign({
squishToProbSimplex: squishToProbSimplex,
isDist: isDist,
metadata: metadata
}, distributions);
}, distributions, wpplFns);
8 changes: 8 additions & 0 deletions src/header.js
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ module.exports = function(env) {
if (!dists.isDist(dist)) {
throw new Error('sample() expected a distribution but received \"' + JSON.stringify(dist) + '\".');
}
for (var name in options) {
if (name !== 'guide' &&
name !== 'driftKernel' &&
name !== 'noAutoGuide' &&
name !== 'reparam') {
throw new Error('Unknown option "' + name + '" passed to sample.');
}
}
return env.coroutine.sample(s, k, a, dist, options);
};

Expand Down
5 changes: 0 additions & 5 deletions src/util.js
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,6 @@ function jsthrow(obj) {
throw obj;
}

function jsnew(ctor, arg) {
return new ctor(arg);
}

// Equivalent to Number.isInteger(), which isn't available in the
// version of phantom.js used on Travis at the time of writing.
function isInteger(x) {
Expand Down Expand Up @@ -404,7 +400,6 @@ module.exports = {
warn: warn,
resetWarnings: resetWarnings,
error: error,
jsnew: jsnew,
jsthrow: jsthrow,
isInteger: isInteger,
isObject: isObject,
Expand Down

0 comments on commit 4da45df

Please sign in to comment.