Skip to content

Commit

Permalink
add log normal negative binomial distributions (#3010)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjankowiak committed Jan 27, 2022
1 parent 5ab7da2 commit 319c515
Show file tree
Hide file tree
Showing 8 changed files with 271 additions and 2 deletions.
7 changes: 7 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,13 @@ LKJCorrCholesky
:undoc-members:
:show-inheritance:

LogNormalNegativeBinomial
-------------------------
.. autoclass:: pyro.distributions.LogNormalNegativeBinomial
:members:
:undoc-members:
:show-inheritance:

Logistic
--------
.. autoclass:: pyro.distributions.Logistic
Expand Down
2 changes: 2 additions & 0 deletions pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from pyro.distributions.improper_uniform import ImproperUniform
from pyro.distributions.inverse_gamma import InverseGamma
from pyro.distributions.lkj import LKJ, LKJCorrCholesky
from pyro.distributions.log_normal_negative_binomial import LogNormalNegativeBinomial
from pyro.distributions.logistic import Logistic, SkewLogistic
from pyro.distributions.mixture import MaskedMixture
from pyro.distributions.multivariate_studentt import MultivariateStudentT
Expand Down Expand Up @@ -124,6 +125,7 @@
"LKJCorrCholesky",
"LinearHMM",
"Logistic",
"LogNormalNegativeBinomial",
"MaskedDistribution",
"MaskedMixture",
"MixtureOfDiagNormals",
Expand Down
152 changes: 152 additions & 0 deletions pyro/distributions/log_normal_negative_binomial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.distributions import constraints
from torch.distributions.utils import broadcast_all, lazy_property

from pyro.distributions.torch import NegativeBinomial
from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.util import broadcast_shape
from pyro.ops.special import get_quad_rule


class LogNormalNegativeBinomial(TorchDistribution):
r"""
A three-parameter generalization of the Negative Binomial distribution [1].
It can be understood as a continuous mixture of Negative Binomial distributions
in which we inject Normally-distributed noise into the logits of the Negative
Binomial distribution:
.. math::
\begin{eqnarray}
&\rm{LNNB}(y | \rm{total\_count}=\nu, \rm{logits}=\ell, \rm{multiplicative\_noise\_scale}=sigma) = \\
&\int d\epsilon \mathcal{N}(\epsilon | 0, \sigma)
\rm{NB}(y | \rm{total\_count}=\nu, \rm{logits}=\ell + \epsilon)
\end{eqnarray}
where :math:`y \ge 0` is a non-negative integer. Thus while a Negative Binomial distribution
can be formulated as a Poisson distribution with a Gamma-distributed rate, this distribution
adds an additional level of variability by also modulating the rate by Log Normally-distributed
multiplicative noise.
This distribution has a mean given by
.. math::
\mathbb{E}[y] = \nu e^{\ell} = e^{\ell + \log \nu + \tfrac{1}{2}\sigma^2}
and a variance given by
.. math::
\rm{Var}[y] = \mathbb{E}[y] + \left( e^{\sigma^2} (1 + 1/\nu) - 1 \right) \left( \mathbb{E}[y] \right)^2
Thus while a given mean and variance together uniquely characterize a Negative Binomial distribution, there is a
one-dimensional family of Log Normal Negative Binomial distributions with a given mean and variance.
Note that in some applications it may be useful to parameterize the logits as
.. math::
\ell = \ell^\prime - \log \nu - \tfrac{1}{2}\sigma^2
so that the mean is given by :math:`\mathbb{E}[y] = e^{\ell^\prime}` and does not depend on :math:`\nu`
and :math:`\sigma`, which serve to determine the higher moments.
References:
[1] "Lognormal and Gamma Mixed Negative Binomial Regression,"
Mingyuan Zhou, Lingbo Li, David Dunson, and Lawrence Carin.
:param total_count: non-negative number of negative Bernoulli trials. The variance decreases
as `total_count` increases.
:type total_count: float or torch.Tensor
:param torch.Tensor logits: Event log-odds for probabilities of success for underlying
Negative Binomial distribution.
:param torch.Tensor multiplicative_noise_scale: Controls the level of the injected Normal logit noise.
:param int num_quad_points: Number of quadrature points used to compute the (approximate) `log_prob`.
Defaults to 8.
"""
arg_constraints = {
"total_count": constraints.greater_than_eq(0),
"logits": constraints.real,
"multiplicative_noise_scale": constraints.positive,
}
support = constraints.nonnegative_integer

def __init__(
self,
total_count,
logits,
multiplicative_noise_scale,
*,
num_quad_points=8,
validate_args=None,
):
if num_quad_points < 1:
raise ValueError("num_quad_points must be positive.")

total_count, logits, multiplicative_noise_scale = broadcast_all(
total_count, logits, multiplicative_noise_scale
)

self.quad_points, self.log_weights = get_quad_rule(num_quad_points, logits)
quad_logits = (
logits.unsqueeze(-1)
+ multiplicative_noise_scale.unsqueeze(-1) * self.quad_points
)
self.nb_dist = NegativeBinomial(
total_count=total_count.unsqueeze(-1), logits=quad_logits
)

self.multiplicative_noise_scale = multiplicative_noise_scale
self.total_count = total_count
self.logits = logits
self.num_quad_points = num_quad_points

batch_shape = broadcast_shape(
multiplicative_noise_scale.shape, self.nb_dist.batch_shape[:-1]
)
event_shape = torch.Size()

super().__init__(batch_shape, event_shape, validate_args)

def log_prob(self, value):
nb_log_prob = self.nb_dist.log_prob(value.unsqueeze(-1))
return torch.logsumexp(self.log_weights + nb_log_prob, axis=-1)

def sample(self, sample_shape=torch.Size()):
raise NotImplementedError

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(type(self), _instance)
batch_shape = torch.Size(batch_shape)
total_count = self.total_count.expand(batch_shape)
logits = self.logits.expand(batch_shape)
multiplicative_noise_scale = self.multiplicative_noise_scale.expand(batch_shape)
LogNormalNegativeBinomial.__init__(
new,
total_count,
logits,
multiplicative_noise_scale,
num_quad_points=self.num_quad_points,
validate_args=False,
)
new._validate_args = self._validate_args
return new

@lazy_property
def mean(self):
return torch.exp(
self.logits
+ self.total_count.log()
+ 0.5 * self.multiplicative_noise_scale.pow(2.0)
)

@lazy_property
def variance(self):
kappa = (
torch.exp(self.multiplicative_noise_scale.pow(2.0))
* (1 + 1 / self.total_count)
- 1
)
return self.mean + kappa * self.mean.pow(2.0)
28 changes: 28 additions & 0 deletions pyro/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import math
import operator

import numpy as np
import torch
from numpy.polynomial.hermite import hermgauss


class _SafeLog(torch.autograd.Function):
Expand Down Expand Up @@ -151,3 +153,29 @@ def log_I1(orders: int, value: torch.Tensor, terms=250):
i1s = lvalues[..., :orders].T + seqs
assert i1s.shape == (orders, vshape.numel())
return i1s.view(-1, *vshape)


def get_quad_rule(num_quad, prototype_tensor):
r"""
Get quadrature points and corresponding log weights for a Gauss Hermite quadrature rule
with the specified number of quadrature points.
Example usage::
quad_points, log_weights = get_quad_rule(32, prototype_tensor)
# transform to N(0, 4.0) Normal distribution
quad_points *= 4.0
# compute variance integral in log-space using logsumexp and exponentiate
variance = torch.logsumexp(quad_points.pow(2.0).log() + log_weights, axis=0).exp()
assert (variance - 16.0).abs().item() < 1.0e-6
:param int num_quad: number of quadrature points.
:param torch.Tensor prototype_tensor: used to determine `dtype` and `device` of returned tensors.
:return: tuple of `torch.Tensor`s of the form `(quad_points, log_weights)`
"""
quad_rule = hermgauss(num_quad)
quad_points = quad_rule[0] * np.sqrt(2.0)
log_weights = np.log(quad_rule[1]) - 0.5 * np.log(np.pi)
return torch.from_numpy(quad_points).type_as(prototype_tensor), torch.from_numpy(
log_weights
).type_as(prototype_tensor)
18 changes: 18 additions & 0 deletions tests/distributions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,24 @@ def __init__(self, von_loc, von_conc, skewness):
prec=0.08,
is_discrete=True,
),
Fixture(
pyro_dist=dist.LogNormalNegativeBinomial,
examples=[
{
"logits": [0.6],
"total_count": 8,
"multiplicative_noise_scale": [0.1],
"test_data": [4.0],
},
{
"logits": [0.2, 0.4],
"multiplicative_noise_scale": [0.1, 0.2],
"total_count": [[8.0, 7.0], [5.0, 9.0]],
"test_data": [[6.0, 3.0], [2.0, 8.0]],
},
],
is_discrete=True,
),
]


