Skip to content

Commit

Permalink
Merge pull request #1715 from AustinRochford/add-discrete-weibull-dis…
Browse files Browse the repository at this point in the history
…tribution

Add discrete Weibull distribution
  • Loading branch information
AustinRochford committed Jan 25, 2017
2 parents 10eaf04 + 9e15318 commit 834e87f
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 9 deletions.
4 changes: 3 additions & 1 deletion pymc3/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .discrete import Binomial
from .discrete import BetaBinomial
from .discrete import Bernoulli
from .discrete import DiscreteWeibull
from .discrete import Poisson
from .discrete import NegativeBinomial
from .discrete import ConstantDist
Expand Down Expand Up @@ -125,5 +126,6 @@
'SkewNormal',
'Mixture',
'NormalMixture',
'Triangular'
'Triangular',
'DiscreteWeibull'
]
65 changes: 59 additions & 6 deletions pymc3/distributions/discrete.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from functools import partial

from functools import partial
import numpy as np
import theano
import theano.tensor as tt
Expand All @@ -8,10 +7,10 @@
from .dist_math import bound, factln, binomln, betaln, logpow
from .distribution import Discrete, draw_values, generate_samples, reshape_sampled

__all__ = ['Binomial', 'BetaBinomial', 'Bernoulli', 'Poisson',
'NegativeBinomial', 'ConstantDist', 'Constant', 'ZeroInflatedPoisson',
'ZeroInflatedNegativeBinomial', 'DiscreteUniform', 'Geometric',
'Categorical']
__all__ = ['Binomial', 'BetaBinomial', 'Bernoulli', 'DiscreteWeibull',
'Poisson', 'NegativeBinomial', 'ConstantDist', 'Constant',
'ZeroInflatedPoisson', 'ZeroInflatedNegativeBinomial',
'DiscreteUniform', 'Geometric', 'Categorical']


class Binomial(Discrete):
Expand Down Expand Up @@ -164,6 +163,60 @@ def logp(self, value):
p >= 0, p <= 1)


class DiscreteWeibull(Discrete):
R"""Discrete Weibull log-likelihood
The discrete Weibull distribution is a flexible model of count data that
can handle both over- and under-dispersion.
.. math:: f(x \mid q, \beta) = q^{x^{\beta}} - q^{(x + 1)^{\beta}}
======== ======================
Support :math:`x \in \mathbb{N}_0`
Mean :math:`\mu = \sum_{x = 1}^{\infty} q^{x^{\beta}}`
Variance :math:`2 \sum_{x = 1}^{\infty} x q^{x^{\beta}} - \mu - \mu^2`
======== ======================
"""
def __init__(self, q, beta, *args, **kwargs):
super(DiscreteWeibull, self).__init__(*args, defaults=['median'], **kwargs)

self.q = q
self.beta = beta

self.median = self._ppf(0.5)

def logp(self, value):
q = self.q
beta = self.beta

return bound(tt.log(tt.power(q, tt.power(value, beta)) - tt.power(q, tt.power(value + 1, beta))),
0 <= value,
0 < q, q < 1,
0 < beta)

def _ppf(self, p):
"""
The percentile point function (the inverse of the cumulative
distribution function) of the discrete Weibull distribution.
"""
q = self.q
beta = self.beta

return (tt.ceil(tt.power(tt.log(1 - p) / tt.log(q), 1. / beta)) - 1).astype('int64')

def _random(self, q, beta, size=None):
p = np.random.uniform(size=size)

return np.ceil(np.power(np.log(1 - p) / np.log(q), 1. / beta)) - 1

def random(self, point=None, size=None, repeat=None):
q, beta = draw_values([self.q, self.beta], point=point)

return generate_samples(self._random, q, beta,
dist_shape=self.shape,
size=size)


class Poisson(Discrete):
R"""
Poisson log-likelihood.
Expand Down
11 changes: 10 additions & 1 deletion pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
InverseGamma, Gamma, Cauchy, HalfCauchy, Lognormal, Laplace,
NegativeBinomial, Geometric, Exponential, ExGaussian, Normal,
Flat, LKJCorr, Wald, ChiSquared, HalfNormal, DiscreteUniform,
Bound, Uniform, Triangular, Binomial, Wishart, SkewNormal)
Bound, Uniform, Triangular, Binomial, Wishart, SkewNormal,
DiscreteWeibull)
from ..distributions import continuous, multivariate
from numpy import array, inf, log, exp
from numpy.testing import assert_almost_equal
Expand Down Expand Up @@ -210,6 +211,10 @@ def logpow(v, p):
return np.choose(v == 0, [p * np.log(v), 0])


def discrete_weibull_logpmf(value, q, beta):
return np.log(np.power(q, np.power(value, beta)) - np.power(q, np.power(value + 1, beta)))


def dirichlet_logpdf(value, a):
return (-betafn(a) + logpow(value, a - 1).sum(-1)).sum()

Expand Down Expand Up @@ -481,6 +486,10 @@ def test_bernoulli(self):
self.pymc3_matches_scipy(Bernoulli, Bool, {'p': Unit},
lambda value, p: sp.bernoulli.logpmf(value, p))

def test_discrete_weibull(self):
self.pymc3_matches_scipy(DiscreteWeibull, Nat,
{'q': Unit, 'beta': Rplusdunif}, discrete_weibull_logpmf)

def test_poisson(self):
self.pymc3_matches_scipy(Poisson, Nat, {'mu': Rplus},
lambda value, mu: sp.poisson.logpmf(value, mu))
Expand Down
16 changes: 15 additions & 1 deletion pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import pymc3 as pm
from .helpers import SeededTest
from .test_distributions import (build_model, Domain, product, R, Rplus, Rplusbig, Unit, Nat,
from .test_distributions import (build_model, Domain, product, R, Rplus, Rplusbig, Rplusdunif, Unit, Nat,
NatSmall, I, Simplex, Vector, PdMatrix)


Expand Down Expand Up @@ -299,6 +299,11 @@ class TestBernoulli(BaseTestCases.BaseTestCase):
params = {'p': 0.5}


class TestDiscreteWeibull(BaseTestCases.BaseTestCase):
distribution = pm.DiscreteWeibull
params = {'q': 0.25, 'beta': 2.}


class TestPoisson(BaseTestCases.BaseTestCase):
distribution = pm.Poisson
params = {'mu': 1.}
Expand Down Expand Up @@ -486,6 +491,15 @@ def ref_rand(size, lower, upper):
pymc3_random_discrete(pm.DiscreteUniform, {'lower': -NatSmall, 'upper': NatSmall},
ref_rand=ref_rand)

def test_discrete_weibull(self):
def ref_rand(size, q, beta):
u = np.random.uniform(size=size)

return np.ceil(np.power(np.log(1 - u) / np.log(q), 1. / beta)) - 1

pymc3_random_discrete(pm.DiscreteWeibull, {'q': Unit, 'beta': Rplusdunif},
ref_rand=ref_rand)

def test_categorical(self):
# Don't make simplex too big. You have been warned.
for s in [2, 3, 4]:
Expand Down

0 comments on commit 834e87f

Please sign in to comment.