Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Score-based density estimators for SBI #1015

Open
wants to merge 98 commits into
base: main
Choose a base branch
from
Open

Score-based density estimators for SBI #1015

wants to merge 98 commits into from

Conversation

jsvetter
Copy link
Contributor

@jsvetter jsvetter commented Mar 19, 2024

What does this implement/fix? Explain your changes

This PR implements score-based methods for SBI (and related methods like flow matching). The first goal is to have a running version for score-based NPE. Later-on NLE, sequential methods and methods that can deal with multiple observations can be tackled.

This includes new base-classes for the required vectorfield-estimators and posteriors.

Does this close any currently open issues?

Fixes #962

Checklist

Put an x in the boxes that apply. You can also fill these out after creating
the PR. If you're unsure about any of them, don't hesitate to ask. We're here to
help! This is simply a reminder of what we are going to look for before merging
your code.

  • I have read and understood the contribution
    guidelines
  • I agree with re-licensing my contribution from AGPLv3 to Apache-2.0.
  • I have commented my code, particularly in hard-to-understand areas
  • I have added tests that prove my fix is effective or that my feature works
  • I have reported how long the new tests run and potentially marked them
    with pytest.mark.slow.
  • New and existing unit tests pass locally with my changes
  • I performed linting and formatting as described in the contribution
    guidelines
  • I rebased on main (or there are no conflicts with main)

@jsvetter jsvetter changed the title added files for vector field-based estimators and posteriors Score-based density estimators for SBI Mar 19, 2024
@jsvetter jsvetter marked this pull request as draft March 19, 2024 08:46
@manuelgloeckler
Copy link
Contributor

Related to #963 and #966 .

@janfb
Copy link
Contributor

janfb commented Mar 22, 2024

@pfuhr do you agree to your contributions being re-licensed under Apache 2.0 in the future?

@pfuhr
Copy link

pfuhr commented May 17, 2024

Yes, I agree that my contributions are re-licensed under Apache 2.0 in the future.

@manuelgloeckler manuelgloeckler marked this pull request as ready for review July 3, 2024 08:57
@manuelgloeckler
Copy link
Contributor

manuelgloeckler commented Jul 8, 2024

Alright, I integrated everything with the current main and fixed some problematic side effects/bugs we introduced. The following stuff still has to be implemented/improved:

  1. NSPE interface (MINOR)
  • Check docstrings (as this is user-exposed, might be some copy pasta)
  • build_posterior we do have two possibilities i.e. sample_with=”sde” will return a ScorePosterior and sample_with=”ode” will return a DirectPosterior (not implemented, may be very similar to flow matching implementation feat: flow matching methods #1049 , but in the end should really just build a ZUKO CNFs). Some stuff has been commented out and needs to be checked.
  1. ScorePosterior (MEDIATE)
  • sample should interface the score-based-samplers (dependent on 3)). But we need to properly shape conditions … . Not sure if need to use “accept_reject” for prior bounds, this anyway should in the future be done within the diffusion process.
  • log_prob Not sure if needed, if one wants an log_prob one should build a direct posterior. But maybe still good for map.
  • map Should use the score_net directly as gradient (no backprop on potentials) i.e. sample a few candidates, run gradient accent where the gradients is just the output of the score_net at t=T_min (Optional, create an issue to do this in the future).
  1. score_based_potential (MINOR)
  • This is more conceptually/abstraction. It just is not a potential but the gradient of a potential and behaves differently than the other potentials (might need another class?). Although this should, as the other potentials, mediate IID stuff. (These will also change in @gmoss13 feat: batched sampling for MCMC #1176. Any suggestions @janfb ?
  1. Score based sampler (MAJOR)
  • This needs to be rewritten to be a general sampling interface (we only need to implement one for now, but it must be extendable)
    • Should have a “predict/step” which steps backward in time (with euler/DDIM/…)
    • Should have an optional “correct” which stays at the same time (Langevin/Gibbs …)
  • Needs to be vectorized i.e. must support batched x at best (but this also loosely depends on feat: batched sampling for MCMC #1176)
  1. Score nets/estimator It looks quite good; I might make a few minor changes to improve the architecture. (MEDIATE)
  • input z_score should be time-dependent (as we add noise to it so, std/mean will change)
  • output_scores should be preconditioned by analytic approximation.
  • Times should be preconditioned by how fast mean/std changes (if they do not change a lot, the score should also not change much).
  • Add option to add covariate to loss to reduce variance.
  1. Tests: We need to extend and check tests. We currently have:
  • linearGaussian_nspe_test.py Currently only tests sample
  • score_estimator_tests.py Tests shaping conventions.
  • score_tests.py: Does not yet do something but implements some analytic scores. We need to test the samplers which we can also do with the analytic scores!

If you have time/want to do some of the points @rdgao @jsvetter, let me know. We can also chat about what should be done in more detail.

Copy link

codecov bot commented Jul 8, 2024

Codecov Report

Attention: Patch coverage is 59.12548% with 215 lines in your changes missing coverage. Please review.

Project coverage is 74.30%. Comparing base (6fd2a6b) to head (be093d3).

Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1015       +/-   ##
===========================================
- Coverage   84.54%   74.30%   -10.25%     
===========================================
  Files          95      102        +7     
  Lines        7576     8036      +460     
===========================================
- Hits         6405     5971      -434     
- Misses       1171     2065      +894     
Flag Coverage Δ
unittests 74.30% <59.12%> (-10.25%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
sbi/analysis/tensorboard_output.py 86.41% <100.00%> (ø)
sbi/inference/__init__.py 100.00% <100.00%> (ø)
sbi/inference/nspe/__init__.py 100.00% <100.00%> (ø)
sbi/inference/posteriors/direct_posterior.py 98.79% <100.00%> (ø)
...inference/potentials/likelihood_based_potential.py 100.00% <100.00%> (ø)
.../inference/potentials/posterior_based_potential.py 96.96% <100.00%> (ø)
sbi/inference/snle/snle_base.py 93.81% <100.00%> (ø)
sbi/inference/snpe/snpe_a.py 63.91% <100.00%> (-27.40%) ⬇️
sbi/inference/snpe/snpe_base.py 90.96% <100.00%> (ø)
sbi/inference/snpe/snpe_c.py 73.15% <100.00%> (-21.48%) ⬇️
... and 23 more

... and 22 files with indirect coverage changes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

score-based diffusion models as density estimators
5 participants