Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to the HMC interface #1816

Closed
neerajprad opened this issue Apr 9, 2019 · 5 comments

Comments

Projects
None yet
3 participants
@neerajprad
Copy link
Member

commented Apr 9, 2019

For the next major release, we can incorporate some of the lessons from numpyro and issues from our forum, to improve the HMC interface as follows:

  • Provide a potential_fn like in numpyro so that users do not necessarily need to pass in their model, but can generate samples from any callable whose log_density can be evaluated. This will be useful for integrating with funsors, and eventually getting rid of our TraceTreeEvaluator wrappers. pyro-ppl/funsor#123
  • If an initial trace is provided, we should make sure that we do not ever run the model to set up our initial state. This is important because many a time we want to sample from distributions which do not have a sample method defined. We should still be able to run HMC on such models.
  • Related to the points above, it will be nice if (just like in numpyro) once given a potential_fn and initial_sample we can generate subsequent samples, without having to make any assumptions about the container data structure. Right now, we assume that this container data structure is a Pyro trace object, but this is only needed to interface with the TracePosterior class. This will make it really simple to integrate NUTS/HMC into other libraries. I think a middle ground might be to move the trace wrapping/unwrapping logic into the MCMC class, but more discussion is needed.

Related issue: Refactoring the TracePosterior interface. #1725

@neerajprad

This comment has been minimized.

Copy link
Member Author

commented Apr 9, 2019

@fehiepsi

This comment has been minimized.

Copy link
Collaborator

commented Apr 10, 2019

I think that following numpyro approach is a good idea. It is also helpful for the future when things like lax.cond, lax.while_loop is available in pytorch (disclaim: I don't know what is the state of pytorch jit right now) so we can jit the whole trajectory as in numpyro to improve the speed.

@fehiepsi

This comment has been minimized.

Copy link
Collaborator

commented Apr 26, 2019

@neerajprad Can I work on this issue too? It is required to support GPyTorch when only potential_fn and params are needed.

@neerajprad

This comment has been minimized.

Copy link
Member Author

commented Apr 26, 2019

Absolutely! Please feel free to create / mark off any sub-issue that you are working on above. I was planning to look at changes to the TracePosterior interface, so please feel free to take up the potential_fn refactor, which we also need to interface with funsors.

@fehiepsi fehiepsi referenced this issue May 2, 2019

Merged

Revise HMC to support general potential_fn #1845

4 of 4 tasks complete

@fritzo fritzo added this to the 0.4 release milestone May 6, 2019

@fehiepsi

This comment has been minimized.

Copy link
Collaborator

commented May 7, 2019

I guess we can close this issue now. Further work should go under TracePosterior refactoring.

@fehiepsi fehiepsi closed this May 7, 2019

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.