Skip to content

Commit

Permalink
swap pareto alpha and scale (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored and neerajprad committed May 16, 2019
1 parent bb1aab7 commit 80937d0
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 5 deletions.
2 changes: 1 addition & 1 deletion examples/baseball.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def partially_pooled(at_bats, hits=None):
"""
num_players = at_bats.shape[0]
m = sample("m", dist.Uniform(np.array([0.]), np.array([1.])))
kappa = sample("kappa", dist.Pareto(np.array([1.]), np.array([1.5])))
kappa = sample("kappa", dist.Pareto(np.array([1.5])))
shape = np.shape(kappa)[:np.ndim(kappa) - 1] + (num_players,)
phi_prior = dist.Beta(np.broadcast_to(m * kappa, shape),
np.broadcast_to((1 - m) * kappa, shape))
Expand Down
4 changes: 1 addition & 3 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,9 +503,7 @@ class Pareto(TransformedDistribution):
arg_constraints = {'alpha': constraints.positive, 'scale': constraints.positive}
support = constraints.real

# FIXME: should we use `concentration` and `scale` for consistence with other distributions
# tensorflow uses `concentration` and `scale` though
def __init__(self, scale, alpha, validate_args=None):
def __init__(self, alpha, scale=1., validate_args=None):
batch_shape = lax.broadcast_shapes(np.shape(scale), np.shape(alpha))
self.scale, self.alpha = np.broadcast_to(scale, batch_shape), np.broadcast_to(alpha, batch_shape)
base_dist = Exponential(self.alpha)
Expand Down
2 changes: 1 addition & 1 deletion test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __new__(cls, jax_dist, *params):
dist.MultinomialLogits: lambda logits, total_count: osp.multinomial(n=total_count,
p=_to_probs_multinom(logits)),
dist.Normal: lambda loc, scale: osp.norm(loc=loc, scale=scale),
dist.Pareto: lambda scale, alpha: osp.pareto(alpha, scale=scale),
dist.Pareto: lambda alpha, scale: osp.pareto(alpha, scale=scale),
dist.Poisson: lambda rate: osp.poisson(rate),
dist.StudentT: lambda df, loc, scale: osp.t(df=df, loc=loc, scale=scale),
dist.Uniform: lambda a, b: osp.uniform(a, b - a),
Expand Down

0 comments on commit 80937d0

Please sign in to comment.