Skip to content

Commit

Permalink
Implement MVN distribution (#186)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored and neerajprad committed Jun 4, 2019
1 parent 99535dd commit 4d86b46
Show file tree
Hide file tree
Showing 9 changed files with 259 additions and 48 deletions.
4 changes: 3 additions & 1 deletion numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
HalfNormal,
LKJCholesky,
LogNormal,
MultivariateNormal,
Normal,
Pareto,
StudentT,
Expand Down Expand Up @@ -54,11 +55,12 @@
'GaussianRandomWalk',
'HalfCauchy',
'HalfNormal',
'LogNormal',
'LKJCholesky',
'LogNormal',
'Multinomial',
'MultinomialLogits',
'MultinomialProbs',
'MultivariateNormal',
'Normal',
'Pareto',
'Poisson',
Expand Down
68 changes: 59 additions & 9 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,19 @@
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

import math

import jax.numpy as np
from jax.scipy.special import expit, logit

from numpyro.distributions.util import cumprod, cumsum, matrix_to_tril_vec, signed_stick_breaking_tril, sum_rightmost
from numpyro.distributions.util import (
cumprod,
cumsum,
matrix_to_tril_vec,
signed_stick_breaking_tril,
sum_rightmost,
vec_to_tril_matrix
)

##########################################################
# CONSTRAINTS
Expand All @@ -38,8 +47,8 @@ def __call__(self, x):


class _Boolean(Constraint):
def __call__(self, value):
return (value == 0) | (value == 1)
def __call__(self, x):
return (x == 0) | (x == 1)


class _CorrCholesky(Constraint):
Expand Down Expand Up @@ -82,25 +91,39 @@ class _IntegerGreaterThan(Constraint):
def __init__(self, lower_bound):
self.lower_bound = lower_bound

def __call__(self, value):
return (value % 1 == 0) & (value >= self.lower_bound)
def __call__(self, x):
return (x % 1 == 0) & (x >= self.lower_bound)


class _Interval(Constraint):
def __init__(self, lower_bound, upper_bound):
self.lower_bound = lower_bound
self.upper_bound = upper_bound

def __call__(self, value):
return (value > self.lower_bound) & (value < self.upper_bound)
def __call__(self, x):
return (x > self.lower_bound) & (x < self.upper_bound)


class _LowerCholesky(Constraint):
def __call__(self, x):
tril = np.tril(x)
lower_triangular = np.all(np.reshape(tril == x, x.shape[:-2] + (-1,)), axis=-1)
positive_diagonal = np.all(np.diagonal(x, axis1=-2, axis2=-1) > 0, axis=-1)
return lower_triangular & positive_diagonal


class _Multinomial(Constraint):
def __init__(self, upper_bound):
self.upper_bound = upper_bound

def __call__(self, value):
return np.all(value >= 0, axis=-1) & (np.sum(value, -1) == self.upper_bound)
def __call__(self, x):
return np.all(x >= 0, axis=-1) & (np.sum(x, -1) == self.upper_bound)


class _PositiveDefinite(Constraint):
def __call__(self, x):
# check for the smallest eigenvalue is positive
return np.linalg.eigh(x)[0][..., 0] > 0


class _Real(Constraint):
Expand All @@ -123,10 +146,12 @@ def __call__(self, x):
integer_interval = _IntegerInterval
integer_greater_than = _IntegerGreaterThan
interval = _Interval
lower_cholesky = _LowerCholesky()
multinomial = _Multinomial
nonnegative_integer = _IntegerGreaterThan(0)
positive_integer = _IntegerGreaterThan(1)
positive = _GreaterThan(0.)
positive_definite = _PositiveDefinite()
real = _Real()
simplex = _Simplex()
unit_interval = _Interval(0., 1.)
Expand Down Expand Up @@ -336,6 +361,26 @@ def log_abs_det_jacobian(self, x, y):
return np.full(np.shape(x), 0.)


class LowerCholeskyTransform(Transform):
codomain = lower_cholesky
event_dim = 1

def __call__(self, x):
n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2)
z = vec_to_tril_matrix(x[..., :-n], diagonal=-1)
diag = np.exp(x[..., -n:])
return z + np.expand_dims(diag, axis=-1) * np.identity(n)

def inv(self, y):
z = matrix_to_tril_vec(y, diagonal=-1)
return np.concatenate([z, np.log(np.diagonal(y, axis1=-2, axis2=-1))], axis=-1)

def log_abs_det_jacobian(self, x, y):
# the jacobian is diagonal, so logdet is the sum of diagonal `exp` transform
n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2)
return x[..., -n:].sum(-1)


class SigmoidTransform(Transform):
codomain = unit_interval

Expand Down Expand Up @@ -433,6 +478,11 @@ def _transform_to_interval(constraint):
domain=unit_interval)])


