Skip to content

Commit

Permalink
Major overhaul of API, Docs and Tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
tttc3 committed Jan 24, 2024
1 parent 84a5d2f commit b74a26f
Show file tree
Hide file tree
Showing 66 changed files with 2,560 additions and 2,962 deletions.
6 changes: 2 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# Other gitignore
img/
*_build/
experiments/
.DS_Store

## Python gitignore
# Byte-compiled / optimized / DLL files
Expand Down Expand Up @@ -142,4 +140,4 @@ dmypy.json
cython_debug/

# Ruff
.ruff_cache/
.ruff_cache/
13 changes: 6 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.9
rev: v0.1.14
hooks:
- id: ruff
types_or: [ python, pyi, jupyter ]
args: [ --fix ]
- id: ruff-format
types_or: [ python, pyi, jupyter ]
# TODO: Enable in next release.
# - repo: https://github.com/RobertCraigie/pyright-python
# rev: v1.1.342
# hooks:
# - id: pyright
# additional_dependencies: [equinox, diffrax, jax, jaxtyping, pytest]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.348
hooks:
- id: pyright
additional_dependencies: [equinox, diffrax, scikit-learn, jax, jaxtyping, pytest]
137 changes: 51 additions & 86 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,38 +8,22 @@

<!-- Add the badges in here -->
[![Documentation Status](https://readthedocs.org/projects/mccube/badge/?version=latest)](https://mccube.readthedocs.io/en/latest/?badge=latest)
[![CI](https://github.com/tttc3/MCCube/actions/workflows/tests.yml/badge.svg)](https://github.com/tttc3/MCCube/actions/workflows/tests.yml/)
[![pypi version](https://img.shields.io/pypi/v/mccube.svg)](https://pypi.org/project/mccube/)

MCCube is a [JAX](https://jax.readthedocs.io) library for constructing Markov Chain
Cubatures (MCCs) that (weakly) solve certain SDEs, and thus, can be used for performing
Bayesian inference.
MCCube provides the tools for performing Markov chain cubature in [diffrax](https://github.com/patrick-kidger/diffrax).

The core features of MCCube are:
- Approximate Bayesian inference of JAX transformable functions (support for PyTorch, Tensorflow and Numpy functions is provided via [Ivy](https://unify.ai/docs/ivy/compiler/transpiler.html));
- A simple Markov chain cubature inference loop, [mccubaturesolve](https://mccube.readthedocs.io/api/_inference);
- A framework for constructing/defining cubature kernels and cubature formulae;
- Visualization tools for evaluating and debugging inference/solving performance;
- A custom solver for using MCC in [Diffrax](https://docs.kidger.site/diffrax/)
**Key features:**

> [!warning]\
> This package is currently a work-in-progress/experimental. Expect bugs, API instability, and treat all results with a healthy degree of skepticism.
## Who should use MCCube?
MCCube should appeal to:
- Users of [Blackjax](https://github.com/blackjax-devs/blackjax#who-should-use-blackjax) (people who need/want modular GPU/TPU capable samplers);
- Users of [Diffrax](https://github.com/patrick-kidger/diffrax) (people who need to solve SDEs/CDEs);
- Markov chain cubature researchers/developers.
- Custom terms, paths, and solvers that provide a painless means to perform MCC in diffrax.
- A small library of recombination kernels, convential cubature formulae, and metrics.

## Installation
To install the base pacakge:
```bash
pip install mccube
```
If you want all the extras provided in `mccube.extensions`:
```bash
pip install mccube[extras]
```

Requires Python 3.9+, JAX 0.4.23+, and Equinox 0.11.1+.
Requires Python 3.10+, Diffrax 0.4.1+, and Equinox 0.11.1+.

By default, a CPU only version of JAX will be installed. To make use of other JAX/XLA
compatible accelerators (GPUs/TPUs) please follow [these installation instructions](https://github.com/google/jax#pip-installation-gpu-cuda-installed-via-pip-easier).
Expand All @@ -50,83 +34,64 @@ using JAX on Windows.
Available at [https://mccube.readthedocs.io/](https://mccube.readthedocs.io/).

## What is Markov chain cubature?
MCC is an approach to constructing a [Cubature on Wiener Space](https://www.jstor.org/stable/4143098) which does not suffer from exponential scaling in time (particle count explosion), thanks to the utilization of (partitioned) recombination in the cubature kernel (Markov transition kernel).

## Quick Example
# TODO:
# 1. IMPROVE THIS EXAMPLE AND ADD MORE TO THE DOCS.
# 2. GET ALL TESTS AND PRE-COMMIT TO PASS.
# 3. FIX TESTS BENCHMARKING OUTPUT.
# 4. ADD KERNEL TESTS.
# 5. UPDATED METRICS.
# 6. UPDATED VISUALIZERS.
# 7. IMPROVE CI/CD

The below toy example demonstrates MCCube for inferring the moments of a ten dimensional
Gaussian, with mean two and diagonal covariance six, given its logdensity function.
More **in-depth examples are coming soon**.
MCC is an approach to constructing a [Cubature on Wiener Space](https://www.jstor.org/stable/4143098)
which does not suffer from exponential scaling in time (particle count explosion),
thanks to the utilization of (partitioned) recombination in the (approximate) cubature
kernel.

### Example
```Python
import diffrax
import jax
import numpy as np
import jax.numpy as jnp
import jax.random as jr
from jax.scipy.stats import multivariate_normal

from mccube import OverdampedLangevinKernel, MonteCarloKernel, WienerSpace, LyonsVictoir04_512, mccubaturesolve
from mccube.metrics import cubature_target_error

# Setup the problem.
n_particles = 1024
target_dimension = 10
rng = np.random.default_rng(42)
prior_particles = rng.uniform(size=(n_particles, target_dimension))
target_mean = 2 * np.ones(target_dimension)
target_cov = 6 * np.eye(target_mean)
# MCCube expects the log-density to have call signature (t, p(t), args), allowing the
# density to be time dependant, or to rely on some other generic args.
def target_logdensity(t, p, args):
return multivariate_normal.logpdf(p, target_mean, target_cov)

# Setup the MCCubature.
recombinator_key = jax.random.PRNGKey(42)
cfv = LyonsVictoir04_512(WienerSpace(target_dimension))
approximate_cubature_kernel = [
OverdampedLangevinKernel(target_logdensity, cfv.evaluate),
MonteCarloKernel(n_particles, recombinator_key)
]
# Construct the MCCubature/solve for the MCCubature paths.
mccubature_paths = mccubaturesolve(
transition_kernel=approximate_cubature_kernel
initial_particles=prior_particles,
from mccube import (
GaussianRegion,
Hadamard,
LocalLinearCubaturePath,
MCCSolver,
MCCTerm,
MonteCarloKernel,
gaussian_wasserstein_metric,
)
# Compare mean and covariance of the inferred cubature to the target.
posterior_particles = mccubature_paths.particles[-1, :, :]
mean_err, cov_err = cubature_target_error(posterior_particles, target_mean, target_cov)
print(f"Mean Error: {mean_err}\n", f"Cov Error: {cov_err}")
```

Note that `mccubaturesolve` returns the cubature paths, but does not return any other
intermediate step information. If such information is required, a 'visualizer' callback
can be used, for example:
key = jr.PRNGKey(42)
n, d = 512, 10
t0 = 0.0
epochs = 512
dt0 = 0.05
t1 = t0 + dt0 * epochs
y0 = jnp.ones((n, d))

```python
from mccube.extensions.visualizers import TensorboardVisualizer
target_mean = 2 * jnp.ones(d)
target_cov = 3 * jnp.eye(d)

with TensorboardVisualizer() as tbv:
mccubature_paths = mccubaturesolve(..., callbacks=[tbv])
```

To make use of the Tensorboard visualization suite remember to run the following command
either during/after each experimental run:
def logdensity(p):
return multivariate_normal.logpdf(p, mean=target_mean, cov=target_cov)

```bash
tensorboard --logdir=experiments
```

Note that the Tensorboard package must be installed separately.
ode = diffrax.ODETerm(lambda t, p, args: jax.vmap(jax.grad(logdensity))(p))
cde = diffrax.WeaklyDiagonalControlTerm(
lambda t, p, args: jnp.sqrt(2.0),
LocalLinearCubaturePath(Hadamard(GaussianRegion(d))),
)
terms = MCCTerm(ode, cde)
solver = MCCSolver(diffrax.Euler(), MonteCarloKernel(n, key=key))

sol = diffrax.diffeqsolve(terms, solver, t0, t1, dt0, y0)
res_mean = jnp.mean(sol.ys[-1], axis=0)
res_cov = jnp.cov(sol.ys[-1], rowvar=False)
metric = gaussian_wasserstein_metric((target_mean, res_mean), (target_cov, res_cov))

print(f"Result 2-Wasserstein distance: {metric}")
```

## Citation
Please cite this repository if it has been useful in your work:
```
```bibtex
@software{mccube2023github,
author={},
title={{MCC}ube: Markov chain cubature via {JAX}},
Expand All @@ -144,4 +109,4 @@ Some other Python/JAX packages that you may find interesting:
- [Equinox](https://github.com/patrick-kidger/equinox) A JAX library for parameterised functions.
- [Diffrax](https://github.com/patrick-kidger/diffrax) A JAX library providing numerical differential equation solvers.
- [Lineax](https://github.com/google/lineax) A JAX library for linear solves and linear least squares.
- [OTT-JAX](https://github.com/ott-jax/ott) A JAX library for optimal transport.
- [OTT-JAX](https://github.com/ott-jax/ott) A JAX library for optimal transport.
111 changes: 92 additions & 19 deletions docs/_static/custom.css
Original file line number Diff line number Diff line change
@@ -1,27 +1,100 @@
#site-navigation{
h1.site-logo {
font-weight: bold;
}
/* Custom colors */
:root {
--ds-blue: #33BCAD;
--ds-mid-blue: #28968A;
--ds-dark-blue: #0A1446;
--ds-green: #DDEF00;
--ds-grey: #E6E6DC;
--ds-dark-grey: #D8D8C9;
--ds-white: #FFFFFF;
}

/* Fixes overflow from align math environments */
.math-wrapper {
overflow-x: visible;
[data-md-color-scheme="mccube"] {
color-scheme: light;

/* Primary */
--md-primary-fg-color: var(--ds-dark-blue);
--md-primary-fg-color--dark: var(--ds-dark-grey);

/* Accent */
--md-accent-fg-color: var(--ds-blue);
--md-accent-bg-color: var(--ds-green);

/* Default */
--md-default-bg-color: var(--ds-grey);

/* Code */
--md-code-bg-color: var(--ds-dark-grey);

/* Typeset */
--md-typeset-a-color: var(--ds-mid-blue);

/* Admonition */
--md-admonition-bg-color: var(--ds-grey);

/* Footer */
--md-footer-fg-color: var(--ds-dark-blue);
--md-footer-fg-color--light: var(--ds-blue);
--md-footer-fg-color--ligher: var(--ds-blue);
--md-footer-bg-color: var(--ds-dark-grey);
--md-footer-bg-color--dark: var(--ds-dark-blue);
}


/* Typeset */
.md-typeset a {
text-decoration: underline;
}

/* Navigation. */
.md-nav__source {
border-top: 0.1rem solid var(--md-accent-bg-color);
color: var(--md-primary-fg-color);
}

.md-sidebar--secondary .md-sidebar__inner {
border-left: 0.1rem solid var(--md-typeset-a-color);
}

/* Header. */
.md-header {
border-bottom: 0.15rem solid var(--md-accent-bg-color);
}

.md-header__button.md-logo img, .md-header__button.md-logo svg {
height: 2.0rem;
width: 2.0rem;
}

/* Add a line break between the arguments of each function.
Taken from: https://github.com/sphinx-doc/sphinx/issues/1514#issuecomment-742703082
*/
/* .sig-param::before {
content: "\a\20\20\20\20\20\20";
white-space: pre;
/* Footer. */
html .md-footer-meta.md-typeset a:hover {
color: var(--md-typeset-a-color);
}

dt em.sig-param:last-of-type::after {
content: "\a";
white-space: pre;
/* Indentation. */
div.doc-contents:not(.first) {
padding-left: 25px;
border-left: .15rem solid var(--md-typeset-table-color);
}

dl.class > dt:first-of-type {
display: block !important;
} */
/* Mark external links as such. */
a.external::after,
a.autorefs-external::after {
/* https://primer.style/octicons/arrow-up-right-24 */
mask-image: url('data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M18.25 15.5a.75.75 0 00.75-.75v-9a.75.75 0 00-.75-.75h-9a.75.75 0 000 1.5h7.19L6.22 16.72a.75.75 0 101.06 1.06L17.5 7.56v7.19c0 .414.336.75.75.75z"></path></svg>');
-webkit-mask-image: url('data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M18.25 15.5a.75.75 0 00.75-.75v-9a.75.75 0 00-.75-.75h-9a.75.75 0 000 1.5h7.19L6.22 16.72a.75.75 0 101.06 1.06L17.5 7.56v7.19c0 .414.336.75.75.75z"></path></svg>');
content: ' ';

display: inline-block;
vertical-align: middle;
position: relative;

height: 1em;
width: 1em;
background-color: var(--md-typeset-a-color);
}

a.external:hover::after,
a.autorefs-external:hover::after {
background-color: var(--md-accent-fg-color);
}
19 changes: 19 additions & 0 deletions docs/_static/mathjax.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
window.MathJax = {
tex: {
inlineMath: [["\\(", "\\)"]],
displayMath: [["\\[", "\\]"]],
processEscapes: true,
processEnvironments: true
},
options: {
ignoreHtmlClass: ".*|",
processHtmlClass: "arithmatex"
}
};

document$.subscribe(() => {
MathJax.startup.output.clearCache()
MathJax.typesetClear()
MathJax.texReset()
MathJax.typesetPromise()
})
2 changes: 1 addition & 1 deletion docs/_static/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ @article{mysovskikh1975
url = {https://www.sciencedirect.com/science/article/pii/0041555375902189}
}

@book{rabinowitz1984
@book{rabinowitz1984,
title = {Methods of Numerical Integration},
author = {Davis, P. J. and Rabinowitz, P.},
year = {1984},
Expand Down
1 change: 1 addition & 0 deletions docs/api/_custom_types.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: mccube._custom_types
1 change: 1 addition & 0 deletions docs/api/_formulae.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: mccube._formulae
1 change: 1 addition & 0 deletions docs/api/_kernels/base.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: mccube._kernels.base
1 change: 1 addition & 0 deletions docs/api/_kernels/random.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: mccube._kernels.random
1 change: 1 addition & 0 deletions docs/api/_kernels/stratified.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: mccube._kernels.stratified
1 change: 1 addition & 0 deletions docs/api/_kernels/tree.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: mccube._kernels.tree
1 change: 1 addition & 0 deletions docs/api/_metrics.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: mccube._metrics
1 change: 1 addition & 0 deletions docs/api/_path.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: mccube._path
1 change: 1 addition & 0 deletions docs/api/_regions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: mccube._regions
1 change: 1 addition & 0 deletions docs/api/_solvers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: mccube._solvers
1 change: 1 addition & 0 deletions docs/api/_term.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: mccube._term
1 change: 1 addition & 0 deletions docs/api/_utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: mccube._utils
Loading

0 comments on commit b74a26f

Please sign in to comment.