Skip to content

Commit

Permalink
Merge pull request #813 from null-a/fix-798-v2
Browse files Browse the repository at this point in the history
Handle duplicate values in Categorical.
  • Loading branch information
stuhlmueller committed Apr 7, 2017
2 parents ab39ab5 + 76ccf0f commit debae41
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions src/dists.ad.js
Original file line number Diff line number Diff line change
Expand Up @@ -1502,21 +1502,35 @@ var Categorical = makeDistributionType({
nohelper: true,
mixins: [finiteSupport],
constructor: function() {
this.ixmap = _.fromPairs(this.params.vs.map(function(v, ix) {
return [util.serialize(v), ix];
}));
'use ad';
var ps = this.params.ps;
var vs = this.params.vs;
if (vs.length !== ad.value(ps).length) {
throw new Error('Parameters ps and vs should have the same length.');
}
if (vs.length === 0) {
throw new Error('Parameters ps and vs should have length > 0.');
}
var dist = {};
var norm = _.isArray(ps) ? sum(ps) : T.sumreduce(ps);
for (var i in vs) {
var val = vs[i];
var k = util.serialize(val);
if (!_.has(dist, k)) {
dist[k] = {val: val, prob: 0};
}
dist[k].prob += (_.isArray(ps) ? ps[i] : T.get(ps, i)) / norm;
}
this.marginal = new Marginal({dist: dist});
},
sample: function() {
var ix = discreteSample(toUnliftedArray(this.params.ps));
var vs = this.params.vs.map(ad.value);
return vs[ix];
return this.marginal.sample();
},
score: function(val) {
var ix = this.ixmap[util.serialize(val)];
return discreteScore(this.params.ps, ix);
return this.marginal.score(val);
},
support: function() {
return this.params.vs;
return this.marginal.support();
}
});

Expand Down

0 comments on commit debae41

Please sign in to comment.