Skip to content

Commit

Permalink
Add sphinx and rtd support for Numpyro (#138)
Browse files Browse the repository at this point in the history
* Add sphinx and rtd support for Numpyro

* fix lint

* add missing docs
  • Loading branch information
neerajprad authored and fehiepsi committed May 8, 2019
1 parent 2801afd commit 85a8a79
Show file tree
Hide file tree
Showing 12 changed files with 359 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ numpyro/examples/.data

# tmp files
.*.swp

# docs
docs/build
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,7 @@ test: lint FORCE
clean: FORCE
git clean -dfx -e numpyro.egg-info

docs: FORCE
$(MAKE) -C docs html

FORCE:
20 changes: 20 additions & 0 deletions docs/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#

# You can set these variables from the command line.
SPHINXOPTS =
SPHINXBUILD = sphinx-build
SPHINXPROJ = numpyro
SOURCEDIR = source
BUILDDIR = build

# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

.PHONY: help Makefile

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
3 changes: 3 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
jax>=0.1.25
jaxlib>=0.1.12
tqdm
192 changes: 192 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import os
import sys

import sphinx_rtd_theme

# import pkg_resources

# -*- coding: utf-8 -*-
#
# Configuration file for the Sphinx documentation builder.
#
# This file does only contain a selection of the most common options. For a
# full list see the documentation:
# http://www.sphinx-doc.org/en/master/config

# -- Path setup --------------------------------------------------------------

# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
sys.path.insert(0, os.path.abspath('../..'))

# -- Project information -----------------------------------------------------

project = u'Numpyro'
copyright = u'2019, Uber Technologies, Inc'
author = u'Uber AI Labs'

# The short X.Y version
version = u'0.0'
# The full version, including alpha/beta/rc tags
release = u'0.0'


# -- General configuration ---------------------------------------------------

# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'

# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.mathjax',
'sphinx.ext.viewcode',
]

# Disable documentation inheritance so as to avoid inheriting
# docstrings in a different format, e.g. when the parent class
# is a PyTorch class.

autodoc_inherit_docstrings = False

autodoc_default_options = {
'member-order': 'bysource',
'show-inheritance': True,
'special-members': True,
'undoc-members': True,
'exclude-members': '__dict__,__module__,__weakref__',
}

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']

# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'

# The master toctree document.
master_doc = 'index'

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path .
exclude_patterns = []

# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'


# do not prepend module name to functions
add_module_names = False

# -- Options for HTML output -------------------------------------------------

# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = "sphinx_rtd_theme"
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]

# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
# html_theme_options = {}

# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = []

# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
#
# The default sidebars (for documents that don't match any pattern) are
# defined by theme itself. Builtin themes are using these templates by
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
# 'searchbox.html']``.
#
# html_sidebars = {}


# -- Options for HTMLHelp output ---------------------------------------------

# Output file base name for HTML help builder.
htmlhelp_basename = 'numpyrodoc'


# -- Options for LaTeX output ------------------------------------------------

latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',

# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',

# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',

# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}

# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'Numpyro.tex', u'Numpyro Documentation', u'Uber AI Labs', 'manual'),
]

# -- Options for manual page output ------------------------------------------

# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'Numpyro', u'Numpyro Documentation',
[author], 1)
]

# -- Options for Texinfo output ----------------------------------------------

# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'Numpyro', u'Numpyro Documentation',
author, 'Numpyro', 'Pyro PPL on Numpy',
'Miscellaneous'),
]


# -- Extension configuration -------------------------------------------------

# -- Options for intersphinx extension ---------------------------------------

# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
'python': ('https://docs.python.org/3/', None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'jax': ('https://jax.readthedocs.io/en/latest/', None),
'pyro': ('http://docs.pyro.ai/en/stable/', None),
}
8 changes: 8 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Distributions
=============

.. automodule:: numpyro.distributions
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
8 changes: 8 additions & 0 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Handlers
========

.. automodule:: numpyro.handlers
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
38 changes: 38 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
:github_url: https://github.com/pyro-ppl/numpyro


Numpyro documentation
=====================

.. toctree::
:glob:
:maxdepth: 2
:caption: Inference:

mcmc
svi


.. toctree::
:glob:
:maxdepth: 2
:caption: Distributions:

distributions


.. toctree::
:glob:
:maxdepth: 2
:caption: Utilities:

handlers


Indices and tables
==================

* :ref:`genindex`
* :ref:`search`

.. * :ref:`modindex`
8 changes: 8 additions & 0 deletions docs/source/mcmc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Markov Chain Monte Carlo (MCMC)
===============================

.. automodule:: numpyro.mcmc
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
8 changes: 8 additions & 0 deletions docs/source/svi.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Stochastic Variational Inference (SVI)
======================================

.. automodule:: numpyro.svi
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
67 changes: 67 additions & 0 deletions numpyro/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,30 @@ def _euclidean_ke(inverse_mass_matrix, r):


def hmc(potential_fn, kinetic_fn=None, algo='NUTS'):
r"""
Hamiltonian Monte Carlo inference, using either fixed number of
steps or the No U-Turn Sampler (NUTS) with adaptive path length.
**References**
[1] `MCMC Using Hamiltonian Dynamics`,
Radford M. Neal
[2] `The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo`,
Matthew D. Hoffman, and Andrew Gelman.
:param potential_fn: Python callable that computes the potential energy
given input parameters. The input parameters to `potential_fn` can be
any python collection type, provided that ``init_samples`` argument to
``init_kernel`` has the same type.
:param kinetic_fn: Python callable that returns the kinetic energy given
inverse mass matrix and momentum. If not provided, the default is
euclidean kinetic energy.
:param str algo: Whether to run ``HMC`` with fixed number of steps or ``NUTS``
with adaptive path length. Default is ``NUTS``.
:return init_kernel, sample_kernel: Returns a tuple of callables, the first
one to initialize the sampler, and the second one to generate samples
given an existing one.
"""
if kinetic_fn is None:
kinetic_fn = _euclidean_ke
vv_init, vv_update = velocity_verlet(potential_fn, kinetic_fn)
Expand All @@ -61,6 +85,41 @@ def init_kernel(init_samples,
progbar=True,
heuristic_step_size=True,
rng=PRNGKey(0)):
r"""
Initializes the HMC sampler.
:param init_samples: Initial parameters to begin sampling. The type can
must be consistent with the input type to ``potential_fn``.
:param int num_warmup_steps: Number of warmup steps; samples generated
during warmup are discarded.
:param float step_size: Determines the size of a single step taken by the
verlet integrator while computing the trajectory using Hamiltonian
dynamics. If not specified, it will be set to 1.
:param bool adapt_step_size: A flag to decide if we want to adapt step_size
during warm-up phase using Dual Averaging scheme.
:param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
matrix during warm-up phase using Welford scheme.
:param bool diag_mass: A flag to decide if mass matrix is diagonal (default)
or dense (if set to ``False``).
:param float target_accept_prob: Target acceptance probability for step size
adaptation using Dual Averaging. Increasing this value will lead to a smaller
step size, hence the sampling will be slower but more robust. Default to 0.8.
:param float trajectory_length: Length of a MCMC trajectory for HMC. Default
value is :math:`2\pi`.
:param int max_tree_depth: Max depth of the binary tree created during the doubling
scheme of NUTS sampler. Default to 10.
:param bool run_warmup: Flag to decide whether warmup is run. If ``True``,
`init_kernel` returns an initial :func:`~numpyro.mcmc.HMCState` that
can be used to generate samples using MCMC. Else, returns the arguments
and callable that does the initial adaptation.
:param bool progbar: Whether to enable progress bar updates. Defaults to
``True``.
:param bool heuristic_step_size: If ``True``, a coarse grained adjustment of
step size is done at the beginning of each adaptation window to achieve
`target_acceptance_prob`.
:param jax.random.PRNGKey rng: random key to be used as the source of
randomness.
"""
step_size = float(step_size)
nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth
trajectory_len = float(trajectory_length)
Expand Down Expand Up @@ -143,6 +202,14 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, rng):

@jit
def sample_kernel(hmc_state):
r"""
Given a :func:`~numpyro.mcmc.HMCState`, run HMC with fixed (possibly
adapted) step size and return :func:`~numpyro.mcmc.HMCState`.
:param hmc_state: Current sample (and associated state).
:return: new proposed :func:`~numpyro.mcmc.HMCState` from simulating
Hamiltonian dynamics given existing state.
"""
rng, rng_momentum, rng_transition = random.split(hmc_state.rng, 3)
r = momentum_generator(hmc_state.inverse_mass_matrix, rng_momentum)
vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad)
Expand Down

0 comments on commit 85a8a79

Please sign in to comment.