@biject_to.register(lower_cholesky)
def _transform_to_lower_cholesky(constraint):
return LowerCholeskyTransform()


@biject_to.register(real)
def _transform_to_real(constraint):
return IdentityTransform()
Expand Down
141 changes: 124 additions & 17 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@
import jax.numpy as np
import jax.random as random
from jax import lax, ops
from jax.scipy.linalg import solve_triangular
from jax.scipy.special import gammaln, log_ndtr, ndtr, ndtri

from numpyro.distributions import constraints
from numpyro.distributions.constraints import AbsTransform, AffineTransform, ExpTransform
from numpyro.distributions.distribution import Distribution, TransformedDistribution
from numpyro.distributions.util import (
cholesky_inverse,
cumsum,
lazy_property,
matrix_to_tril_vec,
multigammaln,
promote_shapes,
Expand Down Expand Up @@ -462,52 +465,156 @@ def log_prob(self, value):


@copy_docs_from(Distribution)
class Normal(Distribution):
class LogNormal(TransformedDistribution):
arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
support = constraints.real
reparametrized_params = ['loc', 'scale']

def __init__(self, loc=0., scale=1., validate_args=None):
self.loc, self.scale = promote_shapes(loc, scale)
batch_shape = lax.broadcast_shapes(np.shape(loc), np.shape(scale))
super(Normal, self).__init__(batch_shape=batch_shape, validate_args=validate_args)
base_dist = Normal(loc, scale)
self.loc, self.scale = base_dist.loc, base_dist.scale
super(LogNormal, self).__init__(base_dist, ExpTransform(), validate_args=validate_args)

@property
def mean(self):
return np.exp(self.loc + self.scale ** 2 / 2)

@property
def variance(self):
return (np.exp(self.scale ** 2) - 1) * np.exp(2 * self.loc + self.scale ** 2)


def _batch_mahalanobis(bL, bx):
# NB: The following procedure handles the case: bL.shape = (i, 1, n, n), bx.shape = (i, j, n)
# because we don't want to broadcast bL to the shape (i, j, n).

# Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
# we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tril_solve
sample_ndim = bx.ndim - bL.ndim + 1 # size of sample_shape
out_shape = np.shape(bx)[:-1] # shape of output
# Reshape bx with the shape (..., 1, i, j, 1, n)
bx_new_shape = out_shape[:sample_ndim]
for (sL, sx) in zip(bL.shape[:-2], out_shape[sample_ndim:]):
bx_new_shape += (sx // sL, sL)
bx_new_shape += (-1,)
bx = np.reshape(bx, bx_new_shape)
# Permute bx to make it have shape (..., 1, j, i, 1, n)
permute_dims = (tuple(range(sample_ndim))
+ tuple(range(sample_ndim, bx.ndim - 1, 2))
+ tuple(range(sample_ndim + 1, bx.ndim - 1, 2))
+ (bx.ndim - 1,))
bx = np.transpose(bx, permute_dims)

# reshape to (-1, i, 1, n)
xt = np.reshape(bx, (-1,) + bL.shape[:-1])
# permute to (i, 1, n, -1)
xt = np.moveaxis(xt, 0, -1)
solve_bL_bx = solve_triangular(bL, xt, lower=True) # shape: (i, 1, n, -1)
M = np.sum(solve_bL_bx ** 2, axis=-2) # shape: (i, 1, -1)
# permute back to (-1, i, 1)
M = np.moveaxis(M, -1, 0)
# reshape back to (..., 1, j, i, 1)
M = np.reshape(M, bx.shape[:-1])
# permute back to (..., 1, i, j, 1)
permute_inv_dims = tuple(range(sample_ndim))
for i in range(bL.ndim - 2):
permute_inv_dims += (sample_ndim + i, len(out_shape) + i)
M = np.transpose(M, permute_inv_dims)
return np.reshape(M, out_shape)


@copy_docs_from(Distribution)
class MultivariateNormal(Distribution):
arg_constraints = {'loc': constraints.real,
'covariance_matrix': constraints.positive_definite,
'precision_matrix': constraints.positive_definite,
'scale_tril': constraints.lower_cholesky}
support = constraints.real
reparametrized_params = ['loc', 'covariance_matrix', 'precision_matrix', 'scale_tril']

def __init__(self, loc=0., covariance_matrix=None, precision_matrix=None, scale_tril=None,
validate_args=None):
if np.isscalar(loc):
loc = np.expand_dims(loc, axis=-1)
# temporary append a new axis to loc
loc = loc[..., np.newaxis]
if covariance_matrix is not None:
loc, self.covariance_matrix = promote_shapes(loc, covariance_matrix)
self.scale_tril = np.linalg.cholesky(self.covariance_matrix)
elif precision_matrix is not None:
loc, self.precision_matrix = promote_shapes(loc, precision_matrix)
self.scale_tril = cholesky_inverse(self.precision_matrix)
elif scale_tril is not None:
loc, self.scale_tril = promote_shapes(loc, scale_tril)
else:
raise ValueError('One of `covariance_matrix`, `precision_matrix`, `scale_tril`'
' must be specified.')
batch_shape = lax.broadcast_shapes(np.shape(loc)[:-2], np.shape(self.scale_tril)[:-2])
event_shape = np.shape(self.scale_tril)[-1:]
self.loc = np.broadcast_to(np.squeeze(loc, axis=-1), batch_shape + event_shape)
super(MultivariateNormal, self).__init__(batch_shape=batch_shape,
event_shape=event_shape,
validate_args=validate_args)

def sample(self, key, sample_shape=()):
eps = random.normal(key, shape=sample_shape + self.batch_shape)
return self.loc + eps * self.scale
eps = random.normal(key, shape=sample_shape + self.batch_shape + self.event_shape)
return self.loc + np.squeeze(np.matmul(self.scale_tril, eps[..., np.newaxis]), axis=-1)

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
normalize_term = np.log(self.scale) + np.log(np.sqrt(2 * np.pi))
return -((value - self.loc) ** 2) / (2.0 * self.scale ** 2) - normalize_term
M = _batch_mahalanobis(self.scale_tril, value - self.loc)
half_log_det = np.log(np.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(-1)
normalize_term = half_log_det + 0.5 * self.scale_tril.shape[-1] * np.log(2 * np.pi)
return - 0.5 * M - normalize_term

@lazy_property
def covariance_matrix(self):
return np.dot(self.scale_tril, self.scale_tril.T)

@lazy_property
def precision_matrix(self):
scale_tril_inv = np.linalg.inv(self.scale_tril)
return np.dot(scale_tril_inv.T, scale_tril_inv)

@property
def mean(self):
return np.broadcast_to(self.loc, self.batch_shape)
return self.loc

@property
def variance(self):
return np.broadcast_to(self.scale ** 2, self.batch_shape)
return np.broadcast_to(np.sum(self.scale_tril ** 2, axis=-1),
self.batch_shape + self.event_shape)


@copy_docs_from(Distribution)
class LogNormal(TransformedDistribution):
class Normal(Distribution):
arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
support = constraints.real
reparametrized_params = ['loc', 'scale']

def __init__(self, loc=0., scale=1., validate_args=None):
base_dist = Normal(loc, scale)
self.loc, self.scale = base_dist.loc, base_dist.scale
super(LogNormal, self).__init__(base_dist, ExpTransform(), validate_args=validate_args)
self.loc, self.scale = promote_shapes(loc, scale)
batch_shape = lax.broadcast_shapes(np.shape(loc), np.shape(scale))
super(Normal, self).__init__(batch_shape=batch_shape, validate_args=validate_args)

def sample(self, key, sample_shape=()):
eps = random.normal(key, shape=sample_shape + self.batch_shape)
return self.loc + eps * self.scale

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
normalize_term = np.log(np.sqrt(2 * np.pi) * self.scale)
value_scaled = (value - self.loc) / self.scale
return -0.5 * value_scaled ** 2 - normalize_term

@property
def mean(self):
return np.exp(self.loc + self.scale ** 2 / 2)
return np.broadcast_to(self.loc, self.batch_shape)

@property
def variance(self):
return (np.exp(self.scale ** 2) - 1) * np.exp(2 * self.loc + self.scale ** 2)
return np.broadcast_to(self.scale ** 2, self.batch_shape)


@copy_docs_from(Distribution)
Expand Down
4 changes: 3 additions & 1 deletion numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import jax.numpy as np

from numpyro.distributions.constraints import Transform, is_dependent
from numpyro.distributions.util import sum_rightmost
from numpyro.distributions.util import lazy_property, sum_rightmost


class Distribution(object):
Expand Down Expand Up @@ -71,6 +71,8 @@ def __init__(self, batch_shape=(), event_shape=(), validate_args=None):
self._validate_args = validate_args
if self._validate_args:
for param, constraint in self.arg_constraints.items():
if param not in self.__dict__ and isinstance(getattr(type(self), param), lazy_property):
continue
if is_dependent(constraint):
continue # skip constraints that cannot be checked
if not np.all(constraint(getattr(self, param))):
Expand Down
8 changes: 8 additions & 0 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,14 @@ def xlog1py(x, y):
batching.primitive_batchers[xlog1py.primitive] = _xlog1py_batching_rule


def cholesky_inverse(matrix):
# This formulation only takes the inverse of a triangular matrix
# which is more numerically stable.
# Refer to:
# https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
return np.swapaxes(np.linalg.inv(np.linalg.cholesky(matrix[..., ::-1, ::-1])[..., ::-1, ::-1]), -2, -1)


def entr(p):
return np.where(p < 0, -np.inf, -xlogy(p))

Expand Down

0 comments on commit 4d86b46

Please sign in to comment.