# Posterior geometry

HMC and its variant NUTS use gradient information to draw (approximate) samples from a posterior distribution. 
These gradients are computed in a particular coordinate system, and different choices of coordinate system can make HMC more or less efficient. 
For this reason it is important to pay attention to the *geometry* of the posterior distribution. 
Reparameterizing the model (i.e. changing the coordinate system) can make a big practical difference for many complex models. 
For the most complex models it can be absolutely essential. For the same reason it can be important to pay attention to some of the hyperparameters that control HMC/NUTS (in particular the `max_tree_depth` and `dense_mass`). 

In this tutorial we explore models with bad posterior geometries---and what one can do to get achieve better performance---with a few examples.

In [1]:
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro

In [2]:
from functools import partial

import numpy as np

import jax.numpy as jnp
from jax import random

import numpyro
import numpyro.distributions as dist
from numpyro.diagnostics import summary

from numpyro.infer import MCMC, NUTS, init_to_uniform
assert numpyro.__version__.startswith('0.7.2')

# NB: replace cpu by gpu to run this notebook on gpu
numpyro.set_platform("cpu")

We begin by writing a helper function to do NUTS inference.

In [3]:
def run_inference(model, 
                  num_warmup=1000, 
                  num_samples=1000,
                  max_tree_depth=10, 
                  dense_mass=False, 
                  init_strategy=init_to_uniform):
    
    kernel = NUTS(model, 
                  init_strategy=init_strategy, 
                  max_tree_depth=max_tree_depth,
                  dense_mass=dense_mass)
    mcmc = MCMC(
        kernel,
        num_warmup=num_warmup,
        num_samples=num_samples,
        num_chains=1,
        progress_bar=False,
    )
    mcmc.run(random.PRNGKey(0))
    summary_dict = summary(mcmc.get_samples(), group_by_chain=False) 
    
    # print the largest r_hat for each variable
    for k, v in summary_dict.items():
        spaces = " " * max(15 - len(k), 0)
        print("[{}] {} \t max r_hat: {:.4f}".format(k, spaces, np.max(v['r_hat'])))

# Evaluating HMC/NUTS

In general it is difficult to assess whether the samples returned from HMC or NUTS represent accurate (approximate) samples from the posterior. 
Two general rules of thumb, however, are to look at the effective sample size (ESS) and r_hat diagnostics returned by `mcmc.print_summary()`.
If we see values of r_hat in the range `(1.0, 1.05)` and effective sample sizes that are comparable to the total number of samples `num_samples` (assuming `thinning=1`) then we have good reason to believe that HMC is doing a good job. 
If, however, we see low effective sample sizes or large r_hats for some of the variables (e.g. r_hat = 1.15) then HMC is likely struggling with the posterior geometry. 
In the following we will use `r_hat` as our primary diagnostic metric.

# Model reparameterization

### Example #1

We begin with an example (horseshoe regression; see [examples/horseshoe.py](https://github.com/pyro-ppl/numpyro/blob/master/examples/horseshoe.py) for a fully example script) where reparameterization helps a lot. 
This particular example demonstrates a general reparameterization strategy that is useful in many models with hierarchical/multi-level structure. 
For more discussion of some of the issues that can arise in hierarhical see reference [1].


In [4]:
# In this unreparameterized model some of the parameters of the distributions
# explicitly depend on other parameters (i.e. beta depends on lamba and tau).
# This kind of coordinate system can be a challenge for HMC.
def _unrep_hs_model(X, Y):
    lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(X.shape[1])))
    tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))
    betas = numpyro.sample("betas", dist.Normal(tau * lambdas))
    mean_function = jnp.dot(X, betas)
    numpyro.sample("Y", dist.Normal(mean_function, 0.05), obs=Y)

To deal with the bad geometry that results form this coordinate system we change coordinates using the following re-write logic.
Instead of 

$$ \beta \sim {\rm Normal}(0, \lambda \tau) $$
we write
$$ \beta^\prime \sim {\rm Normal}(0, 1) $$
and
$$ \beta \equiv \lambda \tau \beta^\prime  $$

where $\beta$ is now defined *deterministically* in terms of $\lambda$, $\tau$,
and $\beta^\prime$. In effect we've changed to a coordinate system where the different
latent variables are less correlated with one another. 
In this new coordinate system we can expect HMC with a diagonal mass matrix to behave much better than it would in the original coordinate system.

In [None]:
# In this reparameterized model none of the parameters of the distributions
# explicitly depend on other parameters. These two models are exactly equivalent 
# but are expressed in different coordinate systems.
def _rep_hs_model(X, Y):
    lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(X.shape[1])))
    tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))
    unscaled_betas = numpyro.sample("unscaled_betas", dist.Normal(jnp.ones(X.shape[1])))
    scaled_betas = numpyro.deterministic("betas", tau * lambdas * unscaled_betas)
    mean_function = jnp.dot(X, scaled_betas)
    numpyro.sample("Y", dist.Normal(mean_function, 0.05), obs=Y)    
 
