Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an AutoStructured guide and StructuredReparam #2812

Merged
merged 17 commits into from
Apr 25, 2021
Merged

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Apr 21, 2021

Addresses #2813

This adds a flexible AutoStructured guide that allows a variety of distributions modeling each latent site (Delta, Normal, or MultivariateNormal), together with a mechanism to declare (link-)linear dependencies between latent variables. As discussed with @fehiepsi this aims to (1) generalize guides with arrowhead covariance structure while (2) learning parameters that can be cheaply used to precondition NUTS via a reparameterizer StructuredReparam.

This also adds a simple StructuredReparam that uses a trained AutoStructured guide to precondition a model for use in HMC. This new (guide,reparam) pair can be seen as a structured version of the monolithic (AutoContinuous,NeuTraReparam) pair in the same sense that AutoNormal is a structured version of the monolithic AutoDiagonalNormal guide.

My main motivation is to use this for high-dimensional models (e.g. 300000 latent variables) with a structured precision matrix, and then use that structured precision matrix as a preconditioner for NUTS.

(Note this does not implement Automatic structured variational inference, a variational family whose stricture is severely limited to dependencies in the model. Nor does this first PR implement automatic suggestion of the guide structure as in Faithful inversion of generative models for effective amortized inference.)

Tested

  • unit tests for AutoStructured
  • unit tests for StructuredReparam
  • test on a real world problem (private repo from which this was abstracted)

@fritzo fritzo changed the title Add an AutoStructured guide Add an AutoStructured guide and StructuredReparam Apr 22, 2021
@fritzo fritzo requested a review from fehiepsi April 24, 2021 18:43
@fritzo fritzo marked this pull request as ready for review April 24, 2021 18:43
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks clean to me! I can confirm that this is equivalent to if we learn arrowhead matrix directly:

A = [[L @ L.t + w @ D @ w.T, w @ D], [D @ w.T, D]]

where L is scale_tril of x_aux, D is the variance of y_aux, w is dep.weight.

pyro/infer/autoguide/guides.py Show resolved Hide resolved
pyro/infer/autoguide/guides.py Outdated Show resolved Hide resolved
scale_tril = scale[..., None] * scale_tril
aux_value = pyro.sample(
name + "_aux",
dist.MultivariateNormal(zero, scale_tril=scale_tril),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we factor this out to scale_tril @ Normal(0, 1), I guess HMC will be a bit happier.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point, I guess that is equivalent to reparametrizing. I've also changed Normal(0,scale) to Normal(0,1) * scale in the "normal" case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you will need to add logdet of those affine transforms. How about using dist.TransformedDistribution(dist.Normal(...), LowerCholeskyAffine(...)) so that we can use TransformReparam in the reparam?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I've added the logdet terms by hand here since it's simpler. Does it look right now?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it looks correct to me.

Copy link
Member Author

@fritzo fritzo Apr 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm seeing very different results with the two versions, and this change seems to have broken my SVI inference. I've been staring at these two versions and I can't seem to see the difference:

# Version 1. This works.
aux_value = pyro.sample(..., Normal(zero, scale).to_event(1))

# Version 2. This is in pyro dev, but no longer works.
aux_value = pyro.sample(..., Normal(zero, 1).to_event(1))
aux_value = aux_value * scale
log_density = log_density - scale.log().sum(-1)

Any ideas @fehiepsi?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe two versions are equivalent... Not sure what's going on. Let me play with some tests to see if elbo is the same for the two versions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks I'll do the same, at least to create a unit test I can run locally (not on some huge model on a GPU cloud machine)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I find the issue. Here log_density is calculated as the sum over all dimensions of the site. However, the ldj term, which is used to calculate the logdet of unconstrained->constrained values, maintains the batch dimension. So sum of them will give wrong result if this site is under some plate. I guess we should use pyro.factor for those log_density terms. What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi thanks, yes I now see the error. I'll think about this and submit a fix ASAP.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! I'll try to find some time next week to port this to NumPyro (hopefully it will be straightforward).

@fritzo
Copy link
Member Author

fritzo commented Apr 25, 2021

Thanks for your careful review @fehiepsi! I'll try to add automated dependency structure with @eb8680 next week.

@fehiepsi fehiepsi merged commit 3a776f9 into dev Apr 25, 2021
@fritzo fritzo mentioned this pull request Apr 29, 2021
4 tasks
@fritzo fritzo deleted the auto-structured branch September 27, 2021 14:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants