-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement GuideMessenger, AutoNormalMessenger, AutoRegressiveMessenger #2953
Conversation
@vitkl I think this should be ready to try out. |
@fritzo this looks quite exciting!
|
I see that So the guide with hierarchical dependencies for all sites should simply do this? See line marked as "Here ->" class HierarchicalGuideMessenger(AutoRegressiveMessenger):
def get_posterior(self, name, prior, upstream_values):
# Use a distribution at all site the value of which depends on upstream_values.
with helpful_support_errors({"name": name, "fn": prior}):
transform = biject_to(prior.support)
loc, scale = self._get_params(name, prior)
affine = dist.transforms.AffineTransform(
Here -> loc + transform.inv(prior.mean), scale,
event_dim=transform.domain.event_dim, cache_size=1
)
posterior = dist.TransformedDistribution(
prior, [transform.inv.with_cache(), affine, transform.with_cache()]
)
return posterior Or should it be something more complex like below where users need to provide a dictionary specifying which sites have which parents and how to transform them into each other: class HierarchicalGuideMessenger(AutoRegressiveMessenger):
def init(self, args=list(), kwargs=dict(), hierarchical_sites=dict()):
self.super().init(*args, **kwargs)
self.hierarchical_sites = hierarchical_sites
def get_posterior(self, name, prior, upstream_values):
if name in self.hierarchical_sites.keys():
# Use a custom distribution at this site the value of which depends on upstream_values.
with helpful_support_errors({"name": name, "fn": prior}):
transform = biject_to(prior.support)
# Get values of parent sites
parent_names = self.hierarchical_sites[name]["parent_nodes"]
parent_upstream_values = {k: upstream_values[k] for k in parent_names}
hierarchical_loc = self.hierarchical_sites[name]["fn"](**parent_upstream_values)
hierarchical_loc_untransformed = transform.inv(hierarchical_loc)
loc, scale = self._get_params(name, prior)
affine = dist.transforms.AffineTransform(
loc + hierarchical_loc_untransformed, scale, event_dim=transform.domain.event_dim, cache_size=1
)
posterior = dist.TransformedDistribution(
prior, [transform.inv.with_cache(), affine, transform.with_cache()]
)
return posterior
# Fall back to autoregressive.
return super().get_posterior(name, prior, upstream_values) Where hierarchical_sites = {"x": {"parent_nodes": ["y", "z"], "fn": lambda y, z: y @ z}} |
I've added an |
I think BTW I think your idea is similar to ASVI, which samples from a posterior with the same dependency structure as the prior. Note that in general the posterior may have more complex dependency structure, as described in Webb et al. (2017). |
Very interesting, thanks for adding @yozhikoff @la-sekretar @bv2 are probably interested in this too |
) -> Union[TorchDistribution, torch.Tensor]: | ||
with helpful_support_errors({"name": name, "fn": prior}): | ||
transform = biject_to(prior.support) | ||
loc, scale = self._get_params(name, prior) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The guide will become fully hierarchical if you do this but it is not fully hierarchical by default, right?
loc, scale = self._get_params(name, prior)
loc = loc + prior.loc
Ideally one can add some kind of test of whether this site has dependency sites.
You are also mentioning that it could be useful to encode a more complex dependency:
loc, scale, weight = self._get_params(name, prior)
loc = loc + prior.loc * weight
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct, the intention of this simple guide is to be mean field.
Do you want to try contributing an AutoHierchicalNormalMessenger
guide as a follow-up to this PR? I tried to do something similar with AutoRegressiveMessenger
below by sampling from the prior and then shifting in unconstrained space. I was unsure how to implement a general AutoHierarchicalNormalMessenger
because not all prior
distributions have a .mean
method, and even then it is the mean in unconstrained space that we care about. E.g. how do we deal with Gamma
or LogNormal
or Beta
or Dirichlet
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand your point about the distributions that don't have the mean. What are those distributions by the way?
I am thinking about this solution:
loc, scale, weight = self._get_params(name, prior)
loc = loc + transform.inv(prior.loc) * weight
Does it make sense for all distributions that have the mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I will doAutoHierchicalNormalMessenger
PR - should I wait until this PR is merged?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What [distributions do not have a
.mean
method] by the way?
- Heavy tailed distributions may not have a mean, e.g.
Cauchy
andStable
have infinite variance and no defined mean - Non-euclidean distributions such as
VonMises3D
andProjectedNormal
have no defined mean. - Some complex distributions have no computable mean, e.g.
TransformedDistribution(Normal(...), MyNormalizingFlow)
.
Does
prior.loc
make sense for all distributions that have the mean?
First I would opt for prior.mean
rather than prior.loc
, since e.g. LogNormal(...).loc
isn't a mean, rather it is the mean of the pre-transformed normal. Second note that the transform of the constrained mean is not the same as the unconstrained mean or unconstrained median, e.g. for LogNormal, mean = exp(loc + scale**2 / 2)
whereas median = exp(loc)
.
I think your .mean
idea is good enough in most cases, and for cases where it fails, users can subclass and define their own custom .get_posterior()
methods.
Yes, Webb et al. (2017) was the idea behind |
if self.num_particles == 1: | ||
return fn | ||
return pyro.plate( | ||
"num_particles_vectorized", | ||
self.num_particles, | ||
dim=-self.max_plate_nesting, | ||
)(fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this ensures serializability
How do these guides know about the plates? |
These guides directly use the @vitkl I'm unsure how subsampling should behave in these models, what do you think? Should we always assume that subsampled plates are amortized (share a single parameter value)? Should we additionally provide an EDIT I've just pushed a fix to support subsampling, and an |
@martinjankowiak thanks for the helpful review, the new design is much simpler! @vitkl you'll need to update: - def get_posterior(self, name, prior, upstream_values):
+ def get_posterior(self, name, prior): |
@martinjankowiak looks like tests now pass |
@fritzo I think that, in the simplest case, it would be good to support subsampling as it is currently done in AutoNormal (define all parameters on initialisation, use a subset of them according to plate indices). I assume this behaviour was already supported, right? I don't think I understand what this statement means:
The code seems to suggest that, for a parameter |
@vitkl feel free to clarify the docstrings in your follow-up PR. Indeed your language of "global" and "local" seems clearer.
Yes your interpretation of Can you explain what kind of amortization and minibatching strategies you would find useful, in cell2location or elsewhere? |
Thanks for explaining! The setting where the full parameter is initialized at the first invocation, and subsets of it are extracted at each minibatch, this setting can be used in cell2location (although we find that it leads to reduces accuracy). It is also used in MOFA (https://biofam.github.io/MOFA2/) and a few models related to cell2location (Stereoscope, DestVI - in scvi-tools) - so this approach is used by a few methods - although these particular models are implemented using pyro. Good to know that this setting works with the messenger guides. The amortization strategy I am thinking about will be hopefully clear when I add the amortised class as a PR. |
Addresses #2950
Replaces #2951, #2928
This introduces a new pattern for writing variational posteriors: as effect handlers. Whereas ELBO's used to always first run the guide and then run the model, the they now check if the guide is a
GuideMessenger
, and if so run the guide as messenger that intercepts model sites as they arise. This has two advantages:The interface is backwards compatible with all ELBOs (except
TraceEnum_ELBO
which is out of scope):To demonstrate the new interface I've added two example autoguides subclassing
GuideMessenger
:AutoNormalMessenger
is an analog ofAutoNormal
and is a useful base class for custom guides.AutoRegressiveMessenger
is a simple autoregressive guide whose posteriors are learned affine transforms of the priors at each site (recursively conditioned on upstream posterior samples).Tested
Added
AutoNormalMessenger
andAutoRegressiveMessenger
to a bunch oftest_autoguide.py
teststest_factor()
test_shapes()
test_subsample_model()
test_subsample_model_amortized()
(this is a new test)test_serialization()
test_init_loc_fn()
test_median()
test_median_module()
test_linear_regression_smoke()
test_*_helpful_support_error()
test_exact*()