Skip to content

toddpocuca/bayinx

Repository files navigation

Bayinx: Bayesian Inference with JAX

Bayinx is an embedded probabilistic programming language in Python, powered by JAX. It is heavily inspired by and aims to have feature parity with Stan, but extends the types of objects you can work with and focuses on normalizing flows variational inference for sampling.

Installation

Bayinx requires Python 3.12+, JAX, and a few extra libraries in the JAX ecosystem. The easiest way to get started is by installing from PyPi using your favourite python package manager. For example with uv:

# Ensure you are in your project environment
uv add bayinx

This installs the bare-bones version of Bayinx, however if you need additional functionality like GPU support, there are a couple of dependency groups:

# Ensure you are in your project environment
uv add 'bayinx[cuda]' # Installs Bayinx with CUDA support

Documentation

Documentation is available at: https://toddpocuca.github.io/bayinx.

Roadmap

  • Allow shape definitions to include expressions (e.g., shape = 'n_obs + 1' will evaluate to the correct specification).
  • Find a nice way to track the ELBO trajectory to implement early stopping (tolerance currently does nothing).
  • Nodes carry bounds for their support (i.e., node.obj ∈ [node._lb, node._ub]) which are used to check if inputs to distributions are valid (e.g., a node inputted as the scale of a normal dist must have node._lb >= 0)
  • Refactor NF implementation to support forward and reverse flows (reverse defaults to a differentiable root-solver) to support STL estimator.

About

Bayesian Inference with JAX

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages