Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stein mixtures v2 #1590

Closed
wants to merge 35 commits into from
Closed

Conversation

OlaRonning
Copy link
Member

@OlaRonning OlaRonning commented May 19, 2023

In merging the original Stein mixture paper with ours, we found an inconsistency in the original article. We correct that here and can (probably) connect the current form with optimizing the ELBO (hence a lower bound on the marginal likelihood). Feel free to PM me for details if you are interested.

The current version uses a quick and dirty version to separate the model_density from the guide_density. I'm sure this can be done more elegantly.

TODO

  • Update documentation
  • Update tests.
  • Refactor/rework stein loss
  • Add test for SteinLoss
  • Refactor/rework Predictive broadcasting solution models for multiple parameters.

@OlaRonning OlaRonning added the WIP label May 19, 2023
@@ -234,35 +232,6 @@ def model(obs):
assert_array_approx_equal(init_value, np.full(expected_shape, 0.0))


def test_svgd_loss_and_grads():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stein reports the 2-norm of Stein force instead of ELBO loss--the system converges when the norm is zero. Note that the norm may not reach zero due to stochasticity in the gradient.

fehiepsi
fehiepsi previously approved these changes Jun 2, 2023
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Do you want me to look at some particular parts? @OlaRonning

@@ -1008,11 +1008,26 @@ def __call__(self, rng_key, *args, **kwargs):
elif self.batch_ndims == 1: # batch over parameters
batch_size = jnp.shape(tree_flatten(self.params)[0][0])[0]
rng_keys = random.split(rng_key, batch_size)
# TODO: better way to broadcast the model across particles?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems dangerous to me, we define batchsize as the first dimension of one latent variable, if we allow different first dimensions, it is ambiguity which one needs to be batch size.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another solution is to sample the particle assignments first and then treat the predictive distribution as ‘nbatch_dim=0’. I could put the (mixture component) assignment logic in a predictive wrapper under Einstein. The way it doesn't interfere with the other inference methods.

jakevdp and others added 11 commits June 2, 2023 08:31
* [WIP] jittable transforms

* add licence to new test file

* turn BijectorConstraint into pytree

* test flattening/unflattening of parametrized constraints

* cosmetic edits

* fix typo

* implement tree_flatten/unflatten for transforms

* attempt to avoid confusing black

* add (un)flattening meths for BijectorTransform

* fixup! implement tree_flatten/unflatten for transforms

* test vmapping over transforms/constraints

* Make constraints `__eq__` checks robust to arbitrary inputs

* make transforms equality check robust to arbitrary inputs

* test constraints and transforms equality checks
* Bump to version 0.12.0

* Fix docs of MCMC class which does not display sharding example

* add get_transforms and unconstrain_fn to docs
* Fix for jax 0.4.11

* Require jax, jaxlib version >= 0.4.7
* Update reparam.py

Refer pyro-ppl#1598

* Making suggested changes
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@OlaRonning
Copy link
Member Author

It looks like files changed is against an old version of Master.

@OlaRonning OlaRonning closed this Jun 6, 2023
@OlaRonning OlaRonning reopened this Jun 6, 2023
@OlaRonning
Copy link
Member Author

can't be bothered to wrestle with git and rebase.

@OlaRonning OlaRonning closed this Jun 6, 2023
@OlaRonning OlaRonning mentioned this pull request Jun 6, 2023
6 tasks
@OlaRonning OlaRonning deleted the feature/sm branch June 6, 2023 11:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants