Skip to content

Commit

Permalink
Minor pending fixes (#180)
Browse files Browse the repository at this point in the history
  • Loading branch information
neerajprad authored and fehiepsi committed May 31, 2019
1 parent 2945e40 commit 868068f
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 20 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Probabilistic programming with Numpy powered by [JAX](https://github.com/google/

## What is NumPyro?

NumPyro is a small probabilistic programming library built on [JAX](https://github.com/google/jax). It essentially provides a NumPy backend for [Pyro](https://github.com/pyro-ppl/pyro), with some minor changes to the inference API and syntax. Since we use JAX, we get autograd and JIT compilation to GPU / CPU for free. This is an alpha release, and the API is subject to change as the design evolves.
NumPyro is a small probabilistic programming library built on [JAX](https://github.com/google/jax). It essentially provides a NumPy backend for [Pyro](https://github.com/pyro-ppl/pyro), with some minor changes to the inference API and syntax. Since we use JAX, we get autograd and JIT compilation to GPU / CPU for free. This is an alpha release under active development, so beware of brittleness, bugs, and changes to the API as the design evolves.

NumPyro is designed to be *lightweight* and focuses on providing a flexible substrate that users can build on:

Expand Down Expand Up @@ -40,7 +40,7 @@ pip install -e .[dev]

For some examples on specifying models and doing inference in NumPyro:

- [Bayesian Regression in Numpyro](https://nbviewer.jupyter.org/github/pyro-ppl/numpyro/blob/master/notebooks/bayesian_regression.ipynb) - Start here to get acquainted with writing a simple model in NumPyro, MCMC inference API, effect handlers and writing custom inference utilities.
- [Bayesian Regression in NumPyro](https://nbviewer.jupyter.org/github/pyro-ppl/numpyro/blob/master/notebooks/bayesian_regression.ipynb) - Start here to get acquainted with writing a simple model in NumPyro, MCMC inference API, effect handlers and writing custom inference utilities.
- [Time Series Forecasting](https://nbviewer.jupyter.org/github/pyro-ppl/numpyro/blob/master/notebooks/time_series_forecasting.ipynb) - Illustrates how to convert for loops in the model to JAX's `lax.scan` primitive for fast inference.
- [Baseball example](https://github.com/pyro-ppl/numpyro/blob/master/examples/baseball.py) - Using NUTS for a simple hierarchical model. Compare this with the baseball example in [Pyro](https://github.com/pyro-ppl/pyro/blob/dev/examples/baseball.py).
- [Hidden Markov Model](https://github.com/pyro-ppl/numpyro/blob/master/examples/hmm.py) in NumPyro as compared to [Stan](https://mc-stan.org/docs/2_19/stan-users-guide/hmms-section.html).
Expand All @@ -53,6 +53,7 @@ Users will note that the API for model specification is largely the same as Pyro

In the near term, we plan to work on the following. Please open new issues for feature requests and enhancements:

- Improving robustness of inference on different models, profiling and performance tuning.
- More inference algorithms, particularly those that require second order derivaties or use HMC.
- Integration with [Funsor](https://github.com/pyro-ppl/funsor) to support inference algorithms with delayed sampling.
- Supporting more distributions, extending the distributions API, and adding more samplers to JAX.
Expand Down
26 changes: 13 additions & 13 deletions notebooks/bayesian_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Bayesian Regression Using Numpyro\n",
"# Bayesian Regression Using NumPyro\n",
"\n",
"In this tutorial, we will explore how to do bayesian regression in Numpyro, using a simple example adapted from Statistical Rethinking [[1](#References)]. In particular, we would like to explore the following:\n",
" - Write a simple model using the `sample` Numpyro primitive.\n",
" - Run inference using MCMC in Numpyro, in particular, using the No U-Turn Sampler (NUTS) to get a posterior distribution over our regression parameters of interest.\n",
"In this tutorial, we will explore how to do bayesian regression in NumPyro, using a simple example adapted from Statistical Rethinking [[1](#References)]. In particular, we would like to explore the following:\n",
" - Write a simple model using the `sample` NumPyro primitive.\n",
" - Run inference using MCMC in NumPyro, in particular, using the No U-Turn Sampler (NUTS) to get a posterior distribution over our regression parameters of interest.\n",
" - Learn about utilities such as `initialize_model` that are useful for running HMC.\n",
" - Learn how we can use effect-handlers in Numpyro to generate execution traces, condition on sample sites, seed models with RNG seeds, etc., and use this to implement various utilities that will be useful for MCMC. e.g. computing model log likelihood, generating empirical distribution over the posterior predictive, etc.\n",
" - Learn how we can use effect-handlers in NumPyro to generate execution traces, condition on sample sites, seed models with RNG seeds, etc., and use this to implement various utilities that will be useful for MCMC. e.g. computing model log likelihood, generating empirical distribution over the posterior predictive, etc.\n",
"\n",
"**Tutorial Outline:**\n",
"1. [Dataset](#Dataset)\n",
Expand Down Expand Up @@ -1161,7 +1161,7 @@
"source": [
"## Regression Model to Predict Divorce Rate\n",
"\n",
"Let us now write a regressionn model in *Numpyro* to predict the divorce rate as a linear function of marriage rate and median age of marriage in each of the states. \n",
"Let us now write a regressionn model in *NumPyro* to predict the divorce rate as a linear function of marriage rate and median age of marriage in each of the states. \n",
"\n",
"First, note that our predictor variables have somewhat different scales. It is a good practice to standardize our predictors and response variables to mean `0` and standard deviation `1`, which should result in faster inference. Refer to this [note](https://mc-stan.org/docs/2_19/stan-users-guide/standardizing-predictors-and-outputs.html) in the Stan manual for more details."
]
Expand All @@ -1181,9 +1181,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We write the Numpyro model as follows. While the code should largely be self-explanatory, take note of the following:\n",
" - In Numpyro, model code is any Python callable that can accept arguments and keywords. For HMC which we will be using for this tutorial, these arguments and keywords cannot change during model execution. This is convenient for passing in numpy arrays, or boolean arguments that might affect the execution path.\n",
" - In addition to regular Python statements, the model code also contains primitives like `sample`. These primitives can be interpreted with various side-effects by effect handlers used by inference algorithms in Numpyro. For more on effect handlers, refer to [[3](#References)], [[4](#References)]. For now, just remember that a `sample` statement makes this a stochastic function by sampling from some distribution of interest.\n",
"We write the NumPyro model as follows. While the code should largely be self-explanatory, take note of the following:\n",
" - In NumPyro, model code is any Python callable that can accept arguments and keywords. For HMC which we will be using for this tutorial, these arguments and keywords cannot change during model execution. This is convenient for passing in numpy arrays, or boolean arguments that might affect the execution path.\n",
" - In addition to regular Python statements, the model code also contains primitives like `sample`. These primitives can be interpreted with various side-effects by effect handlers used by inference algorithms in NumPyro. For more on effect handlers, refer to [[3](#References)], [[4](#References)]. For now, just remember that a `sample` statement makes this a stochastic function by sampling from some distribution of interest.\n",
" - The reason why we have kept our predictors as optional keyword arguments is to be able to reuse the same model as we vary the set of predictors. Likewise, the reason why the response variable is optional is that we would like to reuse this model to sample from the posterior predictive distribution. See the [section](#Posterior-Predictive-Distribution) on plotting the posterior predictive distribution, as an example."
]
},
Expand Down Expand Up @@ -1216,14 +1216,14 @@
"We first try to model the divorce rate as depending on a single variable, marriage rate. As mentioned above, we can use the same `model` code as earlier, but only pass values for `marriage` and `divorce` keyword arguments. We will use the No U-Turn Sampler (see [[5](#References)] for more details on the NUTS algorithm) to run inference on this simple model.\n",
"\n",
"\n",
"Note the following requirements for running HMC and NUTS in Numpyro:\n",
"Note the following requirements for running HMC and NUTS in NumPyro:\n",
" - The Hamiltonian Monte Carlo (or, the NUTS) implementation in Pyro takes in a potential energy function. This is the negative log joint density for the model. \n",
" - The verlet integrator in HMC (or, NUTS) returns sample values simulated using Hamiltonian dynamics in the unconstrained space. As such, continuous variables with bounded support need to be transformed into unconstrained space using bijective transforms. We also need to transform these samples back to their constrained support before returning these values to the user.\n",
" \n",
"Thankfully, all of this is handled on the backend for us. Let us go through the steps one by one.\n",
" - JAX uses functional PRNGs. Unlike other languages / frameworks which maintain a global random state, in JAX, every call to a sampler requires an [explicit PRNGKey](https://github.com/google/jax#random-numbers-are-different). We will split our initial random seed for subsequent operations, so that we do not accidentally reuse the same seed.\n",
" - The function [initialize_model](https://numpyro.readthedocs.io/en/latest/mcmc.html#numpyro.hmc_util.initialize_model) takes a model along with model arguments (and keyword arguments), and returns a tuple of initial parameters, potential energy function, and constrain function. The initial parameters are used to initiate the MCMC chain, the potential energy function is a callable that when given unconstrained sample values returns the potential energy at these sample values. This is used by the verlet integrator in HMC. Lastly, `constrain_fn` is a callable that transforms the unconstrained samples returned by HMC/NUTS to sample values that lie within the constrained support.\n",
" - Finally, we use the [mcmc](https://numpyro.readthedocs.io/en/latest/mcmc.html#numpyro.mcmc.mcmc) function to run inference using the default `NUTS` sampler. Note that to run vanilla HMC, all you need to do is to pass `sampler='hmc'` as argument to `mcmc` instead. This is a convenience utility that does all of the following:\n",
" - Finally, we use the [mcmc](https://numpyro.readthedocs.io/en/latest/mcmc.html#numpyro.mcmc.mcmc) function to run inference using the default `NUTS` sampler. Note that to run vanilla HMC, all you need to do is to pass `algo='hmc'` as argument to `mcmc` instead. This is a convenience utility that does all of the following:\n",
" - Runs warmup - adapts steps size and mass matrix.\n",
" - Uses the sample from the warmup phase to start MCMC.\n",
" - Return samples from the posterior distribution and print diagnostic information."
Expand Down Expand Up @@ -1283,9 +1283,9 @@
"\n",
"During warmup, the aim is to adapt or learn values for hyper-parameters such as step size and mass matrix (the HMC algorithm is very sensitive to these hyper-parameters), and to reach the typical set (see [[6](#References)] for more details). If there are any issues in the model specification, it might be reflected in low acceptance probabilities or very high number of steps. We use the sample from the end of the warmup phase to seed the MCMC chain (denoted by the second `sample` progress bar) from which we generate the desired number of samples from our target distribution.\n",
"\n",
"At the end of inference, Numpyro prints the mean, std and 90% CI values for each of the latent parameters. Note that since we standardized our predictors and response variable, we would expect the intercept to have mean 0, as can be seen here. It also prints other convergence diagnostics on the latent parameters in the model, including [effective sample size](https://numpyro.readthedocs.io/en/latest/diagnostics.html#numpyro.diagnostics.effective_sample_size) and the [gelman rubin diagnostic](https://numpyro.readthedocs.io/en/latest/diagnostics.html#numpyro.diagnostics.gelman_rubin) ($\\hat{R}$). The value for these diagnostics indicates that the chain has converged to the target distribution. In our case, the \"target distribution\" is the posterior distribution over the latent parameters that we are interested in. Note that this is often worth verifying with multiple chains on more complicated models. In the end, `samples_1` is a collection (in our case, a `dict` since `init_samples` was a `dict`) containing samples from the posterior distribution for each of the latent parameters in the model.\n",
"At the end of inference, NumPyro prints the mean, std and 90% CI values for each of the latent parameters. Note that since we standardized our predictors and response variable, we would expect the intercept to have mean 0, as can be seen here. It also prints other convergence diagnostics on the latent parameters in the model, including [effective sample size](https://numpyro.readthedocs.io/en/latest/diagnostics.html#numpyro.diagnostics.effective_sample_size) and the [gelman rubin diagnostic](https://numpyro.readthedocs.io/en/latest/diagnostics.html#numpyro.diagnostics.gelman_rubin) ($\\hat{R}$). The value for these diagnostics indicates that the chain has converged to the target distribution. In our case, the \"target distribution\" is the posterior distribution over the latent parameters that we are interested in. Note that this is often worth verifying with multiple chains on more complicated models. In the end, `samples_1` is a collection (in our case, a `dict` since `init_samples` was a `dict`) containing samples from the posterior distribution for each of the latent parameters in the model.\n",
"\n",
"To look at our regression fit, let us plot the regression line using our posterior estimates for the regression parameters, along with the 90% Credibility Interval (CI). Note that the [hpdi](https://numpyro.readthedocs.io/en/latest/diagnostics.html#numpyro.diagnostics.hpdi) function in Numpyro's diagnostics module can be used to compute CI. In the functions below, note that the collected samples from the posterior are all along the leading axis.\n",
"To look at our regression fit, let us plot the regression line using our posterior estimates for the regression parameters, along with the 90% Credibility Interval (CI). Note that the [hpdi](https://numpyro.readthedocs.io/en/latest/diagnostics.html#numpyro.diagnostics.hpdi) function in NumPyro's diagnostics module can be used to compute CI. In the functions below, note that the collected samples from the posterior are all along the leading axis.\n",
"\n",
"We can see from the plot, that the CI broadens towards the tails where values of the predictor variables are sparse, as can be expected."
]
Expand Down
4 changes: 2 additions & 2 deletions notebooks/logistic_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,8 @@
"| ------------- |:---------:|:---------:|\n",
"| Edward2 (CPU) | | 68.4 ms |\n",
"| Edward2 (GPU) | | 9.7 ms |\n",
"| Numpyro (CPU) | 30.3 ms | 29.8 ms |\n",
"| Numpyro (GPU) | 4.3 ms | 4.7 ms |\n",
"| NumPyro (CPU) | 30.3 ms | 29.8 ms |\n",
"| NumPyro (GPU) | 4.3 ms | 4.7 ms |\n",
"\n",
"*Note:* Edward 2 results are obtained from reference [1], which is run under a different environment system."
]
Expand Down
6 changes: 3 additions & 3 deletions notebooks/time_series_forecasting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The time series has a length of 114 (a data point for each year), and by looking at the plot, we can observe [seasonality](https://en.wikipedia.org/wiki/Seasonality) in this dataset, which is the recurrence of similar patterns at specific time periods. e.g. in this dataset, we observe a cyclical pattern every 10 years, but there is also a less obvious but clear spike in the number of trappings every 40 years. Let us see if we can model this effect in Numpyro.\n",
"The time series has a length of 114 (a data point for each year), and by looking at the plot, we can observe [seasonality](https://en.wikipedia.org/wiki/Seasonality) in this dataset, which is the recurrence of similar patterns at specific time periods. e.g. in this dataset, we observe a cyclical pattern every 10 years, but there is also a less obvious but clear spike in the number of trappings every 40 years. Let us see if we can model this effect in NumPyro.\n",
"\n",
"In this tutorial, we will use the first 80 values for training and the last 34 values for testing."
]
Expand Down Expand Up @@ -160,7 +160,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that `level` and `s` are updated recursively while we collect the expected value at each time step. Numpyro uses [JAX](https://github.com/google/jax) in the backend to JIT compile many critical parts of the NUTS algorithm, including the verlet integrator and the tree building process. However, doing so using Python's `for` loop in the model will result in a long compilation time for the model, so we use `jax.lax.scan` instead. A detailed explanation for using this utility can be found in [lax.scan documentation](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan). Here we use it to collect expected values while the pair `(level, s)` plays the role of carrying state."
"Note that `level` and `s` are updated recursively while we collect the expected value at each time step. NumPyro uses [JAX](https://github.com/google/jax) in the backend to JIT compile many critical parts of the NUTS algorithm, including the verlet integrator and the tree building process. However, doing so using Python's `for` loop in the model will result in a long compilation time for the model, so we use `jax.lax.scan` instead. A detailed explanation for using this utility can be found in [lax.scan documentation](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan). Here we use it to collect expected values while the pair `(level, s)` plays the role of carrying state."
]
},
{
Expand Down Expand Up @@ -198,7 +198,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"With our utility function defined above, we are ready to specify the model using *NumPyro* primitives. In NumPyro, we use the primitive `sample(name, prior)` to declare a latent random variable with a corresponding `prior`. These primitives can have custom interpretations depending on the effect handlers that are used by Numpyro inference algorithms in the backend. e.g. we can condition on specific values using the `substitute` handler, or record values at these sample sites in the execution trace using the `trace` handler. Note that these details are not important for specifying the model, or running inference, but curious readers are encouraged to read the [tutorial on effect handlers](http://pyro.ai/examples/effect_handlers.html) in Pyro."
"With our utility function defined above, we are ready to specify the model using *NumPyro* primitives. In NumPyro, we use the primitive `sample(name, prior)` to declare a latent random variable with a corresponding `prior`. These primitives can have custom interpretations depending on the effect handlers that are used by NumPyro inference algorithms in the backend. e.g. we can condition on specific values using the `substitute` handler, or record values at these sample sites in the execution trace using the `trace` handler. Note that these details are not important for specifying the model, or running inference, but curious readers are encouraged to read the [tutorial on effect handlers](http://pyro.ai/examples/effect_handlers.html) in Pyro."
]
},
{
Expand Down

0 comments on commit 868068f

Please sign in to comment.