Skip to content

Commit

Permalink
Unifying wrappers for distributions with logits/probs parametrization (
Browse files Browse the repository at this point in the history
…#141)

* Unifying wrappers for distributions with logits and probs parametrization

* add multinomial to init

* fix categorical wrapper
  • Loading branch information
neerajprad authored and fehiepsi committed May 9, 2019
1 parent fed0bc6 commit 7da9574
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 62 deletions.
24 changes: 16 additions & 8 deletions numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,32 @@
)
from numpyro.distributions.discrete import (
Bernoulli,
BernoulliWithLogits,
BernoulliLogits,
BernoulliProbs,
Binomial,
BinomialWithLogits,
BinomialLogits,
BinomialProbs,
Categorical,
CategoricalWithLogits,
CategoricalLogits,
CategoricalProbs,
Multinomial,
MultinomialWithLogits,
MultinomialLogits,
MultinomialProbs,
Poisson
)
from numpyro.distributions.distribution import Distribution, TransformedDistribution

__all__ = [
'Bernoulli',
'BernoulliWithLogits',
'BernoulliLogits',
'BernoulliProbs',
'Beta',
'Binomial',
'BinomialWithLogits',
'BinomialLogits',
'BinomialProbs',
'Categorical',
'CategoricalWithLogits',
'CategoricalLogits',
'CategoricalProbs',
'Cauchy',
'Chi2',
'Dirichlet',
Expand All @@ -46,7 +53,8 @@
'LogNormal',
'LKJCholesky',
'Multinomial',
'MultinomialWithLogits',
'MultinomialLogits',
'MultinomialProbs',
'Normal',
'Pareto',
'Poisson',
Expand Down
80 changes: 58 additions & 22 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ def _to_logits_multinom(probs):
return np.clip(np.log(probs), a_min=minval)


class Bernoulli(Distribution):
class BernoulliProbs(Distribution):
arg_constraints = {'probs': constraints.unit_interval}
support = constraints.boolean

def __init__(self, probs, validate_args=None):
self.probs = probs
super(Bernoulli, self).__init__(batch_shape=np.shape(self.probs), validate_args=validate_args)
super(BernoulliProbs, self).__init__(batch_shape=np.shape(self.probs), validate_args=validate_args)

def sample(self, key, size=()):
return random.bernoulli(key, self.probs, shape=size + self.batch_shape)
Expand All @@ -89,13 +89,13 @@ def variance(self):
return self.probs * (1 - self.probs)


class BernoulliWithLogits(Distribution):
class BernoulliLogits(Distribution):
arg_constraints = {'logits': constraints.real}
support = constraints.boolean

def __init__(self, logits=None, validate_args=None):
self.logits = logits
super(BernoulliWithLogits, self).__init__(batch_shape=np.shape(self.logits), validate_args=validate_args)
super(BernoulliLogits, self).__init__(batch_shape=np.shape(self.logits), validate_args=validate_args)

def sample(self, key, size=()):
return random.bernoulli(key, self.probs, shape=size + self.batch_shape)
Expand All @@ -120,14 +120,23 @@ def variance(self):
return self.probs * (1 - self.probs)


class Binomial(Distribution):
def Bernoulli(probs=None, logits=None, validate_args=None):
if probs is not None:
return BernoulliProbs(probs, validate_args=validate_args)
elif logits is not None:
return BernoulliLogits(logits, validate_args=validate_args)
else:
raise ValueError('One of `probs` or `logits` must be specified.')


class BinomialProbs(Distribution):
arg_constraints = {'total_count': constraints.nonnegative_integer,
'probs': constraints.unit_interval}

def __init__(self, probs, total_count=1, validate_args=None):
self.probs, self.total_count = promote_shapes(probs, total_count)
batch_shape = lax.broadcast_shapes(np.shape(probs), np.shape(total_count))
super(Binomial, self).__init__(batch_shape=batch_shape, validate_args=validate_args)
super(BinomialProbs, self).__init__(batch_shape=batch_shape, validate_args=validate_args)

def sample(self, key, size=()):
return binomial(key, self.probs, n=self.total_count, shape=size + self.batch_shape)
Expand Down Expand Up @@ -157,14 +166,14 @@ def support(self):
return constraints.integer_interval(0, self.total_count)


class BinomialWithLogits(Distribution):
class BinomialLogits(Distribution):
arg_constraints = {'total_count': constraints.nonnegative_integer,
'logits': constraints.real}

def __init__(self, logits, total_count=1, validate_args=None):
self.logits, self.total_count = promote_shapes(logits, total_count)
batch_shape = lax.broadcast_shapes(np.shape(logits), np.shape(total_count))
super(BinomialWithLogits, self).__init__(batch_shape=batch_shape, validate_args=validate_args)
super(BinomialLogits, self).__init__(batch_shape=batch_shape, validate_args=validate_args)

def sample(self, key, size=()):
return binomial(key, self.probs, n=self.total_count, shape=size + self.batch_shape)
Expand Down Expand Up @@ -200,15 +209,24 @@ def support(self):
return constraints.integer_interval(0, self.total_count)


class Categorical(Distribution):
def Binomial(total_count=1, probs=None, logits=None, validate_args=None):
if probs is not None:
return BinomialProbs(probs, total_count, validate_args=validate_args)
elif logits is not None:
return BinomialLogits(logits, total_count, validate_args=validate_args)
else:
raise ValueError('One of `probs` or `logits` must be specified.')


class CategoricalProbs(Distribution):
arg_constraints = {'probs': constraints.simplex}

def __init__(self, probs, validate_args=None):
if np.ndim(probs) < 1:
raise ValueError("`probs` parameter must be at least one-dimensional.")
self.probs = probs
super(Categorical, self).__init__(batch_shape=np.shape(self.probs)[:-1],
validate_args=validate_args)
super(CategoricalProbs, self).__init__(batch_shape=np.shape(self.probs)[:-1],
validate_args=validate_args)

def sample(self, key, size=()):
return categorical(key, self.probs, shape=size + self.batch_shape)
Expand Down Expand Up @@ -236,16 +254,16 @@ def support(self):
return constraints.integer_interval(0, np.shape(self.probs)[-1])


class CategoricalWithLogits(Distribution):
class CategoricalLogits(Distribution):
arg_constraints = {'logits': constraints.real}

def __init__(self, logits, validate_args=None):
if np.ndim(logits) < 1:
raise ValueError("`logits` parameter must be at least one-dimensional.")
logits = logits - logsumexp(logits)
self.logits = logits
super(CategoricalWithLogits, self).__init__(batch_shape=np.shape(logits)[:-1],
validate_args=validate_args)
super(CategoricalLogits, self).__init__(batch_shape=np.shape(logits)[:-1],
validate_args=validate_args)

def sample(self, key, size=()):
return categorical(key, self.probs, shape=size + self.batch_shape)
Expand Down Expand Up @@ -275,7 +293,16 @@ def support(self):
return constraints.integer_interval(0, np.shape(self.logits)[-1])


class Multinomial(Distribution):
def Categorical(probs=None, logits=None, validate_args=None):
if probs is not None:
return CategoricalProbs(probs, validate_args=validate_args)
elif logits is not None:
return CategoricalLogits(logits, validate_args=validate_args)
else:
raise ValueError('One of `probs` or `logits` must be specified.')


class MultinomialProbs(Distribution):
arg_constraints = {'total_count': constraints.nonnegative_integer,
'probs': constraints.simplex}

Expand All @@ -285,9 +312,9 @@ def __init__(self, probs, total_count=1, validate_args=None):
batch_shape = lax.broadcast_shapes(np.shape(probs)[:-1], np.shape(total_count))
self.probs = promote_shapes(probs, shape=batch_shape + np.shape(probs)[-1:])[0]
self.total_count = promote_shapes(total_count, shape=batch_shape)[0]
super(Multinomial, self).__init__(batch_shape=batch_shape,
event_shape=np.shape(self.probs)[-1:],
validate_args=validate_args)
super(MultinomialProbs, self).__init__(batch_shape=batch_shape,
event_shape=np.shape(self.probs)[-1:],
validate_args=validate_args)

def sample(self, key, size=()):
return multinomial(key, self.probs, self.total_count, shape=size + self.batch_shape)
Expand All @@ -313,7 +340,7 @@ def support(self):
return constraints.multinomial(self.total_count)


class MultinomialWithLogits(Distribution):
class MultinomialLogits(Distribution):
arg_constraints = {'total_count': constraints.nonnegative_integer,
'logits': constraints.real}

Expand All @@ -324,9 +351,9 @@ def __init__(self, logits, total_count=1, validate_args=None):
logits = logits - logsumexp(logits)
self.logits = promote_shapes(logits, shape=batch_shape + np.shape(logits)[-1:])[0]
self.total_count = promote_shapes(total_count, shape=batch_shape)[0]
super(MultinomialWithLogits, self).__init__(batch_shape=batch_shape,
event_shape=np.shape(self.logits)[-1:],
validate_args=validate_args)
super(MultinomialLogits, self).__init__(batch_shape=batch_shape,
event_shape=np.shape(self.logits)[-1:],
validate_args=validate_args)

def sample(self, key, size=()):
return multinomial(key, self.probs, self.total_count, shape=size + self.batch_shape)
Expand Down Expand Up @@ -356,6 +383,15 @@ def support(self):
return constraints.multinomial(self.total_count)


def Multinomial(total_count=1, probs=None, logits=None, validate_args=None):
if probs is not None:
return MultinomialProbs(probs, total_count, validate_args=validate_args)
elif logits is not None:
return MultinomialLogits(logits, total_count, validate_args=validate_args)
else:
raise ValueError('One of `probs` or `logits` must be specified.')


class Poisson(Distribution):
arg_constraints = {'rate': constraints.positive}
support = constraints.nonnegative_integer
Expand Down
8 changes: 4 additions & 4 deletions numpyro/examples/baseball.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def fully_pooled(at_bats, hits=None):
"""
phi_prior = dist.Uniform(np.array([0.]), np.array([1.]))
phi = sample("phi", phi_prior)
return sample("obs", dist.Binomial(phi, at_bats), obs=hits)
return sample("obs", dist.Binomial(at_bats, probs=phi), obs=hits)


def not_pooled(at_bats, hits=None):
Expand All @@ -91,7 +91,7 @@ def not_pooled(at_bats, hits=None):
phi_prior = dist.Uniform(np.zeros((num_players,)),
np.ones((num_players,)))
phi = sample("phi", phi_prior)
return sample("obs", dist.Binomial(phi, at_bats), obs=hits)
return sample("obs", dist.Binomial(at_bats, probs=phi), obs=hits)


def partially_pooled(at_bats, hits=None):
Expand All @@ -113,7 +113,7 @@ def partially_pooled(at_bats, hits=None):
phi_prior = dist.Beta(np.broadcast_to(m * kappa, shape),
np.broadcast_to((1 - m) * kappa, shape))
phi = sample("phi", phi_prior)
return sample("obs", dist.Binomial(phi, at_bats), obs=hits)
return sample("obs", dist.Binomial(at_bats, probs=phi), obs=hits)


def partially_pooled_with_logit(at_bats, hits=None):
Expand All @@ -132,7 +132,7 @@ def partially_pooled_with_logit(at_bats, hits=None):
shape = np.shape(loc)[:np.ndim(loc) - 1] + (num_players,)
alpha = sample("alpha", dist.Normal(np.broadcast_to(loc, shape),
np.broadcast_to(scale, shape)))
return sample("obs", dist.BinomialWithLogits(alpha, at_bats), obs=hits)
return sample("obs", dist.Binomial(at_bats, logits=alpha), obs=hits)


def run_inference(model, at_bats, hits, rng, args):
Expand Down
2 changes: 1 addition & 1 deletion numpyro/examples/covtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def model(data, labels):
N, dim = data.shape
coefs = sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim)))
logits = np.dot(data, coefs)
return sample('obs', dist.BernoulliWithLogits(logits), obs=labels)
return sample('obs', dist.Bernoulli(logits=logits), obs=labels)


def benchmark_hmc(args, features, labels):
Expand Down
2 changes: 1 addition & 1 deletion numpyro/examples/ucbadmit.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def glmm(dept, male, applications, admit):
axis=-1)

logits = v_mu[..., :1] + v[..., dept, 0] + (v_mu[..., 1:] + v[..., dept, 1]) * male
sample('admit', dist.BinomialWithLogits(logits, applications), obs=admit)
sample('admit', dist.Binomial(applications, logits=logits), obs=admit)


def run_inference(dept, male, applications, admit, rng, args):
Expand Down
44 changes: 22 additions & 22 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ def __new__(cls, jax_dist, *params):


_DIST_MAP = {
dist.Bernoulli: lambda probs: osp.bernoulli(p=probs),
dist.BernoulliWithLogits: lambda logits: osp.bernoulli(p=_to_probs_bernoulli(logits)),
dist.BernoulliProbs: lambda probs: osp.bernoulli(p=probs),
dist.BernoulliLogits: lambda logits: osp.bernoulli(p=_to_probs_bernoulli(logits)),
dist.Beta: lambda con1, con0: osp.beta(con1, con0),
dist.Binomial: lambda probs, total_count: osp.binom(n=total_count, p=probs),
dist.BinomialWithLogits: lambda logits, total_count: osp.binom(n=total_count, p=_to_probs_bernoulli(logits)),
dist.BinomialProbs: lambda probs, total_count: osp.binom(n=total_count, p=probs),
dist.BinomialLogits: lambda logits, total_count: osp.binom(n=total_count, p=_to_probs_bernoulli(logits)),
dist.Cauchy: lambda loc, scale: osp.cauchy(loc=loc, scale=scale),
dist.Chi2: lambda df: osp.chi2(df),
dist.Dirichlet: lambda conc: osp.dirichlet(conc),
Expand All @@ -49,9 +49,9 @@ def __new__(cls, jax_dist, *params):
dist.HalfCauchy: lambda scale: osp.halfcauchy(scale=scale),
dist.HalfNormal: lambda scale: osp.halfnorm(scale=scale),
dist.LogNormal: lambda loc, scale: osp.lognorm(s=scale, scale=np.exp(loc)),
dist.Multinomial: lambda probs, total_count: osp.multinomial(n=total_count, p=probs),
dist.MultinomialWithLogits: lambda logits, total_count: osp.multinomial(n=total_count,
p=_to_probs_multinom(logits)),
dist.MultinomialProbs: lambda probs, total_count: osp.multinomial(n=total_count, p=probs),
dist.MultinomialLogits: lambda logits, total_count: osp.multinomial(n=total_count,
p=_to_probs_multinom(logits)),
dist.Normal: lambda loc, scale: osp.norm(loc=loc, scale=scale),
dist.Pareto: lambda scale, alpha: osp.pareto(alpha, scale=scale),
dist.Poisson: lambda rate: osp.poisson(rate),
Expand Down Expand Up @@ -105,21 +105,21 @@ def __new__(cls, jax_dist, *params):


DISCRETE = [
T(dist.Bernoulli, 0.2),
T(dist.Bernoulli, np.array([0.2, 0.7])),
T(dist.BernoulliWithLogits, np.array([-1., 3.])),
T(dist.Binomial, np.array([0.2, 0.7]), np.array([10, 2])),
T(dist.Binomial, np.array([0.2, 0.7]), np.array([5, 8])),
T(dist.BinomialWithLogits, np.array([-1., 3.]), np.array([5, 8])),
T(dist.Categorical, np.array([1.])),
T(dist.Categorical, np.array([0.1, 0.5, 0.4])),
T(dist.Categorical, np.array([[0.1, 0.5, 0.4], [0.4, 0.4, 0.2]])),
T(dist.CategoricalWithLogits, np.array([-5.])),
T(dist.CategoricalWithLogits, np.array([1., 2., -2.])),
T(dist.CategoricalWithLogits, np.array([[-1, 2., 3.], [3., -4., -2.]])),
T(dist.Multinomial, np.array([0.2, 0.7, 0.1]), 10),
T(dist.Multinomial, np.array([0.2, 0.7, 0.1]), np.array([5, 8])),
T(dist.MultinomialWithLogits, np.array([-1., 3.]), np.array([[5], [8]])),
T(dist.BernoulliProbs, 0.2),
T(dist.BernoulliProbs, np.array([0.2, 0.7])),
T(dist.BernoulliLogits, np.array([-1., 3.])),
T(dist.BinomialProbs, np.array([0.2, 0.7]), np.array([10, 2])),
T(dist.BinomialProbs, np.array([0.2, 0.7]), np.array([5, 8])),
T(dist.BinomialLogits, np.array([-1., 3.]), np.array([5, 8])),
T(dist.CategoricalProbs, np.array([1.])),
T(dist.CategoricalProbs, np.array([0.1, 0.5, 0.4])),
T(dist.CategoricalProbs, np.array([[0.1, 0.5, 0.4], [0.4, 0.4, 0.2]])),
T(dist.CategoricalLogits, np.array([-5.])),
T(dist.CategoricalLogits, np.array([1., 2., -2.])),
T(dist.CategoricalLogits, np.array([[-1, 2., 3.], [3., -4., -2.]])),
T(dist.MultinomialProbs, np.array([0.2, 0.7, 0.1]), 10),
T(dist.MultinomialProbs, np.array([0.2, 0.7, 0.1]), np.array([5, 8])),
T(dist.MultinomialLogits, np.array([-1., 3.]), np.array([[5], [8]])),
T(dist.Poisson, 2.),
T(dist.Poisson, np.array([2., 3., 5.])),
]
Expand Down
8 changes: 4 additions & 4 deletions test/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ def test_logistic_regression(algo):
data = random.normal(random.PRNGKey(0), (N, dim))
true_coefs = np.arange(1., dim + 1.)
logits = np.sum(true_coefs * data, axis=-1)
labels = dist.BernoulliWithLogits(logits).sample(random.PRNGKey(1))
labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

def model(labels):
coefs = sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim)))
logits = np.sum(coefs * data, axis=-1)
return sample('obs', dist.BernoulliWithLogits(logits), obs=labels)
return sample('obs', dist.Bernoulli(logits=logits), obs=labels)

init_params, potential_fn, transform_fn = initialize_model(random.PRNGKey(2), model, (labels,), {})
init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
Expand Down Expand Up @@ -149,9 +149,9 @@ def model(data):
p = sample('p', dist.Beta(1., 1.))
if with_logits:
logits = logit(p)
sample('obs', dist.BinomialWithLogits(logits, data['n']), obs=data['x'])
sample('obs', dist.Binomial(data['n'], logits=logits), obs=data['x'])
else:
sample('obs', dist.Binomial(p, data['n']), obs=data['x'])
sample('obs', dist.Binomial(data['n'], probs=p), obs=data['x'])

data = {'n': 5000000, 'x': 3849}
init_params, potential_fn, transform_fn = initialize_model(random.PRNGKey(2), model, (data,), {})
Expand Down

0 comments on commit 7da9574

Please sign in to comment.