Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add two ensemble sampling methods #1692

Merged
merged 25 commits into from
Jan 26, 2024
Merged

Add two ensemble sampling methods #1692

merged 25 commits into from
Jan 26, 2024

Conversation

amifalk
Copy link
Contributor

@amifalk amifalk commented Nov 30, 2023

As described in #1691

Currently, I've only implemented a subset of emcee and zeus moves, but it should be trivial to extend in the future. I also don't have support for potential_fn.

Should there be separate tests for these modules or should I try to work them in to existing ones? The pattern isn't terribly clean with existing tests because AIES and ESS can only be run with multiple chains.

@martinjankowiak
Copy link
Collaborator

@amifalk hello what is the status of this PR? are you waiting for feedback from us? i think we sort of lost track of this over the holiday break.

@amifalk
Copy link
Contributor Author

amifalk commented Jan 13, 2024

@martinjankowiak Would you mind looking through it and making comments about what you feel is missing/needs to change? I believe all of the code lints with the exception of the batch_ravel_pytree function in ensemble_utils.py. I modified that from JAX so I was reluctant to change the formatting, but I'm happy to do what's needed to make it lint - or if you have ideas of a better solution.

At the core, these are gradient free methods that update the state of each chain by looking at the current state of the other chains, which I've implemented by storing an (n_chains, n_params) array (thus the need to use write a version of ravel_pytree that anticipates a batch dimension).

I'm also happy to reduce the scope of the PR if that will make it easier to review (the affine invariant ensemble sampler, AIES, is much less complex than the ensemble slice sampler, ESS).

@martinjankowiak
Copy link
Collaborator

@amifalk thanks i'll try to review this weekend

Copy link
Collaborator

@martinjankowiak martinjankowiak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so i'm not familiar with these algorithms in detail so it's a bit hard for me to follow all the details but the code is generally very clean and readable : )

i think the main thing that is missing is tests:

  • you can put a test_ensemble_util.py in tests/infer
  • you should add at least one test that does inference on a conjugate model where you can e.g. compute analytic posterior means/variances; for example you can add your methods to a few of the following tests: test_mcmc.py::test_unnormalized_normal_x64, test_logistic_regression_x64, test_beta_bernoulli_x64
  • you have a fair number of different ways to initialize your kernels so you probably want to add some simple smoke tests or the like to test_ensemble_mcmc.py or similar. e.g. make sure that various combinations of init args do not error out. make simple checks about expected shapes of outputs and the like. you can also use the with pytest.raises(ValueError, match="..expected message...") context manager to check that some invalid initializations and the like are being caught as expected.

numpyro/infer/ensemble.py Show resolved Hide resolved
numpyro/infer/ensemble.py Show resolved Hide resolved
numpyro/infer/ensemble.py Show resolved Hide resolved
numpyro/infer/ensemble.py Show resolved Hide resolved

super().__init__(model, potential_fn, randomize_split, init_strategy)

# XXX: this doesn't show because state_method='vectorized' shuts off diagnostics_str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi any workarounds?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this chain_method @amifalk? I don't think progress_bar works with the vectorized method there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is, sorry for the typo. Can we adjust this line to allow the diagnostics_str for these ensemble methods? If I remove the prng_key check, it displays correctly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That line checks whether vectorized method is used. If rng_key.ndim == 1, parallel is used. Otherwise, vectorized is used. Looking at your code, it seems that rng_key.ndim == 2 in both cases, which is strange to me. Could you double check the logic for both methods?

I think you can add an attribute to MCMCKernel to indicate whether it is an ensemble kernel and skip is_prng_key check at that line if the kernel is an ensemble one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Parallel sampling does not work with pmap since the chains need to talk to each other. To make that work, we would probably need to add an internal utility that does parallel sharding on the chains, kind of like what is requested in #1192. In the interest of not overcomplicating the API, I'll opt not to support it for now and make a note in the code, but if there was sufficient interest I think we could add a shard_chains argument to EnsembleSampler.

