Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change coordinatization of AutoMultivariateNormal #2963

Merged
merged 6 commits into from
Nov 11, 2021
Merged

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Nov 9, 2021

Addresses #2924

This adds a row-wise .scale parameter to AutoLowRankMultivariateNormal. The resulting overparametrized triple (loc,scale,scale_tril) has been observed to speed up learning in AutoLowRankMultivariateNormal and in AutoStructured.

This also fixes .get_base_dist() and adds tests of use in NeuTraReparam, and moves two tests from tests/contrib/autoguide to tests/infer/autoguide.

Tested

  • refactoring covered by existing tests
  • .get_base_dist() covered by a new regression test

@martinjankowiak
Copy link
Collaborator

is there any reason to think that this works better than the non-overparameterized version i implemented in pyro-ppl/numpyro#1146? i would be mildly surprised if that were the case

@fritzo fritzo added the WIP label Nov 9, 2021
@fritzo
Copy link
Member Author

fritzo commented Nov 9, 2021

is there any reason to think that this works better than the non-overparameterized version

I would expect this PR to work about the same as the non-overreparametrized version, but with only a fraction of the coding effort. You could consider this PR a stepping stone, in case you want to implement the fancier version in a subsequent PR.

)

def get_base_dist(self):
return dist.Normal(
torch.zeros_like(self.loc), torch.zeros_like(self.loc)
torch.zeros_like(self.loc), torch.ones_like(self.loc)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, what a bug!? It is surprised to me that we didn't catch this earlier.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, apparently nobody uses AutoMultivariateNormal with NeuTraReparam. My motivation for this PR is to write a tutorial about autoguides, so hopefully they'll get more exposure

@fritzo fritzo added this to the 1.8 release milestone Nov 10, 2021
@fritzo fritzo mentioned this pull request Nov 10, 2021
6 tasks

def _loc_scale(self, *args, **kwargs):
return self.loc, self.scale_tril.diag()
return self.loc, self.scale * self.scale_tril.diag()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does scale need an unsqueeze here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope, scale is already the correct shape

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i see this is only for marginals

pyro/infer/autoguide/guides.py Show resolved Hide resolved
@martinjankowiak martinjankowiak merged commit 3c6a591 into dev Nov 11, 2021
@fritzo
Copy link
Member Author

fritzo commented Nov 11, 2021

Thanks for reviewing @martinjankowiak!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants