Skip to content

Implement lognormal distribution#12

Merged
neerajprad merged 6 commits intopyro-ppl:masterfrom
fehiepsi:lognormal
Mar 3, 2019
Merged

Implement lognormal distribution#12
neerajprad merged 6 commits intopyro-ppl:masterfrom
fehiepsi:lognormal

Conversation

@fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Mar 2, 2019

This PR also addresses shape parameters in jax_continuous class. With this change, implementing other distributions such as beta/gamma which includes shape parameters will be easier.

shapes = [np.shape(arg) for arg in args] + [np.shape(loc), np.shape(scale)]
size = lax.broadcast_shapes(*shapes)
else:
args = [np.reshape(arg, size) for arg in args]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad I delete this statement because arg shape can not be reshaped into size when it is scalar. Please correct me if it is not a right way to do.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we will need this for distributions with args other than loc, scale, but this is wrong. 😄. It should have been np.broadcast_to(arg, size) instead, which would work with scalars.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, we already promote args so broadcasting will work out of the box (we can defer the logic of broadcasting to rvs method in some specific distributions; however, I guess that we won't have to do it).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a promotion based on mutual shapes, but we don't promote them to the shape denoted by size arg, but yes, its better to deal with this issue when we actually have it. :)

# or it will take default value (which is None).
# Note: self.numargs is the number of shape parameters.
size = kwargs.pop('size', args.pop() if len(args) > (self.numargs + 2) else None)
args, loc, scale = self._parse_args(*args, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. Nice to rely on the default implementation as much as possible.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, this _parse_args method is so convenient! :D

loc = kwargs.get('loc', args.popleft() if len(args) > 0 else 0)
scale = kwargs.get('scale', args.popleft() if len(args) > 0 else 1)
size = kwargs.get('size', args.popleft() if len(args) > 0 else None)
assert _is_prng_key(rng)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super useful!



class expon_gen(jax_continuous):
r"""An exponential continuous random variable.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this docstring? We can just add a single one in distributions directing people to scipy.stats instead. Otherwise we might need to inherit scipy's documentation settings, which would be best avoided.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw..no need to add any docstring to distributions just yet, we can do that later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I'll remove them. They are unnecessary.

X = x * x
V = v * v * v
U = random.uniform(key, ())
U = 1 - random.uniform(key, ())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the rationale to changing these invocations to 1 - u?

Copy link
Member Author

@fehiepsi fehiepsi Mar 2, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for asking! It seems that we don't need to take 1 - u here. I'll revert the change. I was worried that np.log(U) will throw something wrong but indeed it will give -inf when U=0 so the condition will fail; and for the boost, though it might give boost=0, we don't need to deal with it because we clamp the sample at the end. For exponential sampling, it is used to give samples in range (0, inf) instead of (0, inf].

# Theoretically
# Use 1 - u
-np.log(1 - np.array(1 - np.finfo(np.float32).eps, dtype=np.float32))  # max = 16
-np.log(1)  # min = 0 with positive probability
-np.log(1 - np.array(np.finfo(np.float32).eps, dtype=np.float32))  # 1e-7: smallest positive sample
# Use u
-np.log(np.array(np.finfo(np.float32).tiny, dtype=np.float32))  # 87: largest finite sample
-np.log(0)  # max = inf with positive probability
-np.log(1 - np.array(np.finfo(np.float32).eps, dtype=np.float32))  # min = 1e-7

The tricky part here is:

  • numpy.random uses double precision => u is non-zero in most cases => using both u and 1-u would be fine in most cases.
  • jax.random supports single precision. In single precision, jax.random.uniform generates samples belong to the set 0 & [finfo.eps, 1-finfo.eps] instead of 0 & [finfo.tiny, 1-finfo.eps]. So both ways have minimum non-zero value is 1e-7 and maximum finite value is 16 (tested). However, using 1 - u will additionally give 0 samples with probability 1/10000000, and using u will additionally give inf samples with probability 1/10000000. So we have to clamp the result.

I don't find any way is better so I'll stick with np.log(u) because it is faster.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining this issue!

@neerajprad neerajprad merged commit c64d8ea into pyro-ppl:master Mar 3, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants