Bayinx is an embedded probabilistic programming language in Python, powered by JAX. It is heavily inspired by and aims to have feature parity with Stan, but extends the types of objects you can work with and focuses on normalizing flows variational inference for sampling.
Bayinx requires Python 3.12+, JAX, and a few extra libraries in the JAX ecosystem.
The easiest way to get started is by installing from PyPi using your favourite python package manager. For example with uv:
# Ensure you are in your project environment
uv add bayinxThis installs the bare-bones version of Bayinx, however if you need additional functionality like GPU support, there are a couple of dependency groups:
# Ensure you are in your project environment
uv add 'bayinx[cuda]' # Installs Bayinx with CUDA supportDocumentation is available at: https://toddpocuca.github.io/bayinx.
- Allow shape definitions to include expressions (e.g., shape = 'n_obs + 1' will evaluate to the correct specification).
- Find a nice way to track the ELBO trajectory to implement early stopping (tolerance currently does nothing).
- Nodes carry bounds for their support (i.e., node.obj ∈ [node._lb, node._ub]) which are used to check if inputs to distributions are valid (e.g., a node inputted as the scale of a normal dist must have
node._lb >= 0) - Refactor NF implementation to support forward and reverse flows (reverse defaults to a differentiable root-solver) to support STL estimator.