In [1]:
%matplotlib inline

In [2]:
%run notebook_setup

# Sampling

`pymc3-ext` comes with some functions to make sampling more flexible in some cases and improve the default parameter choices for the types of problems encountered in astrophysics.
These features are accessed through the `pymc3_ext.sample` function that behaves mostly like the `pymc3.sample` function with a couple of different arguments.
The two main differences for all users is that the `pymc3_ext.sample` function defaults to a target acceptance fraction of `0.9` (which will be better for many models in astrophysics) and to adapting a full dense mass matrix (instead of diagonal).
Therefore, if there are covariances between parameters, this method will generally perform better than the PyMC3 defaults.

## Correlated parameters

A thorough discussion of this [can be found elsewhere online](https://dfm.io/posts/pymc3-mass-matrix/), but here is a simple demo where we sample a covariant Gaussian using `pymc3_ext.sample`.

First, we generate a random positive definite covariance matrix for the Gaussian:

In [3]:
import numpy as np

ndim = 5
np.random.seed(42)
L = np.random.randn(ndim, ndim)
L[np.diag_indices_from(L)] = 0.1 * np.exp(L[np.diag_indices_from(L)])
L[np.triu_indices_from(L, 1)] = 0.0
cov = np.dot(L, L.T)

And then we can set up this model using PyMC3:

In [4]:
import pymc3 as pm

with pm.Model() as model:
    pm.MvNormal("x", mu=np.zeros(ndim), chol=L, shape=ndim)

If we sample this using PyMC3 default sampling method, things don't go so well (we're only doing a small number of steps because we don't want it to take forever, but things don't get better if you run for longer!):

In [5]:
with model:
    trace = pm.sample(tune=500, draws=500, chains=2, cores=2)

Auto-assigning NUTS sampler...


Initializing NUTS using jitter+adapt_diag...


Multiprocess sampling (2 chains in 2 jobs)


NUTS: [x]


Sampling 2 chains for 500 tune and 500 draw iterations (1_000 + 1_000 draws total) took 148 seconds.


The chain reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize.


The chain reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize.


The rhat statistic is larger than 1.05 for some parameters. This indicates slight problems during sampling.


The estimated number of effective samples is smaller than 200 for some parameters.


But, we can use `pymc3_ext.sample` as a drop in replacement to get much better performance:

In [6]:
import pymc3_ext as pmx

with model:
    tracex = pmx.sample(tune=1000, draws=1000, chains=2, cores=2)

Multiprocess sampling (2 chains in 2 jobs)


NUTS: [x]


Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 9 seconds.


As you can see, this is substantially faster (even though we generated twice as many samples).

We can compare the sampling summaries to confirm that the default method did not produce reliable results in this case, while the `pymc3_ext` version did:

In [7]:
pm.summary(trace).head()

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_mean,ess_sd,ess_bulk,ess_tail,r_hat
x[0],0.035,0.153,-0.277,0.283,0.033,0.023,22.0,22.0,22.0,136.0,1.08
x[1],-0.023,0.557,-1.023,1.06,0.069,0.049,65.0,65.0,64.0,93.0,1.02
x[2],-0.131,0.621,-1.251,1.042,0.146,0.105,18.0,18.0,18.0,100.0,1.11
x[3],-0.191,1.153,-2.037,2.126,0.167,0.119,48.0,48.0,49.0,95.0,1.09
x[4],0.24,2.074,-3.891,3.796,0.421,0.302,24.0,24.0,24.0,185.0,1.07


In [8]:
pm.summary(tracex).head()

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_mean,ess_sd,ess_bulk,ess_tail,r_hat
x[0],-0.003,0.163,-0.307,0.314,0.004,0.004,1834.0,963.0,1820.0,1357.0,1.0
x[1],0.006,0.535,-0.989,1.001,0.011,0.012,2496.0,1070.0,2495.0,1656.0,1.0
x[2],0.007,0.653,-1.253,1.221,0.015,0.015,1922.0,966.0,1931.0,1457.0,1.0
x[3],0.005,1.175,-2.088,2.335,0.025,0.026,2172.0,1056.0,2161.0,1523.0,1.0
x[4],0.026,2.066,-3.604,4.135,0.044,0.044,2186.0,1090.0,2156.0,1496.0,1.0


In this particular case, you could get similar performance using the `init="adapt_full"` argument to the `sample` function in PyMC3, but the implementation in `pymc3-ext` is somewhat more flexible.
Specifically, `pymc3_ext` implements a tuning procedure that it more similar to [the one implemented by the Stan project](https://mc-stan.org/docs/2_24/reference-manual/hmc-algorithm-parameters.html).
The relevant parameters are:

- `warmup_window`: The length of the initial "fast" window. This is called "initial buffer" in the Stan docs.
- `adapt_window`: The length of the initial "slow" window. This is called "window" in the Stan docs.
- `cooldown_window`: The length of the final "fast" window. This is called "term buffer" in the Stan docs.

Unlike the Stan implementation, here we have support for updating the mass matrix estimate every `recompute_interval` steps based on the previous window and all the steps in the current window so far.
This can improve warm up performance substantially so the default value is `1`, but this might be intractable for high dimensional models.
To only recompute the estimate at the end of each window, set `recompute_interval=0`.

If you run into numerical issues, you can try increasing `adapt_window` or use the `regularization_steps`and `regularization_variance` to regularize the mass matrix estimator.
The `regularization_steps` parameter sets the effective number of steps that are used for regularization and `regularization_variance` is the effective variance for those steps.

## Parameter groups

If you are fitting a model with a large number of parameters, it might not be computationally or numerically tractable to estimate the full dense mass matrix.
But, sometimes you might know something about the covariance structure of the problem that you can exploit.
Perhaps some parameters are correlated with each other, but not with others.
In this case, you can use the `parameter_groups` argument to exploit this structure.

Here is an example where `x`, `y`, and `z` are all independent with different covariance structure.
We can take advantage of this structure using `pmx.ParameterGroup` specifications in the `parameter_groups` argument.
Note that by default each group will internally estimate a dense mass matrix, but here we specifically only estimate a diagonal mass matrix for `z`.

In [9]:
with pm.Model():
    x = pm.MvNormal("x", mu=np.zeros(ndim), chol=L, shape=ndim)
    y = pm.MvNormal("y", mu=np.zeros(ndim), chol=L, shape=ndim)
    z = pm.Normal("z", shape=ndim)  # Uncorrelated

    tracex2 = pmx.sample(
        tune=1000,
        draws=1000,
        chains=2,
        cores=2,
        parameter_groups=[
            [x],
            [y],
            pmx.ParameterGroup([z], "diag"),
        ],
    )

Multiprocess sampling (2 chains in 2 jobs)


NUTS: [z, y, x]


Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 15 seconds.