numpyro/infer/ensemble.py Outdated Show resolved Hide resolved
numpyro/infer/ensemble.py Show resolved Hide resolved
numpyro/infer/ensemble.py Outdated Show resolved Hide resolved
numpyro/infer/ensemble.py Show resolved Hide resolved
@amifalk
Copy link
Contributor Author

amifalk commented Jan 22, 2024

Thanks for the review @martinjankowiak! Documentation has been updated and tests have been added (a distribution test seems to be failing for reasons unrelated to this PR). If there are any more comments and/or if @fehiepsi has a better solution for batch_ravel_pytree, I'm happy to make corresponding fixes.

numpyro/infer/ensemble_util.py Outdated Show resolved Hide resolved
numpyro/infer/ensemble_util.py Outdated Show resolved Hide resolved
numpyro/infer/ensemble_util.py Show resolved Hide resolved

super().__init__(model, potential_fn, randomize_split, init_strategy)

# XXX: this doesn't show because state_method='vectorized' shuts off diagnostics_str
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this chain_method @amifalk? I don't think progress_bar works with the vectorized method there.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me! Could you add the classes to docs/source/mcmc file?

Copy link
Collaborator

@martinjankowiak martinjankowiak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great thanks!

can you check how much additional time your tests are adding?
i don't know how slow the additions to test_beta_bernoulli_x64 and test_logistic_regression_x64 might be. if they're adding too much time we may want to add something like pytest.mark.skipif("CI" in os.environ, reason="reduce time for CI")

test/infer/test_mcmc.py Outdated Show resolved Hide resolved
@amifalk
Copy link
Contributor Author

amifalk commented Jan 23, 2024

test_beta_bernoulli_x64 and test_logistic_regression_x64 were quite slow (around 20-30 seconds on my machine for each sampler) so I opted to skip them. Everything should be building correctly now!

@martinjankowiak
Copy link
Collaborator

@amifalk thanks! it might be good to keep one of the tests if it only adds ~50 seconds in total.... what do you think @fehiepsi ?

@fehiepsi
Copy link
Member

@amifalk Could you help fix the failing test by adding a catch at this line?

except AttributeError, ValueError:

@fehiepsi fehiepsi merged commit 01089cf into pyro-ppl:master Jan 26, 2024
4 checks passed
@martinjankowiak
Copy link
Collaborator

thanks @amifalk ! btw what's your interest in these algorithms? do you have a non-differentiable log density? because unless the problem is very multi-modal or otherwise difficult HMC should work pretty well if a gradient is available. also have you tried comparing these algorithms to SA?

@amifalk
Copy link
Contributor Author

amifalk commented Jan 26, 2024

thanks @amifalk ! btw what's your interest in these algorithms? do you have a non-differentiable log density? because unless the problem is very multi-modal or otherwise difficult HMC should work pretty well if a gradient is available. also have you tried comparing these algorithms to SA?

@martinjankowiak We work with some likelihood-free models like LCA that we perform approximate inference on through simulation. We actually did try SA first, but it didn't seem robust to the noisy likelihood.

@amifalk amifalk deleted the ensemble branch February 16, 2024 15:36
OlaRonning pushed a commit to aleatory-science/numpyro that referenced this pull request May 6, 2024
* ensemble sampling draft

* rewrite for loop as fori_loop

* added efficiency comment for ESS GaussianMove

* fix typo

* fixed ravel for mixed dtype

* add defaults

* add support for potential_fn

* AIES tests, warnings for AIES

* AIES input validation

* better docs, more input validation

* ESS passing test cases

* add tests for other files

* linting

* refactor ensemble_util

* make test result less close to margin in CI, swap deprecated function

* rename get_nondiagonal_indices,  fix batch_ravel_pytree

* print ensemble kernel diagnostics, smoke test parallel arg

* fix docstring build

* documentation

* skip slow CI tests, unnest test if statements

* fix doctest

* doc rewrite

* fix distribution test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants