Skip to content

Commit

Permalink
Respect unnormalized logits input in Categorical/Multinomial dist (#145)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored and neerajprad committed May 12, 2019
1 parent dde3ebf commit 6fc0b1f
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
6 changes: 4 additions & 2 deletions numpyro/contrib/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from jax import lax
from jax.experimental.stax import softmax
from jax.numpy.lax_numpy import _promote_dtypes
from jax.scipy.special import digamma, gammaln
from jax.scipy.special import digamma, gammaln, logsumexp

from numpyro.contrib.distributions.discrete import binom
from numpyro.contrib.distributions.distribution import jax_continuous, jax_discrete, jax_multivariate
Expand Down Expand Up @@ -54,6 +54,8 @@ def logpmf(self, x, p):
x = np.broadcast_to(x, batch_shape + (1,))
p = np.broadcast_to(p, batch_shape + p.shape[-1:])
if self.is_logits:
# normalize log prob
p = p - logsumexp(p, axis=-1, keepdims=True)
# gather and remove the trailing dimension
return np.take_along_axis(p, x, axis=-1)[..., 0]
else:
Expand Down Expand Up @@ -127,7 +129,7 @@ def _event_shape(self, n, p):
def logpmf(self, x, n, p):
x, n, p = _promote_dtypes(x, n, p)
if self.is_logits:
return gammaln(n + 1) + np.sum(x * p - gammaln(x + 1), axis=-1)
return gammaln(n + 1) + np.sum(x * p - gammaln(x + 1), axis=-1) - n * logsumexp(p, axis=-1)
else:
return gammaln(n + 1) + np.sum(xlogy(x, p) - gammaln(x + 1), axis=-1)

Expand Down
8 changes: 4 additions & 4 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ class CategoricalLogits(Distribution):
def __init__(self, logits, validate_args=None):
if np.ndim(logits) < 1:
raise ValueError("`logits` parameter must be at least one-dimensional.")
logits = logits - logsumexp(logits)
self.logits = logits
super(CategoricalLogits, self).__init__(batch_shape=np.shape(logits)[:-1],
validate_args=validate_args)
Expand All @@ -272,7 +271,8 @@ def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
value = np.expand_dims(value, -1)
value, log_pmf = promote_shapes(value, self.logits)
log_pmf = self.logits - logsumexp(self.logits, axis=-1, keepdims=True)
value, log_pmf = promote_shapes(value, log_pmf)
value = value[..., :1]
return np.take_along_axis(log_pmf, value, -1)[..., 0]

Expand Down Expand Up @@ -348,7 +348,6 @@ def __init__(self, logits, total_count=1, validate_args=None):
if np.ndim(logits) < 1:
raise ValueError("`logits` parameter must be at least one-dimensional.")
batch_shape = lax.broadcast_shapes(np.shape(logits)[:-1], np.shape(total_count))
logits = logits - logsumexp(logits)
self.logits = promote_shapes(logits, shape=batch_shape + np.shape(logits)[-1:])[0]
self.total_count = promote_shapes(total_count, shape=batch_shape)[0]
super(MultinomialLogits, self).__init__(batch_shape=batch_shape,
Expand All @@ -364,7 +363,8 @@ def log_prob(self, value):
dtype = get_dtypes(self.logits)[0]
value = lax.convert_element_type(value, dtype)
total_count = lax.convert_element_type(self.total_count, dtype)
return gammaln(total_count + 1) + np.sum(value * self.logits - gammaln(value + 1), axis=-1)
normalize_term = total_count * logsumexp(self.logits, axis=-1) - gammaln(total_count + 1)
return np.sum(value * self.logits - gammaln(value + 1), axis=-1) - normalize_term

@lazy_property
def probs(self):
Expand Down
4 changes: 2 additions & 2 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def fn(args, value):
actual_grad = jax.grad(fn)(params, value)
assert len(actual_grad) == len(params)

eps = 1e-4
eps = 1e-3
for i in range(len(params)):
if np.result_type(params[i]) in (np.int32, np.int64):
continue
Expand All @@ -373,7 +373,7 @@ def fn(args, value):
# finite diff approximation
expected_grad = (fn_rhs - fn_lhs) / (2. * eps)
assert np.shape(actual_grad[i]) == np.shape(params[i])
assert_allclose(np.sum(actual_grad[i]), expected_grad, rtol=0.10, atol=1e-3)
assert_allclose(np.sum(actual_grad[i]), expected_grad, rtol=0.01, atol=1e-3)


@pytest.mark.parametrize('jax_dist, sp_dist, params', CONTINUOUS + DISCRETE)
Expand Down

0 comments on commit 6fc0b1f

Please sign in to comment.