Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add log normal negative binomial distributions #3010

Merged
merged 16 commits into from
Jan 27, 2022
Merged
7 changes: 7 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,13 @@ LKJCorrCholesky
:undoc-members:
:show-inheritance:

LogNormalNegativeBinomial
-------------------------
.. autoclass:: pyro.distributions.LogNormalNegativeBinomial
:members:
:undoc-members:
:show-inheritance:

Logistic
--------
.. autoclass:: pyro.distributions.Logistic
Expand Down
2 changes: 2 additions & 0 deletions pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from pyro.distributions.improper_uniform import ImproperUniform
from pyro.distributions.inverse_gamma import InverseGamma
from pyro.distributions.lkj import LKJ, LKJCorrCholesky
from pyro.distributions.log_normal_negative_binomial import LogNormalNegativeBinomial
from pyro.distributions.logistic import Logistic, SkewLogistic
from pyro.distributions.mixture import MaskedMixture
from pyro.distributions.multivariate_studentt import MultivariateStudentT
Expand Down Expand Up @@ -124,6 +125,7 @@
"LKJCorrCholesky",
"LinearHMM",
"Logistic",
"LogNormalNegativeBinomial",
"MaskedDistribution",
"MaskedMixture",
"MixtureOfDiagNormals",
Expand Down
152 changes: 152 additions & 0 deletions pyro/distributions/log_normal_negative_binomial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.distributions import constraints
from torch.distributions.utils import broadcast_all, lazy_property

from pyro.distributions.torch import NegativeBinomial
from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.util import broadcast_shape
from pyro.ops.special import get_quad_rule


class LogNormalNegativeBinomial(TorchDistribution):
r"""
A three-parameter generalization of the Negative Binomial distribution [1].
It can be understood as a continuous mixture of Negative Binomial distributions
in which we inject Normally-distributed noise into the logits of the Negative
Binomial distribution:

.. math::

\begin{eqnarray}
&\rm{LNNB}(y | \rm{total\_count}=\nu, \rm{logits}=\ell, \rm{multiplicative\_noise\_scale}=sigma) = \\
&\int d\epsilon \mathcal{N}(\epsilon | 0, \sigma)
\rm{NB}(y | \rm{total\_count}=\nu, \rm{logits}=\ell + \epsilon)
\end{eqnarray}

where :math:`y \ge 0` is a non-negative integer. Thus while a Negative Binomial distribution
can be formulated as a Poisson distribution with a Gamma-distributed rate, this distribution
adds an additional level of variability by also modulating the rate by Log Normally-distributed
multiplicative noise.

This distribution has a mean given by

.. math::
\mathbb{E}[y] = \nu e^{\ell} = e^{\ell + \log \nu + \tfrac{1}{2}\sigma^2}

and a variance given by

.. math::
\rm{Var}[y] = \mathbb{E}[y] + \left( e^{\sigma^2} (1 + 1/\nu) - 1 \right) \left( \mathbb{E}[y] \right)^2

Thus while a given mean and variance together uniquely characterize a Negative Binomial distribution, there is a
one-dimensional family of Log Normal Negative Binomial distributions with a given mean and variance.

Note that in some applications it may be useful to parameterize the logits as

.. math::
\ell = \ell^\prime - \log \nu - \tfrac{1}{2}\sigma^2

so that the mean is given by :math:`\mathbb{E}[y] = e^{\ell^\prime}` and does not depend on :math:`\nu`
and :math:`\sigma`, which serve to determine the higher moments.

References:

[1] "Lognormal and Gamma Mixed Negative Binomial Regression,"
Mingyuan Zhou, Lingbo Li, David Dunson, and Lawrence Carin.

:param total_count: non-negative number of negative Bernoulli trials. The variance decreases
as `total_count` increases.
:type total_count: float or torch.Tensor
:param torch.Tensor logits: Event log-odds for probabilities of success for underlying
Negative Binomial distribution.
:param torch.Tensor multiplicative_noise_scale: Controls the level of the injected Normal logit noise.
:param int num_quad_points: Number of quadrature points used to compute the (approximate) `log_prob`.
Defaults to 8.
"""
arg_constraints = {
"total_count": constraints.greater_than_eq(0),
"logits": constraints.real,
"multiplicative_noise_scale": constraints.positive,
}
support = constraints.nonnegative_integer

def __init__(
self,
total_count,
logits,
multiplicative_noise_scale,
*,
num_quad_points=8,
martinjankowiak marked this conversation as resolved.
Show resolved Hide resolved
validate_args=None,
):
if num_quad_points < 1:
raise ValueError("num_quad_points must be positive.")

total_count, logits, multiplicative_noise_scale = broadcast_all(
total_count, logits, multiplicative_noise_scale
)

self.quad_points, self.log_weights = get_quad_rule(num_quad_points, logits)
quad_logits = (
logits.unsqueeze(-1)
+ multiplicative_noise_scale.unsqueeze(-1) * self.quad_points
)
self.nb_dist = NegativeBinomial(
total_count=total_count.unsqueeze(-1), logits=quad_logits
)

self.multiplicative_noise_scale = multiplicative_noise_scale
self.total_count = total_count
self.logits = logits
self.num_quad_points = num_quad_points

batch_shape = broadcast_shape(
multiplicative_noise_scale.shape, self.nb_dist.batch_shape[:-1]
)
event_shape = torch.Size()

super().__init__(batch_shape, event_shape, validate_args)

def log_prob(self, value):
nb_log_prob = self.nb_dist.log_prob(value.unsqueeze(-1))
return torch.logsumexp(self.log_weights + nb_log_prob, axis=-1)

def sample(self, sample_shape=torch.Size()):
raise NotImplementedError

def expand(self, batch_shape, _instance=None):
fritzo marked this conversation as resolved.
Show resolved Hide resolved
new = self._get_checked_instance(type(self), _instance)
batch_shape = torch.Size(batch_shape)
total_count = self.total_count.expand(batch_shape)
logits = self.logits.expand(batch_shape)
multiplicative_noise_scale = self.multiplicative_noise_scale.expand(batch_shape)
LogNormalNegativeBinomial.__init__(
new,
total_count,
logits,
multiplicative_noise_scale,
num_quad_points=self.num_quad_points,
validate_args=False,
)
new._validate_args = self._validate_args
return new

@lazy_property
def mean(self):
return torch.exp(
self.logits
+ self.total_count.log()
+ 0.5 * self.multiplicative_noise_scale.pow(2.0)
)

@lazy_property
def variance(self):
kappa = (
torch.exp(self.multiplicative_noise_scale.pow(2.0))
* (1 + 1 / self.total_count)
- 1
)
return self.mean + kappa * self.mean.pow(2.0)
28 changes: 28 additions & 0 deletions pyro/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import math
import operator

import numpy as np
import torch
from numpy.polynomial.hermite import hermgauss


class _SafeLog(torch.autograd.Function):
Expand Down Expand Up @@ -151,3 +153,29 @@ def log_I1(orders: int, value: torch.Tensor, terms=250):
i1s = lvalues[..., :orders].T + seqs
assert i1s.shape == (orders, vshape.numel())
return i1s.view(-1, *vshape)


def get_quad_rule(num_quad, prototype_tensor):
r"""
Get quadrature points and corresponding log weights for a Gauss Hermite quadrature rule
with the specified number of quadrature points.

Example usage::

quad_points, log_weights = get_quad_rule(32, prototype_tensor)
# transform to N(0, 4.0) Normal distribution
quad_points *= 4.0
# compute variance integral in log-space using logsumexp and exponentiate
variance = torch.logsumexp(quad_points.pow(2.0).log() + log_weights, axis=0).exp()
assert (variance - 16.0).abs().item() < 1.0e-6

:param int num_quad: number of quadrature points.
:param torch.Tensor prototype_tensor: used to determine `dtype` and `device` of returned tensors.
:return: tuple of `torch.Tensor`s of the form `(quad_points, log_weights)`
"""
quad_rule = hermgauss(num_quad)
quad_points = quad_rule[0] * np.sqrt(2.0)
log_weights = np.log(quad_rule[1]) - 0.5 * np.log(np.pi)
return torch.from_numpy(quad_points).type_as(prototype_tensor), torch.from_numpy(
log_weights
).type_as(prototype_tensor)
18 changes: 18 additions & 0 deletions tests/distributions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,24 @@ def __init__(self, von_loc, von_conc, skewness):
prec=0.08,
is_discrete=True,
),
Fixture(
pyro_dist=dist.LogNormalNegativeBinomial,
examples=[
{
"logits": [0.6],
"total_count": 8,
"multiplicative_noise_scale": [0.1],
"test_data": [4.0],
},
{
"logits": [0.2, 0.4],
"multiplicative_noise_scale": [0.1, 0.2],
"total_count": [[8.0, 7.0], [5.0, 9.0]],
"test_data": [[6.0, 3.0], [2.0, 8.0]],
},
],
is_discrete=True,
),
]


