Skip to content

Commit

Permalink
Add Gompertz distribution (#1551)
Browse files Browse the repository at this point in the history
* Added Gompertz distribution class, sampling utilites

* reparameterized in terms of a and b (loc and shape) as in cited paper

* updated log_prob

* updated log_prob

* addressed comments

* Update continuous.py

missed two

* expm1/log1p, but correclty this time

* nit

* math

* distributions.rst updated

* formatting

* variance approximation

* removed unnecessary comment

* tests + docs

* fixed math

* TeX typo

* reparameterized as requested

* add mean method for Gompertz

* fix merge conflict

* fix lint

* use raw string

* Update requirements.txt

add pillow to dependencies

* fix some links

* add pylab dependency

* add matplotlib

* remove debug code

* remove debug info stuff for provenance

* add missing *

* add missing module

* Update setup.py

* Update docs for Gompertz

---------

Co-authored-by: jackpotrykus <jackpotrykus@gmail.com>
Co-authored-by: jackpot-nfer <73846181+jackpot-nfer@users.noreply.github.com>
Co-authored-by: Jack Potrykus <50253393+jackpotrykus@users.noreply.github.com>
  • Loading branch information
4 people committed Mar 14, 2023
1 parent 9a43b19 commit 849d4cf
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 13 deletions.
3 changes: 3 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ jax
jaxlib
jaxns==1.0.0
Jinja2<3.1
matplotlib
multipledispatch
nbsphinx==0.8.9
numpy
optax
pillow
pylab-sdk
pyyaml
readthedocs-sphinx-search==0.1.2
sphinx<5
Expand Down
24 changes: 16 additions & 8 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,6 @@ Gamma
:show-inheritance:
:member-order: bysource

Gumbel
^^^^^^
.. autoclass:: numpyro.distributions.continuous.Gumbel
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

GaussianCopula
^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.copula.GaussianCopula
Expand All @@ -200,6 +192,22 @@ GaussianRandomWalk
:show-inheritance:
:member-order: bysource

Gompertz
^^^^^^^^
.. autoclass:: numpyro.distributions.continuous.Gompertz
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

Gumbel
^^^^^^
.. autoclass:: numpyro.distributions.continuous.Gumbel
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

HalfCauchy
^^^^^^^^^^
.. autoclass:: numpyro.distributions.continuous.HalfCauchy
Expand Down
4 changes: 4 additions & 0 deletions docs/source/svi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ We offer a brief overview of the three most commonly used ELBO implementations i
:show-inheritance:
:member-order: bysource

.. autodata:: numpyro.infer.svi.SVIState

.. autodata:: numpyro.infer.svi.SVIRunResult

ELBO
----

Expand Down
2 changes: 2 additions & 0 deletions numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Exponential,
Gamma,
GaussianRandomWalk,
Gompertz,
Gumbel,
HalfCauchy,
HalfNormal,
Expand Down Expand Up @@ -140,6 +141,7 @@
"Geometric",
"GeometricLogits",
"GeometricProbs",
"Gompertz",
"Gumbel",
"HalfCauchy",
"HalfNormal",
Expand Down
58 changes: 58 additions & 0 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from jax.scipy.linalg import cho_solve, solve_triangular
from jax.scipy.special import (
betaln,
expi,
expit,
gammainc,
gammaln,
Expand Down Expand Up @@ -668,6 +669,63 @@ def cdf(self, x):
return 1 - self.base_dist.cdf(1 / x)


class Gompertz(Distribution):
r"""Gompertz Distribution.
The Gompertz distribution is a distribution with support on the positive real line that is closely
related to the Gumbel distribution. This implementation follows the notation used in the Wikipedia
entry for the Gompertz distribution. See https://en.wikipedia.org/wiki/Gompertz_distribution.
However, we call the parameter "eta" a concentration parameter and the parameter
"b" a rate parameter (as opposed to scale parameter as in wikipedia description.)
The CDF, in terms of `concentration` (`con`) and `rate`, is
.. math::
F(x) = 1 - \exp \left\{ - \text{con} * \left [ \exp\{x * rate \} - 1 \right ] \right\}
"""

arg_constraints = {
"concentration": constraints.positive,
"rate": constraints.positive,
}
support = constraints.positive
reparametrized_params = ["concentration", "rate"]

def __init__(self, concentration, rate=1.0, *, validate_args=None):
self.concentration, self.rate = promote_shapes(concentration, rate)
super(Gompertz, self).__init__(
batch_shape=lax.broadcast_shapes(jnp.shape(concentration), jnp.shape(rate)),
validate_args=validate_args,
)

def sample(self, key, sample_shape=()):
assert is_prng_key(key)
random_shape = sample_shape + self.batch_shape + self.event_shape
unifs = random.uniform(key, shape=random_shape)
return self.icdf(unifs)

@validate_sample
def log_prob(self, value):
scaled_value = value * self.rate
return (
jnp.log(self.concentration)
+ jnp.log(self.rate)
+ scaled_value
- self.concentration * jnp.expm1(scaled_value)
)

def cdf(self, value):
return -jnp.expm1(-self.concentration * jnp.expm1(value * self.rate))

def icdf(self, q):
return jnp.log1p(-jnp.log1p(-q) / self.concentration) / self.rate

@property
def mean(self):
return -jnp.exp(self.concentration) * expi(-self.concentration) / self.rate


class Gumbel(Distribution):
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real
Expand Down
4 changes: 2 additions & 2 deletions numpyro/infer/svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"""
A :func:`~collections.namedtuple` consisting of the following fields:
- **params** - the optimized parameters.
- **state** - the last :class:`SVIState`
- **state** - the last :data:`SVIState`
- **losses** - the losses collected at every step.
"""

Expand Down Expand Up @@ -337,7 +337,7 @@ def run(
:return: a namedtuple with fields `params` and `losses` where `params`
holds the optimized values at :class:`numpyro.param` sites,
and `losses` is the collected loss during the process.
:rtype: SVIRunResult
:rtype: :data:`SVIRunResult`
"""

if num_steps < 1:
Expand Down
5 changes: 2 additions & 3 deletions numpyro/ops/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jax._src.pjit import pjit_p
from jax.api_util import flatten_fun, shaped_abstractify
import jax.core as core
from jax.interpreters.partial_eval import debug_info, trace_to_jaxpr_dynamic
from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic
from jax.interpreters.pxla import xla_pmap_p
from jax.interpreters.xla import xla_call_p
import jax.linear_util as lu
Expand Down Expand Up @@ -41,9 +41,8 @@ def eval_provenance(fn, **kwargs):
# XXX: we split out the process of abstract evaluation and provenance tracking
# for simplicity. In principle, they can be merged so that we only need to walk
# through the equations once.
info = debug_info(fn, in_tree, True, "eval_shape")
jaxpr, avals_out, _ = trace_to_jaxpr_dynamic(
lu.wrap_init(wrapped_fun.call_wrapped, {}), avals, info
lu.wrap_init(wrapped_fun.call_wrapped, {}), avals
)

# get provenances of flatten kwargs
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,11 @@
"funsor>=0.4.1",
"graphviz",
"jaxns==1.0.0",
"matplotlib",
"optax>=0.0.6",
"pylab-sdk", # jaxns dependency
"pyyaml", # flax dependency
"requests", # pylab dependency
"tensorflow_probability>=0.17.0",
],
"examples": [
Expand Down
4 changes: 4 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,8 @@ def get_sp_dist(jax_dist):
),
T(dist.GaussianCopulaBeta, 2.0, 1.5, np.eye(3)),
T(dist.GaussianCopulaBeta, 2.0, 1.5, np.full((5, 3, 3), np.eye(3))),
T(dist.Gompertz, np.array([1.7]), np.array([[2.0], [3.0]])),
T(dist.Gompertz, np.array([0.5, 1.3]), np.array([[1.0], [3.0]])),
T(dist.Gumbel, 0.0, 1.0),
T(dist.Gumbel, 0.5, 2.0),
T(dist.Gumbel, np.array([0.0, 0.5]), np.array([1.0, 2.0])),
Expand Down Expand Up @@ -1742,6 +1744,8 @@ def get_min_shape(ix, batch_shape):
assert_allclose(jnp.mean(samples, 0), d_jax.mean, rtol=0.05, atol=1e-2)
if isinstance(d_jax, dist.CAR):
pytest.skip("CAR distribution does not have `variance` implemented.")
if isinstance(d_jax, dist.Gompertz):
pytest.skip("Gompertz distribution does not have `variance` implemented.")
if jnp.all(jnp.isfinite(d_jax.variance)):
assert_allclose(
jnp.std(samples, 0), jnp.sqrt(d_jax.variance), rtol=0.05, atol=1e-2
Expand Down

0 comments on commit 849d4cf

Please sign in to comment.