Skip to content

Add standard_gamma sampler#7

Merged
neerajprad merged 5 commits intopyro-ppl:masterfrom
fehiepsi:gamma
Feb 27, 2019
Merged

Add standard_gamma sampler#7
neerajprad merged 5 commits intopyro-ppl:masterfrom
fehiepsi:gamma

Conversation

@fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Feb 27, 2019

This PR adds standard_gamma sampler. It is slow to compile but after compiling, sampling will be fast.

I think this version can be temporarily used for time being. Currently, some drawbacks are:

Later, to improve performance, we should move to this to JAX to get feedback from its devs.

@neerajprad Do I need CUDA to run tests? Currently, all tests throw error that "platform CUDA not found". :(

@neerajprad
Copy link
Member

#8 should take care of pytest not throwing exceptions for the warning that you see.

def test_standard_gamma_stats(alpha):
rng = random.PRNGKey(0)
z = standard_gamma(rng, np.full((1000,), alpha))
assert np.abs((np.mean(z) - alpha) / alpha) < 0.06
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: You could just use np.allclose with rtol=0.06.

@neerajprad
Copy link
Member

I think it will be a good idea to move most of the distributions related functionality to jax finally, but that will also add overhead in terms of design discussions and review. For the time being, its nice to be able to hack away without that overhead, but let us keep that in mind for sure.

@neerajprad
Copy link
Member

Thanks for adding the gamma sampler, @fehiepsi. I'll take a look at issue 3 you mentioned above after merging this.

@neerajprad neerajprad merged commit 2761a05 into pyro-ppl:master Feb 27, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants