Skip to content

Commit

Permalink
Porting TruncatedNormal, TruncatedCauchy to new api (#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored and neerajprad committed May 13, 2019
1 parent 6005d54 commit 4781c5d
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 3 deletions.
4 changes: 4 additions & 0 deletions numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
Normal,
Pareto,
StudentT,
TruncatedCauchy,
TruncatedNormal,
Uniform
)
from numpyro.distributions.discrete import (
Expand Down Expand Up @@ -62,5 +64,7 @@
'Poisson',
'StudentT',
'TransformedDistribution',
'TruncatedCauchy',
'TruncatedNormal',
'Uniform',
]
92 changes: 91 additions & 1 deletion numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import jax.numpy as np
import jax.random as random
from jax import lax, ops
from jax.scipy.special import gammaln
from jax.scipy.special import gammaln, log_ndtr, ndtr, ndtri

from numpyro.distributions import constraints
from numpyro.distributions.constraints import AbsTransform, AffineTransform, ExpTransform
Expand Down Expand Up @@ -568,6 +568,96 @@ def variance(self):
return np.broadcast_to(var, self.batch_shape)


class TruncatedCauchy(Distribution):
arg_constraints = {'low': constraints.real, 'loc': constraints.real,
'scale': constraints.positive}
reparametrized_params = ['low', 'loc', 'scale']

def __init__(self, low=0., loc=0., scale=1., validate_args=None):
self.low, self.loc, self.scale = promote_shapes(low, loc, scale)
batch_shape = lax.broadcast_shapes(np.shape(low), np.shape(loc), np.shape(scale))
super(TruncatedCauchy, self).__init__(batch_shape=batch_shape, validate_args=validate_args)

def sample(self, key, size=()):
# We use inverse transform method:
# z ~ inv_cdf(U), where U ~ Uniform(cdf(low), cdf(high)).
# ~ Uniform(arctan(low), arctan(high)) / pi + 1/2
size = size + self.batch_shape
low = (self.low - self.loc) / self.scale
minval = np.arctan(low)
maxval = np.pi / 2
u = minval + random.uniform(key, shape=size) * (maxval - minval)
return self.loc + np.tan(u) * self.scale

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
low = (self.low - self.loc) / self.scale
# pi / 2 is arctan of self.high when that arg is supported
normalize_term = np.log(np.pi / 2 - np.arctan(low)) + np.log(self.scale)
return - np.log1p(((value - self.loc) / self.scale) ** 2) - normalize_term

# NB: these stats do not apply when arg `high` is supported
@property
def mean(self):
return np.full(self.batch_shape, np.nan)

@property
def variance(self):
return np.full(self.batch_shape, np.nan)

@property
def support(self):
return constraints.greater_than(self.low)


class TruncatedNormal(Distribution):
arg_constraints = {'low': constraints.real, 'loc': constraints.real,
'scale': constraints.positive}
reparametrized_params = ['low', 'loc', 'scale']

# TODO: support `high` arg
def __init__(self, low=0., loc=0., scale=1., validate_args=None):
self.low, self.loc, self.scale = promote_shapes(low, loc, scale)
batch_shape = lax.broadcast_shapes(np.shape(low), np.shape(loc), np.shape(scale))
self._normal = Normal(self.loc, self.scale)
super(TruncatedNormal, self).__init__(batch_shape=batch_shape, validate_args=validate_args)

def sample(self, key, size=()):
size = size + self.batch_shape
# We use inverse transform method:
# z ~ icdf(U), where U ~ Uniform(0, 1).
u = random.uniform(key, shape=size)
low = (self.low - self.loc) / self.scale
# Ref: https://en.wikipedia.org/wiki/Truncated_normal_distribution#Simulating
# icdf[cdf_a + u * (1 - cdf_a)] = icdf[1 - (1 - cdf_a)(1 - u)]
# = - icdf[(1 - cdf_a)(1 - u)]
return self.loc - ndtri(ndtr(-low) * (1 - u)) * self.scale

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
# log(cdf(high) - cdf(low)) = log(1 - cdf(low)) = log(cdf(-low))
low = (self.low - self.loc) / self.scale
return self._normal.log_prob(value) - log_ndtr(-low)

@property
def mean(self):
low = (self.low - self.loc) / self.scale
low_prob_scaled = np.exp(self._normal.log_prob(self.low)) * self.scale / ndtr(-low)
return self.loc + low_prob_scaled * self.scale

@property
def variance(self):
low = (self.low - self.loc) / self.scale
low_prob_scaled = np.exp(self._normal.log_prob(self.low)) * self.scale / ndtr(-low)
return self._normal.variance * (1 + low * low_prob_scaled - low_prob_scaled ** 2)

@property
def support(self):
return constraints.greater_than(self.low)


class Uniform(Distribution):
arg_constraints = {'low': constraints.dependent, 'high': constraints.dependent}
reparametrized_params = ['low', 'high']
Expand Down
18 changes: 16 additions & 2 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ def __new__(cls, jax_dist, *params):
T(dist.StudentT, 1., 1., 0.5),
T(dist.StudentT, 2., np.array([1., 2.]), 2.),
T(dist.StudentT, np.array([3, 5]), np.array([[1.], [2.]]), 2.),
T(dist.TruncatedCauchy, -1., 0., 1.),
T(dist.TruncatedCauchy, 1., 0., np.array([1., 2.])),
T(dist.TruncatedCauchy, np.array([-2., 2.]), np.array([0., 1.]), np.array([[1.], [2.]])),
T(dist.TruncatedNormal, -1., 0., 1.),
T(dist.TruncatedNormal, 1., -1., np.array([1., 2.])),
T(dist.TruncatedNormal, np.array([-2., 2.]), np.array([0., 1.]), np.array([[1.], [2.]])),
T(dist.Uniform, 0., 2.),
T(dist.Uniform, 1., np.array([2., 3.])),
T(dist.Uniform, np.array([0., 0.]), np.array([[2.], [3.]])),
Expand Down Expand Up @@ -231,7 +237,7 @@ def fn(args):
actual_grad = jax.grad(fn)(repara_params)
assert len(actual_grad) == len(repara_params)

eps = 1e-5
eps = 1e-3
for i in range(len(repara_params)):
args_lhs = [p if j != i else p - eps for j, p in enumerate(params)]
args_rhs = [p if j != i else p + eps for j, p in enumerate(params)]
Expand All @@ -240,7 +246,7 @@ def fn(args):
# finite diff approximation
expected_grad = (fn_rhs - fn_lhs) / (2. * eps)
assert np.shape(actual_grad[i]) == np.shape(repara_params[i])
assert_allclose(np.sum(actual_grad[i]), expected_grad, rtol=0.10)
assert_allclose(np.sum(actual_grad[i]), expected_grad, rtol=0.02)


@pytest.mark.parametrize('jax_dist, sp_dist, params', CONTINUOUS + DISCRETE)
Expand All @@ -257,6 +263,14 @@ def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit):
samples = jax_dist.sample(key=rng, size=prepend_shape)
assert jax_dist.log_prob(samples).shape == prepend_shape + jax_dist.batch_shape
if not sp_dist:
if isinstance(jax_dist, dist.TruncatedCauchy) or isinstance(jax_dist, dist.TruncatedNormal):
low, loc, scale = params
high = np.inf
sp_dist = osp.cauchy if isinstance(jax_dist, dist.TruncatedCauchy) else osp.norm
sp_dist = sp_dist(loc, scale)
expected = sp_dist.logpdf(samples) - np.log(sp_dist.cdf(high) - sp_dist.cdf(low))
assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5)
return
pytest.skip('no corresponding scipy distn.')
if _is_batched_multivariate(jax_dist):
pytest.skip('batching not allowed in multivariate distns.')
Expand Down

0 comments on commit 4781c5d

Please sign in to comment.