-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Rejection Sampling Variational Inference #659
Conversation
Glad to see this incorporated into pyro, @fritzo and @martinjankowiak! One of the points discussed in the paper is that the score function term tends to be higher variance, and in practice it can often be safely ignored at the cost of a small bias. You can also ameliorate it to some extent by "shape augmentation," i.e. introducing auxiliary uniform r.v.'s in order to increase the shape parameter and thereby increase the acceptance probability. I did not implement this in the demo notebook, but we used it in our paper. The code for that is in this repo: https://github.com/blei-lab/ars-reparameterization and @naesseth can probably provide more details on its implementation. |
@martinjankowiak This is almost done, there's just one more test to fix. Could you please review the changes to |
cool awesome. i'll do so in the morning when i have a fresh brain |
pyro/distributions/distribution.py
Outdated
""" | ||
log_pdf = self.batch_log_pdf(x, *args, **kwargs) | ||
if self.reparameterized: | ||
return ScoreParts(log_pdf, 0, log_pdf) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe explicit names in namedtuple here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's investigate the xfail more
class RejectionStandardGamma(Rejector): | ||
""" | ||
Naive Marsaglia & Tsang rejection sampler for standard Gamma distibution. | ||
This assumes `alpha >= 1` and does not boost `alpha` boosting or |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring typo
@martinjankowiak I believe I've fixed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm. @fritzo merge?
Ready to merge, finally. |
@fritzo are there any examples of using the accept-reject dirichlet or gamma? |
@rmehta1987 i don't believe there are. but if you changed the |
Addresses #63
This implements Rejection Sampling Variational Inference (RSVI) of Naesseth et al. (2017).
This implements a new class
Rejector
as an abstract base class derived fromDistribution
. The crux to integrating these with Pyro's SVI inTraceGraph_ELBO
is a new methodDistribution.score_function_term()
that can be overridden to computeg_cor
in the paper. This method computes a partial score function for a partially reparameterized distribution such that: for an unreparameterized distribution it computes the entire score function, and for a fully reparameterized distribution it computes zero.Tasks
Trace_ELBO
RejectionGamma
example, e.g. see @slinderman's notebookThis was joint work by @martinjankowiak and @fritzo