# create fake dataset
X = np.random.RandomState(0).randn(100, 500)
Y = X[:, 0]

print("unreparameterized model (bad r_hats)")
run_inference(partial(_unrep_hs_model, X, Y))

print("\nreparameterized model (good r_hats)")
run_inference(partial(_rep_hs_model, X, Y))

unreparameterized model (bad r_hats)
[betas]            	 max r_hat: 1.1636
[lambdas]          	 max r_hat: 1.0351
[tau]              	 max r_hat: 1.0884

reparameterized model (good r_hats)


### Aside: numpyro.deterministic

In `_rep_hs_model` above we used [`numpyro.deterministic`](http://num.pyro.ai/en/stable/primitives.html?highlight=deterministic#numpyro.primitives.deterministic) to define `scaled_betas`.
We note that using this primitive is not strictly necessary; however, it has the consequence that `scaled_betas` will appear in the trace and will thus appear in the summary reported by `mcmc.print_summary()`. In other words we could also have written:

```
scaled_betas = tau * lambdas * unscaled_betas
```

without invoking the `deterministic` primitive.

# Mass matrices
By default HMC/NUTS use diagonal mass matrices. 
For models with complex geometries it can pay to use a richer set of mass matrices.


### Example #2
In this first simple example we show that using a full-rank (i.e. dense) mass matrix leads to a better r_hat.

In [None]:
# Because rho is very close to 1.0 the posterior geometry is extremely skewed and using
# the "diagonal" coordinate system implied by dense_mass=False leads to bad results
rho = 0.9999
true_cov = jnp.array([[10.0, rho], [rho, 0.1]])

def mvn_model():
    numpyro.sample("x", 
        dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=true_cov)
    )
    
print("dense_mass = False (bad r_hat)")
run_inference(mvn_model, dense_mass=False, max_tree_depth=3)

print("dense_mass = True (good r_hat)")
run_inference(mvn_model, dense_mass=True, max_tree_depth=3)

### Example #3

Using `dense_mass=True` can be very expensive when the dimension of the latent space `D` is very large. In addition it can be difficult to estimate a full-rank mass matrix with `D^2` parameters using a moderate number of samples if `D` is large. In these cases `dense_mass=True` can be a poor choice.  Luckily, the argument `dense_mass` can also be used to specify structured mass matrices that are richer than a diagonal mass matrix but more constrained (i.e. have fewer parameters) than a full-rank mass matrix ([see the docs](http://num.pyro.ai/en/stable/mcmc.html#hmc)).
In this second example we show how we can use `dense_mass` to specify such a structured mass matrix.

In [None]:
rho = 0.9
true_cov = jnp.array([[10.0, rho], [rho, 0.1]])

# In this model x1 and x2 are highly correlated with one another
# but not correlated at all with y.
def partially_correlated_model():
    x1 = numpyro.sample("x1", 
        dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=true_cov)
    )
    x2 = numpyro.sample("x2", 
        dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=true_cov)
    )    
    y = numpyro.sample("y", dist.Normal(jnp.zeros(100), 1.0))
    numpyro.sample("obs", dist.Normal(x1 - x2, 0.1), jnp.ones(2))

In [None]:
print("dense_mass = False (very bad r_hats)")
run_inference(partially_correlated_model, dense_mass=False, max_tree_depth=3)

print("\ndense_mass = True (bad r_hats)")
run_inference(partially_correlated_model, dense_mass=True, max_tree_depth=3)

# We use dense_mass=[("x1", "x2")] to specify
# a structured mass matrix in which the y-part of the mass matrix is diagonal
# and the (x1, x2) block of the mass matrix is full-rank.

# Graphically:
#      x1 x2 y
#  x1 | * * 0 |
#  x2 | * * 0 |
#  y  | 0 0 * |

print("\nstructured mass matrix (good r_hats)")
run_inference(partially_correlated_model, dense_mass=[("x1", "x2")], max_tree_depth=3)

# Other strategies

- In some cases it can make sense to use variational inference to *learn* a new coordinate system. For details see [examples/neutra.py](https://github.com/pyro-ppl/numpyro/blob/master/examples/neutra.py) and reference [3].

# References

[1] "Hamiltonian Monte Carlo for Hierarchical Models,"
    M. J. Betancourt, Mark Girolami.

[2] "Slice sampling," R. M. Neal.

[3] "NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport,"
    Matthew Hoffman, Pavel Sountsov, Joshua V. Dillon, Ian Langmore, Dustin Tran, Srinivas Vasudevan.
    
[4] "Reparameterization" in the Stan user's manual.
    https://mc-stan.org/docs/2_27/stan-users-guide/reparameterization-section.html