Skip to content

Commit

Permalink
Merge pull request #823 from null-a/auto-guide
Browse files Browse the repository at this point in the history
Add auto-guides for RandomInteger and Binomial distributions.
  • Loading branch information
stuhlmueller committed Apr 18, 2017
2 parents 8c89709 + 177b150 commit 37cc9d2
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/guide.js
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,12 @@ function spec(targetDist) {
} else if (targetDist instanceof dists.Beta) {
return betaSpec(targetDist);
} else if (targetDist instanceof dists.Discrete) {
return discreteSpec(targetDist);
} else if (targetDist instanceof dists.RandomInteger ||
targetDist instanceof dists.Binomial ||
targetDist instanceof dists.MultivariateGaussian ||
return discreteSpec(ad.value(targetDist.params.ps).length);
} else if (targetDist instanceof dists.RandomInteger) {
return discreteSpec(targetDist.params.n);
} else if (targetDist instanceof dists.Binomial) {
return discreteSpec(targetDist.params.n + 1);
} else if (targetDist instanceof dists.MultivariateGaussian ||
targetDist instanceof dists.Marginal ||
targetDist instanceof dists.SampleBasedMarginal) {
throwAutoGuideError(targetDist);
Expand Down Expand Up @@ -292,12 +294,11 @@ function gammaSpec(targetDist) {
};
}

function discreteSpec(targetDist) {
var d = ad.value(targetDist.params.ps).length;
function discreteSpec(dim) {
return {
type: dists.Discrete,
params: {
ps: {param: {dims: [d - 1, 1], squish: dists.squishToProbSimplex}}
ps: {param: {dims: [dim - 1, 1], squish: dists.squishToProbSimplex}}
}
};
}
Expand Down
7 changes: 7 additions & 0 deletions tests/test-data/stochastic/expected/binomial2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"hist": {
"0": 0.36,
"1": 0.48,
"2": 0.16
}
}
6 changes: 6 additions & 0 deletions tests/test-data/stochastic/expected/randomInteger2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"hist": {
"0": 0.8,
"1": 0.2
}
}
3 changes: 3 additions & 0 deletions tests/test-data/stochastic/models/binomial2.wppl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
var model = function() {
return binomial(0.4, 2);
};
5 changes: 5 additions & 0 deletions tests/test-data/stochastic/models/randomInteger2.wppl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
var model = function() {
var x = randomInteger(2);
factor(x === 0 ? 0 : Math.log(0.25));
return x;
};
18 changes: 18 additions & 0 deletions tests/test-inference.js
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,13 @@ var tests = [
deterministic: { hist: { exact: true } },
store: { hist: { exact: true } },
geometric: { args: { maxExecutions: 10 } },
binomial2: true,
delta: { args: { exact: true } },
discreteArr: true,
discreteVec: true,
categoricalArr: true,
categoricalVec: true,
randomInteger2: true,
cache: true,
withCaching: true,
earlyExit: { hist: { exact: true } },
Expand Down Expand Up @@ -674,6 +676,22 @@ var tests = [
verbose: false
}
},
randomInteger2: {
args: {
samples: 1000,
steps: 1000,
optMethod: {adam: {stepSize: 0.1}},
verbose: false
}
},
binomial2: {
args: {
samples: 1000,
steps: 1000,
optMethod: {adam: {stepSize: 0.1}},
verbose: false
}
},
mapData: true,
onlyMAP: {
mean: { tol: 0.1 },
Expand Down

0 comments on commit 37cc9d2

Please sign in to comment.