Skip to content

Latest commit

 

History

History
47 lines (29 loc) · 1.64 KB

README.md

File metadata and controls

47 lines (29 loc) · 1.64 KB

Stochastic interpolants

This repository contains a minimal implementation of some concepts related to stochastic interpolants in JAX, based on this paper by Michael S. Albergo, Nicholas M. Boffi, Eric Vanden-Eijnden.

Disclaimer: This implementation is meant to be didactic. For a more functional version (in Pytorch), see the repository published by the authors of the paper here.

Installation

Before installing this project, and after creating & activating your virtual environment, you must install JAX yourself because CPU and GPU backends require different installation commands. See here for instructions. For the small examples, pip install jax[cpu] will suffice. For the bigger demos, a GPU is helpful.

Then, move to the root of the directory and run

pip install .

This command installs all requirements (Flax, Optax, etc.).

Then, find the content as

from stochint import *

Demonstrations

Find the demos in demos/.

Acknowledgements

Thanks to Paul Jeha (@pablo2909) for teaching us how to write a name with 2d samples.