Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: flow matching methods #1049

Open
wants to merge 96 commits into
base: main
Choose a base branch
from
Open

feat: flow matching methods #1049

wants to merge 96 commits into from

Conversation

turnmanh
Copy link
Contributor

@turnmanh turnmanh commented Mar 20, 2024

What does this implement/fix? Explain your changes

Implements basic Flow Matchting Posterior Estimation as described by Dax et al.

Does this close any currently open issues?

Fixes #963

Any relevant code examples, logs, error output, etc?

...

Any other comments?

...

Checklist

Put an x in the boxes that apply. You can also fill these out after creating
the PR. If you're unsure about any of them, don't hesitate to ask. We're here to
help! This is simply a reminder of what we are going to look for before merging
your code.

  • I have read and understood the contribution
    guidelines
  • I agree with re-licensing my contribution from AGPLv3 to Apache-2.0.
  • I have commented my code, particularly in hard-to-understand areas
  • I have added tests that prove my fix is effective or that my feature works
  • I have reported how long the new tests run and potentially marked them
    with pytest.mark.slow.
  • New and existing unit tests pass locally with my changes
  • I performed linting and formatting as described in the contribution
    guidelines
  • I rebased on main (or there are no conflicts with main)

@turnmanh
Copy link
Contributor Author

fyi @rdgao

@janfb
Copy link
Contributor

janfb commented Jul 1, 2024

Notes from discussion with Faried:

New plan: We merge this PR with a minimal working version of flow matching posterior estimation:

  • the flow matching estimator class still inherits from ConditionalDensityEstimator and behaves just like discrete time normalizing flow density estimator. via zuko it provides a direct interface to sample and log_prob methods
  • it is not theta and x agnostic, i.e., x will always be the conditioning variable for now.
  • time is sampled uniformly during training
  • z-scoring for theta and x is optional and happens in the estimator class for now because it is not the same as for discrete time flows.
  • the neural network for flow matching is built in the factory from a flowmatching_nn method similar to posterior_nn / likelihood_nn
  • there is an option for passing an embedding net for x, default is Identity. the embedded x is then concatenated with theta and t and passed to the vector field estimator net.

Features for a future PR:

  • refactoring into a ConditionalVectorFieldEstimator(ConditionalEstimator) class framework together with the score matching methods
  • SBI related flow matching improvements from Dax, Wildberger et al. paper
    • conditioning on theta and time happens through Gated Linear Units
    • time prior can be power-law distribution and not just uniform
  • z-scoring as part of embedding network (for x) and part of flow (for theta)
  • refactoring to become agnostic w.r.t. conditioning variable -> flow matching could be used for likelihood estimation as well.

fariedabuzaid and others added 2 commits July 1, 2024 18:01
Co-authored-by: Jan <janfb@users.noreply.github.com>
@janfb janfb marked this pull request as ready for review July 8, 2024 13:48
janfb added 8 commits July 8, 2024 15:57
- move flow matching build fns to flow_matcher.py
- move FlowMatchingEstimator to zuko_flow.py
- introduce embedding nets with z-scoring
- use same signatures like in ZukoFlow.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Continuous normalizing flows via flow matching
6 participants