Skip to content

Commit

Permalink
Merge pull request #1596 from pymc-devs/fix_bounds2
Browse files Browse the repository at this point in the history
Make bound broadcast again.
  • Loading branch information
jsalvatier committed Dec 19, 2016
2 parents 1a1a36e + 6c7f127 commit d9e57e8
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 29 deletions.
2 changes: 1 addition & 1 deletion pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def logp(self, value):
lower = self.lower
upper = self.upper
return bound(-tt.log(upper - lower),
value >= lower, value <= upper)
value >= lower, value <= upper)


class Flat(Continuous):
Expand Down
4 changes: 2 additions & 2 deletions pymc3/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ class ZeroInflatedPoisson(Discrete):
Often used to model the number of events occurring in a fixed period
of time when the times at which events occur are independent.
.. math::
.. math::
f(x \mid \theta, \psi) = \left\{ \begin{array}{l}
(1-\psi) + \psi e^{-\theta}, \text{if } x = 0 \\
Expand Down Expand Up @@ -503,7 +503,7 @@ class ZeroInflatedNegativeBinomial(Discrete):
The Zero-inflated version of the Negative Binomial (NB).
The NB distribution describes a Poisson random variable
whose rate parameter is gamma distributed.
whose rate parameter is gamma distributed.
.. math::
Expand Down
30 changes: 24 additions & 6 deletions pymc3/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,43 @@

from .special import gammaln, multigammaln


def bound(logp, *conditions):
def bound(logp, *conditions, **kwargs):
"""
Bounds a log probability density with several conditions
Bounds a log probability density with several conditions.
Parameters
----------
logp : float
*conditions : booleans
broadcast_conditions : bool (optional, default=True)
If True, broadcasts logp to match the largest shape of the conditions.
This is used e.g. in DiscreteUniform where logp is a scalar constant and the shape
is specified via the conditions.
If False, will return the same shape as logp.
This is used e.g. in Multinomial where broadcasting can lead to differences in the logp.
Returns
-------
logp if all conditions are true
-inf if some are false
logp with elements set to -inf where any condition is False
"""
broadcast_conditions = kwargs.get('broadcast_conditions', True)

if broadcast_conditions:
alltrue = alltrue_elemwise
else:
alltrue = alltrue_scalar

return tt.switch(alltrue(conditions), logp, -np.inf)


def alltrue(vals):
def alltrue_elemwise(vals):
ret = 1
for c in vals:
ret = ret * (1 * c)
return ret


def alltrue_scalar(vals):
return tt.all([tt.all(1 * val) for val in vals])


Expand Down
15 changes: 11 additions & 4 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ def logp(self, value):
return bound(tt.sum(logpow(value, a - 1) - gammaln(a), axis=-1)
+ gammaln(tt.sum(a, axis=-1)),
tt.all(value >= 0), tt.all(value <= 1),
k > 1, tt.all(a > 0))
k > 1, tt.all(a > 0),
broadcast_conditions=False)


class Multinomial(Discrete):
Expand Down Expand Up @@ -323,7 +324,9 @@ def logp(self, x):
tt.all(tt.eq(tt.sum(x, axis=-1, keepdims=True), n)),
tt.all(p <= 1),
tt.all(tt.eq(tt.sum(p, axis=-1), 1)),
tt.all(tt.ge(n, 0)))
tt.all(tt.ge(n, 0)),
broadcast_conditions=False
)


def posdef(AA):
Expand Down Expand Up @@ -443,7 +446,9 @@ def logp(self, X):
- 2 * multigammaln(n / 2., p)) / 2,
matrix_pos_def(X),
tt.eq(X, X.T),
n > (p - 1))
n > (p - 1),
broadcast_conditions=False
)


def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, testval=None):
Expand Down Expand Up @@ -605,4 +610,6 @@ def logp(self, x):
return bound(result,
tt.all(X <= 1), tt.all(X >= -1),
matrix_pos_def(X),
n > 0)
n > 0,
broadcast_conditions=False
)
110 changes: 94 additions & 16 deletions pymc3/tests/test_dist_math.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,114 @@
import numpy as np
import theano.tensor as tt
import pymc3 as pm

from ..distributions.dist_math import alltrue
from ..distributions import Discrete
from ..distributions.dist_math import bound, factln, alltrue_elemwise, alltrue_scalar


def test_alltrue():
assert alltrue([]).eval()
assert alltrue([True]).eval()
assert alltrue([tt.ones(10)]).eval()
assert alltrue([tt.ones(10),
def test_bound():
logp = tt.ones((10, 10))
cond = tt.ones((10, 10))
assert np.all(bound(logp, cond).eval() == logp.eval())

logp = tt.ones((10, 10))
cond = tt.zeros((10, 10))
assert np.all(bound(logp, cond).eval() == (-np.inf * logp).eval())

logp = tt.ones((10, 10))
cond = True
assert np.all(bound(logp, cond).eval() == logp.eval())

logp = tt.ones(3)
cond = np.array([1, 0, 1])
assert not np.all(bound(logp, cond).eval() == 1)
assert np.prod(bound(logp, cond).eval()) == -np.inf

logp = tt.ones((2, 3))
cond = np.array([[1, 1, 1], [1, 0, 1]])
assert not np.all(bound(logp, cond).eval() == 1)
assert np.prod(bound(logp, cond).eval()) == -np.inf

def test_alltrue_scalar():
assert alltrue_scalar([]).eval()
assert alltrue_scalar([True]).eval()
assert alltrue_scalar([tt.ones(10)]).eval()
assert alltrue_scalar([tt.ones(10),
5 * tt.ones(101)]).eval()
assert alltrue([np.ones(10),
assert alltrue_scalar([np.ones(10),
5 * tt.ones(101)]).eval()
assert alltrue([np.ones(10),
assert alltrue_scalar([np.ones(10),
True,
5 * tt.ones(101)]).eval()
assert alltrue([np.array([1, 2, 3]),
assert alltrue_scalar([np.array([1, 2, 3]),
True,
5 * tt.ones(101)]).eval()

assert not alltrue([False]).eval()
assert not alltrue([tt.zeros(10)]).eval()
assert not alltrue([True,
assert not alltrue_scalar([False]).eval()
assert not alltrue_scalar([tt.zeros(10)]).eval()
assert not alltrue_scalar([True,
False]).eval()
assert not alltrue([np.array([0, -1]),
assert not alltrue_scalar([np.array([0, -1]),
tt.ones(60)]).eval()
assert not alltrue([np.ones(10),
assert not alltrue_scalar([np.ones(10),
False,
5 * tt.ones(101)]).eval()


def test_alltrue_shape():
vals = [True, tt.ones(10), tt.zeros(5)]

assert alltrue(vals).eval().shape == ()
assert alltrue_scalar(vals).eval().shape == ()

class MultinomialA(Discrete):
def __init__(self, n, p, *args, **kwargs):
super(MultinomialA, self).__init__(*args, **kwargs)

self.n = n
self.p = p

def logp(self, value):
n = self.n
p = self.p

return bound(factln(n) - factln(value).sum() + (value * tt.log(p)).sum(),
value >= 0,
0 <= p, p <= 1,
tt.isclose(p.sum(), 1),
broadcast_conditions=False
)


class MultinomialB(Discrete):
def __init__(self, n, p, *args, **kwargs):
super(MultinomialB, self).__init__(*args, **kwargs)

self.n = n
self.p = p

def logp(self, value):
n = self.n
p = self.p

return bound(factln(n) - factln(value).sum() + (value * tt.log(p)).sum(),
tt.all(value >= 0),
tt.all(0 <= p), tt.all(p <= 1),
tt.isclose(p.sum(), 1),
broadcast_conditions=False
)


def test_multinomial_bound():

x = np.array([1, 5])
n = x.sum()

with pm.Model() as modelA:
p_a = pm.Dirichlet('p', np.ones(2))
x_obs_a = MultinomialA('x', n, p_a, observed=x)

with pm.Model() as modelB:
p_b = pm.Dirichlet('p', np.ones(2))
x_obs_b = MultinomialB('x', n, p_b, observed=x)

assert np.isclose(modelA.logp({'p_stickbreaking_': [0]}),
modelB.logp({'p_stickbreaking_': [0]}))

0 comments on commit d9e57e8

Please sign in to comment.