# Posterior geometry

HMC and its variant NUTS use gradient information to draw (approximate) samples from a posterior distribution. 
These gradients are computed in a particular coordinate system, and different choices of coordinate system can make HMC more or less efficient. 
For this reason it is important to pay attention to the *geometry* of the posterior distribution. 
Reparameterizing the model (i.e. changing the coordinate system) can make a big practical difference for many complex models. 
For the most complex models it can be absolutely essential. For the same reason it can be important to pay attention to some of the hyperparameters that control NUTS (in particular the `max_tree_depth`). 

In this tutorial we explore models with bad posterior geometries---and what one can do to get achieve better performance---with a few examples.

In [None]:
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro

In [69]:
from functools import partial

import numpy as np

import jax.numpy as jnp
from jax import random

import numpyro
import numpyro.distributions as dist
from numpyro.diagnostics import summary

from numpyro.infer import MCMC, NUTS, init_to_uniform
assert numpyro.__version__.startswith('0.7.2')

# NB: replace gpu by gpu to run this notebook on gpu
numpyro.set_platform("cpu")

We begin by writing a helper function to do NUTS inference.

In [70]:
def run_inference(model, 
                  num_warmup=1000, 
                  num_samples=1000,
                  max_tree_depth=10, 
                  dense_mass=False, 
                  init_strategy=init_to_uniform):
    kernel = NUTS(model, 
                  init_strategy=init_strategy, 
                  max_tree_depth=max_tree_depth,
                  dense_mass=dense_mass)
    mcmc = MCMC(
        kernel,
        num_warmup=num_warmup,
        num_samples=num_samples,
        num_chains=1,
        progress_bar=False,
    )
    mcmc.run(random.PRNGKey(0))
    summary_dict = summary(mcmc.get_samples(), 0.90, group_by_chain=False) 
    
    # print the largest r_hat for each variable
    for k, v in summary_dict.items():
        spaces = " " * max(15 - len(k), 0)
        print("[{}] {} \t max r_hat: {:.4f}".format(k, spaces, np.max(v['r_hat'])))

# Evaluating HMC/NUTS

In general it is difficult to assess whether the samples returned from HMC or NUTS represent accurate (approximate) posterior samples from the posterior. 
Two general rules of thumb, however, are to look at the effective sample size (ESS) and r_hat diagnostics returned by `mcmc.print_summary()`.
If we see values of r_hat in the range `(1.0, 1.05)` and effective sample sizes that are comparable to the total number of samples `num_samples` (assuming `thinning=1`) then we have good reason to believe that HMC is doing a good job. 
If, however, we see low effective sample sizes or large r_hats for some of the variables (e.g. r_hat = 1.51) then HMC is likely struggling with the posterior geometry. 

# Model reparameterization

We begin with an example (horseshoe regression; see also examples/horseshoe.py) where reparameterization helps. 
This particular example demonstrates a general reparameterization strategy that is useful in many models with hierarchical/multi-level structure.

In [71]:
# In this unreparameterized model some of the parameters of the distributions
# explicitly depend on other parameters (i.e. beta depends on lamba and tau).
# This kind of coordinate system can be a challenge for HMC.
def _unrep_hs_model(X, Y):
    lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(X.shape[1])))
    tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))
    betas = numpyro.sample("betas", dist.Normal(tau * lambdas))
    mean_function = jnp.dot(X, betas)
    numpyro.sample("Y", dist.Normal(mean_function, 0.05), obs=Y)

# In this reparameterized model none of the parameters of the distributions
# explicitly depend on other parameters. 
# These two models are exactly equivalent but are expressed 
# in different coordinate systems.
def _rep_hs_model(X, Y):
    lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(X.shape[1])))
    tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))
    unscaled_betas = numpyro.sample("unscaled_betas", dist.Normal(jnp.ones(X.shape[1])))
    scaled_betas = numpyro.deterministic("betas", tau * lambdas * unscaled_betas)
    mean_function = jnp.dot(X, scaled_betas)
    numpyro.sample("Y", dist.Normal(mean_function, 0.05), obs=Y)    
 
# create fake dataset
X = np.random.RandomState(0).randn(100, 500)
Y = X[:, 0]

print("unreparameterized model (bad r_hats)")
run_inference(partial(_unrep_hs_model, X, Y))

print("\nreparameterized model (good r_hats)")
run_inference(partial(_rep_hs_model, X, Y))

unreparameterized model (bad r_hats)
[betas]            	 max r_hat: 1.1636
[lambdas]          	 max r_hat: 1.0351
[tau]              	 max r_hat: 1.0884

reparameterized model (good r_hats)
[betas]            	 max r_hat: 1.0108
[lambdas]          	 max r_hat: 1.0132
[tau]              	 max r_hat: 1.0003
[unscaled_betas]   	 max r_hat: 1.0041


# Mass matrices
By default HMC/NUTS use diagonal mass matrices. 
For models with complex geometries it can pay to use a richer set of mass matrices.
In this simple example we show that using a full-rank (i.e. dense) mass matrix leads to better r_hats.

In [68]:
# because rho is almost 1.0 the posterior geometry is extremely skewed and using
# the "diagonal" coordinate system implied by dense_mass=False leads to bad results
rho = 0.9999
true_cov = jnp.array([[10.0, rho], [rho, 0.1]])

def mvn_model():
    numpyro.sample("x", 
        dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=true_cov)
    )
    
print("dense_mass = False (bad r_hats)")
run_inference(mvn_model, dense_mass=False, max_tree_depth=3)

print("dense_mass = True (good r_hats)")
run_inference(mvn_model, dense_mass=True, max_tree_depth=3)

dense_mass = False (bad r_hats)
[x]                	 max r_hat: 1.3810
dense_mass = True (good r_hats)
[x]                	 max r_hat: 0.9992
