Skip to content

Commit

Permalink
Avoid Kumaraswamy numerical issues (#1681)
Browse files Browse the repository at this point in the history
* fix intersphinx issues

* fix kumaraswamy numerical issues
  • Loading branch information
fehiepsi committed Nov 19, 2023
1 parent 432c9e7 commit a5b3bab
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,18 +790,22 @@ def __init__(self, concentration1, concentration0, *, validate_args=None):

def sample(self, key, sample_shape=()):
assert is_prng_key(key)
minval = jnp.finfo(jnp.result_type(float)).tiny
u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval)
log_sample = jnp.log1p(-(u ** (1 / self.concentration0))) / self.concentration1
finfo = jnp.finfo(u)
finfo = jnp.finfo(jnp.result_type(float))
u = random.uniform(
key, shape=sample_shape + self.batch_shape, minval=finfo.tiny
)
u_con0 = jnp.clip(u ** (1 / self.concentration0), a_max=1 - finfo.eps)
log_sample = jnp.log1p(-u_con0) / self.concentration1
return jnp.clip(jnp.exp(log_sample), a_min=finfo.tiny, a_max=1 - finfo.eps)

@validate_sample
def log_prob(self, value):
finfo = jnp.finfo(jnp.result_type(float))
normalize_term = jnp.log(self.concentration0) + jnp.log(self.concentration1)
value_con1 = jnp.clip(value**self.concentration1, a_max=1 - finfo.eps)
return (
xlogy(self.concentration1 - 1, value)
+ xlog1py(self.concentration0 - 1, -(value**self.concentration1))
+ xlog1py(self.concentration0 - 1, -value_con1)
+ normalize_term
)

Expand Down

0 comments on commit a5b3bab

Please sign in to comment.