Expand Down
9 changes: 8 additions & 1 deletion tests/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,17 @@ def test_distribution_validate_args(dist_class, args, validate_args):
def check_sample_shapes(small, large):
dist_instance = small
if isinstance(
dist_instance, (dist.LogNormal, dist.LowRankMultivariateNormal, dist.VonMises)
dist_instance,
(
dist.LogNormal,
dist.LowRankMultivariateNormal,
dist.VonMises,
dist.LogNormalNegativeBinomial,
),
):
# Ignore broadcasting bug in LogNormal:
# https://github.com/pytorch/pytorch/pull/7269
# LogNormalNegativeBinomial has no sample method
return
x = small.sample()
assert_equal(small.log_prob(x).expand(large.batch_shape), large.log_prob(x))
Expand Down
47 changes: 47 additions & 0 deletions tests/distributions/test_log_normal_negative_binomial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from pyro.distributions import LogNormalNegativeBinomial
from tests.common import assert_close


@pytest.mark.parametrize("num_quad_points", [2, 4])
@pytest.mark.parametrize("shape", [(2,), (4, 3)])
def test_lnnb_shapes(num_quad_points, shape):
logits = torch.randn(shape)
total_count = 5.0
multiplicative_noise_scale = torch.rand(shape)

d = LogNormalNegativeBinomial(
total_count, logits, multiplicative_noise_scale, num_quad_points=num_quad_points
)

