Implement lognormal distribution#12
Implement lognormal distribution#12neerajprad merged 6 commits intopyro-ppl:masterfrom fehiepsi:lognormal
Conversation
…our where (0, 1] is expected
| shapes = [np.shape(arg) for arg in args] + [np.shape(loc), np.shape(scale)] | ||
| size = lax.broadcast_shapes(*shapes) | ||
| else: | ||
| args = [np.reshape(arg, size) for arg in args] |
There was a problem hiding this comment.
@neerajprad I delete this statement because arg shape can not be reshaped into size when it is scalar. Please correct me if it is not a right way to do.
There was a problem hiding this comment.
I think we will need this for distributions with args other than loc, scale, but this is wrong. 😄. It should have been np.broadcast_to(arg, size) instead, which would work with scalars.
There was a problem hiding this comment.
yeah, we already promote args so broadcasting will work out of the box (we can defer the logic of broadcasting to rvs method in some specific distributions; however, I guess that we won't have to do it).
There was a problem hiding this comment.
That is a promotion based on mutual shapes, but we don't promote them to the shape denoted by size arg, but yes, its better to deal with this issue when we actually have it. :)
| # or it will take default value (which is None). | ||
| # Note: self.numargs is the number of shape parameters. | ||
| size = kwargs.pop('size', args.pop() if len(args) > (self.numargs + 2) else None) | ||
| args, loc, scale = self._parse_args(*args, **kwargs) |
There was a problem hiding this comment.
+1. Nice to rely on the default implementation as much as possible.
There was a problem hiding this comment.
yeah, this _parse_args method is so convenient! :D
| loc = kwargs.get('loc', args.popleft() if len(args) > 0 else 0) | ||
| scale = kwargs.get('scale', args.popleft() if len(args) > 0 else 1) | ||
| size = kwargs.get('size', args.popleft() if len(args) > 0 else None) | ||
| assert _is_prng_key(rng) |
numpyro/distributions/expon.py
Outdated
|
|
||
|
|
||
| class expon_gen(jax_continuous): | ||
| r"""An exponential continuous random variable. |
There was a problem hiding this comment.
Can we remove this docstring? We can just add a single one in distributions directing people to scipy.stats instead. Otherwise we might need to inherit scipy's documentation settings, which would be best avoided.
There was a problem hiding this comment.
Btw..no need to add any docstring to distributions just yet, we can do that later.
There was a problem hiding this comment.
yeah, I'll remove them. They are unnecessary.
numpyro/distributions/util.py
Outdated
| X = x * x | ||
| V = v * v * v | ||
| U = random.uniform(key, ()) | ||
| U = 1 - random.uniform(key, ()) |
There was a problem hiding this comment.
What's the rationale to changing these invocations to 1 - u?
There was a problem hiding this comment.
Thanks for asking! It seems that we don't need to take 1 - u here. I'll revert the change. I was worried that np.log(U) will throw something wrong but indeed it will give -inf when U=0 so the condition will fail; and for the boost, though it might give boost=0, we don't need to deal with it because we clamp the sample at the end. For exponential sampling, it is used to give samples in range (0, inf) instead of (0, inf].
# Theoretically
# Use 1 - u
-np.log(1 - np.array(1 - np.finfo(np.float32).eps, dtype=np.float32)) # max = 16
-np.log(1) # min = 0 with positive probability
-np.log(1 - np.array(np.finfo(np.float32).eps, dtype=np.float32)) # 1e-7: smallest positive sample
# Use u
-np.log(np.array(np.finfo(np.float32).tiny, dtype=np.float32)) # 87: largest finite sample
-np.log(0) # max = inf with positive probability
-np.log(1 - np.array(np.finfo(np.float32).eps, dtype=np.float32)) # min = 1e-7
The tricky part here is:
- numpy.random uses double precision => u is non-zero in most cases => using both u and 1-u would be fine in most cases.
- jax.random supports single precision. In single precision,
jax.random.uniformgenerates samples belong to the set0 & [finfo.eps, 1-finfo.eps]instead of0 & [finfo.tiny, 1-finfo.eps]. So both ways have minimum non-zero value is 1e-7 and maximum finite value is 16 (tested). However, using 1 - u will additionally give 0 samples with probability 1/10000000, and using u will additionally giveinfsamples with probability 1/10000000. So we have to clamp the result.
I don't find any way is better so I'll stick with np.log(u) because it is faster.
There was a problem hiding this comment.
Thanks for explaining this issue!
This PR also addresses shape parameters in
jax_continuousclass. With this change, implementing other distributions such as beta/gamma which includes shape parameters will be easier.