Skip to content

Commit

Permalink
Kumaraswamy distribution bug fixes (#1675)
Browse files Browse the repository at this point in the history
* Minor Kumaraswamy dist bug fixes

* Removing intermediates from Kuma log_prob again because no longer necessary

---------

Co-authored-by: Eike Petersen <ewipe@dtu.dk>
  • Loading branch information
e-pet and e-pet committed Nov 10, 2023
1 parent 24c21b8 commit f5bd186
Showing 1 changed file with 2 additions and 8 deletions.
10 changes: 2 additions & 8 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ def icdf(self, q):
return self.loc - self.scale * jnp.log(-jnp.log(q))


class Kumaraswamy(TransformedDistribution):
class Kumaraswamy(Distribution):
arg_constraints = {
"concentration1": constraints.positive,
"concentration0": constraints.positive,
Expand All @@ -786,13 +786,7 @@ def __init__(self, concentration1, concentration0, *, validate_args=None):
batch_shape = lax.broadcast_shapes(
jnp.shape(concentration1), jnp.shape(concentration0)
)
base_dist = Uniform(0, 1).expand(batch_shape)
transforms = [
PowerTransform(1 / concentration0),
AffineTransform(1, -1),
PowerTransform(1 / concentration1),
]
super().__init__(base_dist, transforms, validate_args=validate_args)
super().__init__(batch_shape=batch_shape, validate_args=validate_args)

def sample(self, key, sample_shape=()):
assert is_prng_key(key)
Expand Down

0 comments on commit f5bd186

Please sign in to comment.