Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement GuideMessenger, AutoNormalMessenger, AutoRegressiveMessenger #2953

Merged
merged 22 commits into from
Nov 1, 2021

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Oct 28, 2021

Addresses #2950
Replaces #2951, #2928

This introduces a new pattern for writing variational posteriors: as effect handlers. Whereas ELBO's used to always first run the guide and then run the model, the they now check if the guide is a GuideMessenger, and if so run the guide as messenger that intercepts model sites as they arise. This has two advantages:

  1. Guides can be more dynamic, sampling only those nodes in models that actually arise.
  2. Guides can leverage model-side computations at no extra cost. This framework produces a (model,guide) pair of traces while running model computations only once (they are shared between the model and guide). This is the main feature requested by @vitkl

The interface is backwards compatible with all ELBOs (except TraceEnum_ELBO which is out of scope):

model = ...  # standard model syntax
guide = AutoRegressiveMessenger()  # note this currently doesn't take a model
elbo = Trace_ELBO()
optim = Adam({"lr": 1e-3})
svi = SVI(model, guide, optim, elbo)

To demonstrate the new interface I've added two example autoguides subclassing GuideMessenger:

  • AutoNormalMessenger is an analog of AutoNormal and is a useful base class for custom guides.
  • AutoRegressiveMessenger is a simple autoregressive guide whose posteriors are learned affine transforms of the priors at each site (recursively conditioned on upstream posterior samples).

Tested

Added AutoNormalMessenger and AutoRegressiveMessenger to a bunch of test_autoguide.py tests

  • test_factor()
  • test_shapes()
  • test_subsample_model()
  • test_subsample_model_amortized() (this is a new test)
  • test_serialization()
  • test_init_loc_fn()
  • test_median()
  • test_median_module()
  • test_linear_regression_smoke()
  • test_*_helpful_support_error()
  • test_exact*()

@fritzo fritzo marked this pull request as draft October 28, 2021 20:22
@fritzo
Copy link
Member Author

fritzo commented Oct 28, 2021

@vitkl I think this should be ready to try out.

@vitkl
Copy link
Contributor

vitkl commented Oct 29, 2021

@fritzo this looks quite exciting!

  1. So AutoRegressiveMessenger acts like AutoNormal guide?
  2. Can you please give an example of how to use upstream_values=... for encoding the hierarchy/dependencies?

pyro/infer/effect_elbo.py Outdated Show resolved Hide resolved
@vitkl
Copy link
Contributor

vitkl commented Oct 29, 2021

I see that upstream_values are automatically collected and AutoRegressiveMessenger currently doesn't support the specification of hierarchies.

So the guide with hierarchical dependencies for all sites should simply do this? See line marked as "Here ->"

class HierarchicalGuideMessenger(AutoRegressiveMessenger):
            def get_posterior(self, name, prior, upstream_values):
                    # Use a distribution at all site the value of which depends on upstream_values.
                    with helpful_support_errors({"name": name, "fn": prior}):
                        transform = biject_to(prior.support)
                    loc, scale = self._get_params(name, prior)
                    affine = dist.transforms.AffineTransform(
Here ->                 loc + transform.inv(prior.mean), scale, 
                        event_dim=transform.domain.event_dim, cache_size=1
                    )
                    posterior = dist.TransformedDistribution(
                        prior, [transform.inv.with_cache(), affine, transform.with_cache()]
                    )
                    return posterior

Or should it be something more complex like below where users need to provide a dictionary specifying which sites have which parents and how to transform them into each other:

class HierarchicalGuideMessenger(AutoRegressiveMessenger):
            def init(self, args=list(), kwargs=dict(), hierarchical_sites=dict()):
                        self.super().init(*args, **kwargs)
                        self.hierarchical_sites = hierarchical_sites
            def get_posterior(self, name, prior, upstream_values):
                if name in self.hierarchical_sites.keys():
                    # Use a custom distribution at this site the value of which depends on upstream_values.
                    with helpful_support_errors({"name": name, "fn": prior}):
                        transform = biject_to(prior.support)
                    # Get values of parent sites
                    parent_names = self.hierarchical_sites[name]["parent_nodes"]
                    parent_upstream_values = {k: upstream_values[k] for k in parent_names}
                    hierarchical_loc = self.hierarchical_sites[name]["fn"](**parent_upstream_values)
                    hierarchical_loc_untransformed = transform.inv(hierarchical_loc)
                    loc, scale = self._get_params(name, prior)
                    affine = dist.transforms.AffineTransform(
                        loc + hierarchical_loc_untransformed, scale, event_dim=transform.domain.event_dim, cache_size=1
                    )
                    posterior = dist.TransformedDistribution(
                        prior, [transform.inv.with_cache(), affine, transform.with_cache()]
                    )
                    return posterior
                # Fall back to autoregressive.
                return super().get_posterior(name, prior, upstream_values)

Where hierarchical_sites needs to specify:

hierarchical_sites = {"x": {"parent_nodes": ["y", "z"], "fn": lambda y, z: y @ z}}

@fritzo
Copy link
Member Author

fritzo commented Oct 29, 2021

So GuideMessenger acts like AutoNormal guide?
Can you please give an example of how to use GuideMessenger(upstream_values=...) ?

I've added an AutoNormalMessenger guide just now, that's probably easier for discussion. In the AutoNormalMessenger docstring I've added an example of how to use upstream_values.

@fritzo
Copy link
Member Author

fritzo commented Oct 29, 2021

So the guide with hierarchical dependencies for all sites should simply do this? See line marked as "Here ->"

I think AutoRegressiveMessenger is already hierarchical, there's no need to extract a mean.

BTW I think your idea is similar to ASVI, which samples from a posterior with the same dependency structure as the prior. Note that in general the posterior may have more complex dependency structure, as described in Webb et al. (2017).

@vitkl
Copy link
Contributor

vitkl commented Oct 29, 2021

Very interesting, thanks for adding AutoNormalMessenger and sharing the papers. Is the Webb et al. (2017) the motivation for how you do things in AutoStructured?

@yozhikoff @la-sekretar @bv2 are probably interested in this too

) -> Union[TorchDistribution, torch.Tensor]:
with helpful_support_errors({"name": name, "fn": prior}):
transform = biject_to(prior.support)
loc, scale = self._get_params(name, prior)
Copy link
Contributor

@vitkl vitkl Oct 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The guide will become fully hierarchical if you do this but it is not fully hierarchical by default, right?

loc, scale = self._get_params(name, prior)
loc = loc + prior.loc

Ideally one can add some kind of test of whether this site has dependency sites.

You are also mentioning that it could be useful to encode a more complex dependency:

loc, scale, weight = self._get_params(name, prior)
loc = loc + prior.loc * weight

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, the intention of this simple guide is to be mean field.

Do you want to try contributing an AutoHierchicalNormalMessenger guide as a follow-up to this PR? I tried to do something similar with AutoRegressiveMessenger below by sampling from the prior and then shifting in unconstrained space. I was unsure how to implement a general AutoHierarchicalNormalMessenger because not all prior distributions have a .mean method, and even then it is the mean in unconstrained space that we care about. E.g. how do we deal with Gamma or LogNormal or Beta or Dirichlet?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your point about the distributions that don't have the mean. What are those distributions by the way?

I am thinking about this solution:

loc, scale, weight = self._get_params(name, prior)
loc = loc + transform.inv(prior.loc) * weight

Does it make sense for all distributions that have the mean?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will doAutoHierchicalNormalMessenger PR - should I wait until this PR is merged?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What [distributions do not have a .mean method] by the way?

  • Heavy tailed distributions may not have a mean, e.g. Cauchy and Stable have infinite variance and no defined mean
  • Non-euclidean distributions such as VonMises3D and ProjectedNormal have no defined mean.
  • Some complex distributions have no computable mean, e.g. TransformedDistribution(Normal(...), MyNormalizingFlow).

Does prior.loc make sense for all distributions that have the mean?

First I would opt for prior.mean rather than prior.loc, since e.g. LogNormal(...).loc isn't a mean, rather it is the mean of the pre-transformed normal. Second note that the transform of the constrained mean is not the same as the unconstrained mean or unconstrained median, e.g. for LogNormal, mean = exp(loc + scale**2 / 2) whereas median = exp(loc).

I think your .mean idea is good enough in most cases, and for cases where it fails, users can subclass and define their own custom .get_posterior() methods.

@fritzo fritzo marked this pull request as ready for review October 29, 2021 19:15
@fritzo fritzo marked this pull request as draft October 29, 2021 19:16
@fritzo
Copy link
Member Author

fritzo commented Oct 29, 2021

Is the Webb et al. (2017) the motivation for how you do things in AutoStructured?

Yes, Webb et al. (2017) was the idea behind AutoStructured with auto dependency detection via pyro.infer.inspect.get_dependencies(), however @eb8680 convinced me AutoGaussian is a more natural solution, so I've been spending more work on that.

Comment on lines +142 to +148
if self.num_particles == 1:
return fn
return pyro.plate(
"num_particles_vectorized",
self.num_particles,
dim=-self.max_plate_nesting,
)(fn)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this ensures serializability

@vitkl
Copy link
Contributor

vitkl commented Oct 31, 2021

How do these guides know about the plates? AutoNormal has create_plates argument which tells AutoNormal which plates exist. Maybe good to comment about this in the docs.

@fritzo
Copy link
Member Author

fritzo commented Oct 31, 2021

How do these guides know about the plates?

These guides directly use the pyro.plate statements in the model, so no create_plates logic is needed.

@vitkl I'm unsure how subsampling should behave in these models, what do you think? Should we always assume that subsampled plates are amortized (share a single parameter value)? Should we additionally provide an amortized_plates kwarg to AutoMessenger.__init__()? I'd guess yes to both, but will that work for your use cases?

EDIT I've just pushed a fix to support subsampling, and an amortized_plates kwarg with tests.

@fritzo fritzo requested review from martinjankowiak and removed request for eb8680 November 1, 2021 15:11
pyro/infer/effect_elbo.py Outdated Show resolved Hide resolved
pyro/infer/effect_elbo.py Outdated Show resolved Hide resolved
pyro/infer/effect_elbo.py Outdated Show resolved Hide resolved
@fritzo fritzo changed the title Implement Effect_ELBO, AutoNormalMessenger, AutoRegressiveMessenger Implement GuideMessenger, AutoNormalMessenger, AutoRegressiveMessenger Nov 1, 2021
@fritzo
Copy link
Member Author

fritzo commented Nov 1, 2021

@martinjankowiak thanks for the helpful review, the new design is much simpler!

@vitkl you'll need to update:

- def get_posterior(self, name, prior, upstream_values):
+ def get_posterior(self, name, prior):

martinjankowiak
martinjankowiak previously approved these changes Nov 1, 2021
@fritzo
Copy link
Member Author

fritzo commented Nov 1, 2021

@martinjankowiak looks like tests now pass

@martinjankowiak martinjankowiak merged commit 1dc2d73 into dev Nov 1, 2021
@vitkl
Copy link
Contributor

vitkl commented Nov 1, 2021

@fritzo I think that, in the simplest case, it would be good to support subsampling as it is currently done in AutoNormal (define all parameters on initialisation, use a subset of them according to plate indices). I assume this behaviour was already supported, right?
I see amortisation as a second optional layer.

I don't think I understand what this statement means:

A tuple of names of plates over which guide parameters should be shared. This is useful for subsampling, where a guide parameter can be shared across all plates.

The code seems to suggest that, for a parameter w_cf that has both cell c and cell type f index where subsampling is done across cells c, only w_f parameters are learned test_subsample_model_amortized. Is this correct? If yes, I don't understand why this is useful. You are essentially converting a local parameter to global, whereas users of the model are interested in values specific to each cell c (e.g. cell abundance in cell2location model).

@fritzo
Copy link
Member Author

fritzo commented Nov 2, 2021

@vitkl feel free to clarify the docstrings in your follow-up PR. Indeed your language of "global" and "local" seems clearer.

Is this correct? If yes, I don't understand why this is useful.

Yes your interpretation of amortized_plates is correct. Note if you specify amortized_plates=() (the default), the behavior will be the same as in AutoNormal, where the full parameter is initialized at the first invocation, and subsets of it are extracted at each minibatch. Do you use that version in cell2location?

Can you explain what kind of amortization and minibatching strategies you would find useful, in cell2location or elsewhere?

@vitkl
Copy link
Contributor

vitkl commented Nov 2, 2021

Thanks for explaining!

The setting where the full parameter is initialized at the first invocation, and subsets of it are extracted at each minibatch, this setting can be used in cell2location (although we find that it leads to reduces accuracy). It is also used in MOFA (https://biofam.github.io/MOFA2/) and a few models related to cell2location (Stereoscope, DestVI - in scvi-tools) - so this approach is used by a few methods - although these particular models are implemented using pyro. Good to know that this setting works with the messenger guides.

The amortization strategy I am thinking about will be hopefully clear when I add the amortised class as a PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants