Skip to content

Commit

Permalink
Merge pull request #842 from null-a/optional-dist-params
Browse files Browse the repository at this point in the history
Support optional distribution parameters
  • Loading branch information
stuhlmueller committed May 2, 2017
2 parents 4fa25d8 + 9b9138d commit 2842a38
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 17 deletions.
2 changes: 1 addition & 1 deletion docs/primitive-distributions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
* ps: probabilities (can be unnormalized) *(vector or real array [0, Infinity))*
* vs: support *(any array)*

Distribution over elements of ``vs`` with ``P(vs[i])`` proportional to ``ps[i]``
Distribution over elements of ``vs`` with ``P(vs[i])`` proportional to ``ps[i]``. ``ps`` may be omitted, in which case a uniform distribution over ``vs`` is returned.

`Wikipedia entry <https://en.wikipedia.org/wiki/Categorical_distribution>`__

Expand Down
49 changes: 35 additions & 14 deletions src/dists.ad.js
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,19 @@
//
// - All distributions of a particular type should share the same set
// of parameters.
//
// Optional parameters:
//
// The default Dist constructor checks that all parameters described
// in a distribution's definition are present. The `optional` flag can
// be set on a parameter to skip this check. In such cases a custom
// constructor must be defined to fill in the value of the parameter
// when omitted. More specifically, the constructor should extend the
// `params` object to include an appropriate default value. This
// ensures that the condition described above (regarding all types
// sharing the same set of parameters) is met. Care should also be
// taken not to modify the `params` object originally passed by the
// user.

'use strict';

Expand Down Expand Up @@ -176,26 +189,25 @@ function makeDistributionType(options) {
}
return param.type;
});
var parameterOptionalFlags = _.map(options.params, 'optional');
var extraConstructorFn = options.constructor;

// Note that Chrome uses the name of this local variable in the
// output of `console.log` when it's called on a distribution that
// uses the default constructor.
var dist = function(params) {

if (params === undefined && parameterNames.length > 0) {
throw new Error('Parameters not supplied to ' + this.meta.name + ' distribution.');
}
params = params || {};
parameterNames.forEach(function(p, i) {
if (!params.hasOwnProperty(p)) {
throw new Error('Parameter \"' + p + '\" missing from ' + this.meta.name + ' distribution.');
}

var type = parameterTypes[i];
if (type && !type.check(ad.valueRec(params[p]))) {
throw new Error('Parameter \"' + p + '\" should be of type "' + type.desc + '".');
if (params.hasOwnProperty(p)) {
var type = parameterTypes[i];
if (type && !type.check(ad.valueRec(params[p]))) {
throw new Error('Parameter \"' + p + '\" should be of type "' + type.desc + '".');
}
} else {
if (!parameterOptionalFlags[i]) {
throw new Error('Parameter \"' + p + '\" missing from ' + this.meta.name + ' distribution.');
}
}

}, this);
this.params = params;
if (extraConstructorFn !== undefined) {
Expand Down Expand Up @@ -1506,14 +1518,23 @@ function printMarginal(dist) {

var Categorical = makeDistributionType({
name: 'Categorical',
desc: 'Distribution over elements of ``vs`` with ``P(vs[i])`` proportional to ``ps[i]``',
desc: 'Distribution over elements of ``vs`` with ``P(vs[i])`` proportional to ``ps[i]``. ' +
'``ps`` may be omitted, in which case a uniform distribution over ``vs`` is returned.',
params: [
{name: 'ps', desc: 'probabilities (can be unnormalized)', type: types.nonNegativeVectorOrRealArray},
{name: 'ps', desc: 'probabilities (can be unnormalized)',
type: types.nonNegativeVectorOrRealArray, optional: true},
{name: 'vs', desc: 'support', type: types.array(types.any)}],
wikipedia: true,
mixins: [finiteSupport],
constructor: function() {
'use ad';
// Add default for ps when omitted.
if (this.params.ps === undefined) {
this.params = {
ps: _.fill(Array(this.params.vs.length), 1),
vs: this.params.vs
};
}
var ps = this.params.ps;
var vs = this.params.vs;
if (vs.length !== ad.value(ps).length) {
Expand Down
2 changes: 1 addition & 1 deletion tests/test-data/deterministic/expected/expectation.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"result": [0.9, 0.9, 0.9, 0.9]
"result": [0.9, 0.9, 0.9, 0.9, 0.5]
}
3 changes: 2 additions & 1 deletion tests/test-data/deterministic/models/expectation.wppl
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
expectation(Discrete({ps: [3, 5, 2]}), idF),
expectation(Discrete({ps: Vector([3, 5, 2])}), idF),
expectation(Categorical({ps: [3, 5, 2], vs: [0, 1, 2]}), idF),
expectation(Categorical({ps: Vector([3, 5, 2]), vs: [0, 1, 2]}), idF)
expectation(Categorical({ps: Vector([3, 5, 2]), vs: [0, 1, 2]}), idF),
expectation(Categorical({vs: [0, 1]}), idF)
];

0 comments on commit 2842a38

Please sign in to comment.