assert d.batch_shape == shape
assert d.log_prob(torch.ones(shape)).shape == shape

assert d.expand(shape + shape).batch_shape == shape + shape
assert d.expand(shape + shape).log_prob(torch.ones(shape)).shape == shape + shape


@pytest.mark.parametrize("total_count", [0.5, 4.0])
@pytest.mark.parametrize("multiplicative_noise_scale", [0.01, 0.25])
def test_lnnb_mean_variance(
total_count, multiplicative_noise_scale, num_quad_points=128, N=512
):
logits = torch.tensor(2.0)
d = LogNormalNegativeBinomial(
total_count, logits, multiplicative_noise_scale, num_quad_points=num_quad_points
)

values = torch.arange(N)
probs = d.log_prob(values).exp()
assert_close(1.0, probs.sum().item(), atol=1.0e-6)

expected_mean = (probs * values).sum()
assert_close(expected_mean, d.mean, atol=1.0e-6, rtol=1.0e-5)

expected_var = (probs * (values - d.mean).pow(2.0)).sum()
assert_close(expected_var, d.variance, atol=1.0e-6, rtol=1.0e-5)
10 changes: 9 additions & 1 deletion tests/ops/test_special.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import tensor
from torch.autograd import grad

from pyro.ops.special import log_beta, log_binomial, log_I1, safe_log
from pyro.ops.special import get_quad_rule, log_beta, log_binomial, log_I1, safe_log
from tests.common import assert_equal


Expand Down Expand Up @@ -93,3 +93,11 @@ def test_log_I1_shapes():
assert_equal(log_I1(10, tensor([[0.6]])).shape, torch.Size([11, 1, 1]))
assert_equal(log_I1(10, tensor([0.6, 0.2])).shape, torch.Size([11, 2]))
assert_equal(log_I1(0, tensor(0.6)).shape, torch.Size((1, 1)))


@pytest.mark.parametrize("sigma", [0.5, 1.25])
def test_get_quad_rule(sigma):
quad_points, log_weights = get_quad_rule(32, torch.zeros(1))
quad_points *= sigma # transform to N(0, sigma) gaussian
variance = torch.logsumexp(quad_points.pow(2.0).log() + log_weights, axis=0).exp()
assert_equal(sigma ** 2, variance.item())

0 comments on commit 319c515

Please sign in to comment.