Expand Down
9 changes: 8 additions & 1 deletion tests/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,17 @@ def test_distribution_validate_args(dist_class, args, validate_args):
def check_sample_shapes(small, large):
dist_instance = small
if isinstance(
dist_instance, (dist.LogNormal, dist.LowRankMultivariateNormal, dist.VonMises)
dist_instance,
(
dist.LogNormal,
dist.LowRankMultivariateNormal,
dist.VonMises,
dist.LogNormalNegativeBinomial,
),
):
# Ignore broadcasting bug in LogNormal:
# https://github.com/pytorch/pytorch/pull/7269
# LogNormalNegativeBinomial has no sample method
return
x = small.sample()
assert_equal(small.log_prob(x).expand(large.batch_shape), large.log_prob(x))
Expand Down
47 changes: 47 additions & 0 deletions tests/distributions/test_log_normal_negative_binomial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from pyro.distributions import LogNormalNegativeBinomial
from tests.common import assert_close


@pytest.mark.parametrize("num_quad_points", [2, 4])
@pytest.mark.parametrize("shape", [(2,), (4, 3)])
def test_lnnb_shapes(num_quad_points, shape):
logits = torch.randn(shape)
total_count = 5.0
multiplicative_noise_scale = torch.rand(shape)

d = LogNormalNegativeBinomial(
total_count, logits, multiplicative_noise_scale, num_quad_points=num_quad_points
)

assert d.batch_shape == shape
assert d.log_prob(torch.ones(shape)).shape == shape

assert d.expand(shape + shape).batch_shape == shape + shape
assert d.expand(shape + shape).log_prob(torch.ones(shape)).shape == shape + shape


@pytest.mark.parametrize("total_count", [0.5, 4.0])
@pytest.mark.parametrize("multiplicative_noise_scale", [0.01, 0.25])
def test_lnnb_mean_variance(
total_count, multiplicative_noise_scale, num_quad_points=128, N=512
):
logits = torch.tensor(2.0)
d = LogNormalNegativeBinomial(
total_count, logits, multiplicative_noise_scale, num_quad_points=num_quad_points
)

values = torch.arange(N)
probs = d.log_prob(values).exp()
assert_close(1.0, probs.sum().item(), atol=1.0e-6)

expected_mean = (probs * values).sum()
assert_close(expected_mean, d.mean, atol=1.0e-6, rtol=1.0e-5)

expected_var = (probs * (values - d.mean).pow(2.0)).sum()
assert_close(expected_var, d.variance, atol=1.0e-6, rtol=1.0e-5)
10 changes: 9 additions & 1 deletion tests/ops/test_special.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import tensor
from torch.autograd import grad

from pyro.ops.special import log_beta, log_binomial, log_I1, safe_log
from pyro.ops.special import get_quad_rule, log_beta, log_binomial, log_I1, safe_log
from tests.common import assert_equal


Expand Down Expand Up @@ -93,3 +93,11 @@ def test_log_I1_shapes():
assert_equal(log_I1(10, tensor([[0.6]])).shape, torch.Size([11, 1, 1]))
assert_equal(log_I1(10, tensor([0.6, 0.2])).shape, torch.Size([11, 2]))
assert_equal(log_I1(0, tensor(0.6)).shape, torch.Size((1, 1)))


@pytest.mark.parametrize("sigma", [0.5, 1.25])
def test_get_quad_rule(sigma):
quad_points, log_weights = get_quad_rule(32, torch.zeros(1))
quad_points *= sigma # transform to N(0, sigma) gaussian
variance = torch.logsumexp(quad_points.pow(2.0).log() + log_weights, axis=0).exp()
assert_equal(sigma ** 2, variance.item())