# Interfacing with Stan
This notebook explains how it is possible to interface with [Stan](https://mc-stan.org/) to make use of their log-probability calculations, gradient calculation (via autodiff) and their large library of probability distributions.

One thing to be mindful of is that the interface below only allows the log probability to be accessed from Stan objects up to an additive constant.

In this notebook, we use the [Eight Schools example](http://pints.readthedocs.io/en/latest/toy/eight_schools.html) and show how the model can be defined in Stan but called from in Pints.

In [1]:
import pystan

Define Stan model using their syntax.

In [2]:
code="""
data {
  int<lower=0> J;
  real y[J];
  real<lower=0> sigma[J];
}

parameters {
  real mu;
  real<lower=0> tau;
  real theta[J];
}

model {
  mu ~ normal(0, 5);
  tau ~ cauchy(0, 5);
  theta ~ normal(mu, tau);
  y ~ normal(theta, sigma);
}
"""

Then compile Stan model (we also show code below to save the compiled model to avoid having to redo compilation).

In [3]:
import pickle
# sm = pystan.StanModel(model_code=code)

# to pickle compiled model
# with open('model.pkl', 'wb') as f:
#     pickle.dump(sm, f)

# to load pickled compiled model
sm = pickle.load(open('model.pkl', 'rb'))

Import the Eight Schools ("centered" parameterisation) model from Pints.

In [4]:
import os
os.chdir("../../")
import pints
import pints.toy
import numpy as np
import matplotlib.pyplot as plt
import time
import scipy.stats

model = pints.toy.EightSchoolsCenteredLogPDF()

Use the data provided with Pints' toy model to run Stan's NUTS for a few iterations (only needed so that we can access the functions bound to a stanfit object, so ignore warnings).

In [5]:
fit = sm.sampling(data=model.data(), iter=10, chains=1, verbose=True, refresh=10)
names = fit.unconstrained_param_names()



Define a `pints.LogPDF` that wraps stanfit object. In doing this, it is important to note that we will be using the `log_prob` and `grad_log_prob` functions bound to the stanfit object to calculate the log probability and gradient (both of which remove any constants from the log probability). These functions operate in an unconstrained space, so we also need to use the `unconstrain_pars` argument to convert constrained parameters to be unconstrained.

In [6]:
class EightSchoolsStanLogPDF(pints.LogPDF):
    def __init__(self, stanfit):
        self._fit = stanfit
        self._log_prob = stanfit.log_prob
        self._grad_log_prob = stanfit.grad_log_prob
        # Stan takes dictionary of parameter values
        self._dict_dynamic = {'mu':1, 'tau':1, 'theta': [2] * 8}
        # convert variables from unconstrained to constrained space
        self._u_to_c = stanfit.unconstrain_pars

    def __call__(self, x):
        # handle case when tau < 0 is proposed (otherwise Stan throws error)
        if x[1] < 0:
            print("hello")
            return -np.inf
        self._dict_update(x)
        return self._log_prob(self._u_to_c(self._dict_dynamic), adjust_transform=True)
    
    def _dict_update(self, x):
        self._dict_dynamic["mu"] = x[0]
        self._dict_dynamic["tau"] = x[1]
        self._dict_dynamic["theta"] = x[2:]
    
    def evaluateS1(self, x):
        # handle case when tau < 0 is proposed (otherwise Stan throws error)
        if x[1] < 0:
            return -np.inf, np.repeat(1e6,10)
        self._dict_update(x)
        uncons = self._u_to_c(self._dict_dynamic)
        return self._log_prob(uncons, adjust_transform=True), self._grad_log_prob(uncons, adjust_transform=True)

    def n_parameters(self):
        return 10

Run [relativistic HMC](../sampling/relativistic-mcmc.ipynb) sampler using this model.

In [None]:
# instantiate Pints version of Stan model
stanmodel = EightSchoolsStanLogPDF(fit)

# initialise
xs = [np.random.normal(size=10) for chain in range(4)]
# set 1st element of each list to positive number since corresponds to a scale parameter
for x in xs:
    x[1] = np.random.uniform()

mcmc = pints.MCMCController(stanmodel, len(xs), xs, method=pints.RelativisticMCMC)

# Add stopping criterion
mcmc.set_max_iterations(2000)

# Set up modest logging
mcmc.set_log_to_screen(True)
mcmc.set_log_interval(400)

# # Update step sizes used by individual samplers
for sampler in mcmc.samplers():
    sampler.set_leapfrog_step_size(0.2)

start = time.time()
# Run!
print('Running...')
full_chains = mcmc.run()
print('Done!')
end = time.time()

Running...
Using Relativistic MCMC
Generating 4 chains.
Running in sequential mode.
Iter. Eval. Accept.   Accept.   Accept.   Accept.   Time m:s
0     4      0         0         0         0          0:00.0
1     84     0.333     0.333     0.333     0.333      0:00.0
2     164    0.5       0.25      0.5       0.5        0:00.0
3     244    0.6       0.4       0.6       0.4        0:00.0
400   32004  0.627     0.657     0.888     0.821      0:05.8


In [None]:
results = pints.MCMCSummary(chains=full_chains, time=(end-start), parameter_names=fit.unconstrained_param_names())
print(results)

The wandering chains here illustrate how difficult inference is for this model.

In [None]:
import pints.plot
pints.plot.trace(full_chains)
plt.show()

Just to hammer home the differences between Stan's log probability and Pints': we can compare them for sets of parameter values.

In [None]:
params = np.random.uniform(size=10)

pintsmodel = pints.toy.EightSchoolsCenteredLogPDF()

print(str("Stan log prob: " + str(stanmodel(params))))
print(str("Pints log prob: " + str(pintsmodel(params))))

The same goes for sensitivities!

In [None]:
p1, dp1 = stanmodel.evaluateS1(params)
p2, dp2 = pintsmodel.evaluateS1(params)

print(str("Stan d log(prob)/dmu: " + str(dp1[0])))
print(str("Pints d log(prob)/dmu: " + str(dp2[0])))

## Non-centered model

To speed things up, we can move to Stan's non-centered parameterisation. (Pints also has a version of this model: ) This model introduces auxillary variables $\tilde{\theta}_j$ which ensure the joint distribution: $p(\mu,\tau, \boldsymbol{\theta})$ remains the same, but is easier to sample from.

In [None]:
code="""
data {
  int<lower=0> J;
  real y[J];
  real<lower=0> sigma[J];
}

parameters {
  real mu;
  real<lower=0> tau;
  real theta_tilde[J];
}

transformed parameters {
  real theta[J];
  for (j in 1:J)
    theta[j] = mu + tau * theta_tilde[j];
}

model {
  mu ~ normal(0, 5);
  tau ~ cauchy(0, 5);
  theta_tilde ~ normal(0, 1);
  y ~ normal(theta, sigma);
}
"""

Compile Stan model and run it for a few iterations to get a stanfit object.

In [None]:
import pickle
sm = pystan.StanModel(model_code=code)

# to pickle compiled model
with open('model_ncp.pkl', 'wb') as f:
    pickle.dump(sm, f)

# to load pickled compiled model
sm = pickle.load(open('model_ncp.pkl', 'rb'))

# Run Stan model for a few iterations
fit = sm.sampling(data=model.data(), iter=10, chains=1, verbose=True, refresh=10)

Wrap the Stan model in Pints. Note that the parameter names have changed from 'theta' to 'theta.tilde'.

In [None]:
class EightSchoolsNonCenteredStanLogPDF(pints.LogPDF):
    def __init__(self, stanfit):
        self._fit = stanfit
        self._log_prob = stanfit.log_prob
        self._grad_log_prob = stanfit.grad_log_prob
        # Stan takes dictionary of parameter values
        self._dict_dynamic = {'mu':1, 'tau':1, 'theta_tilde': [2] * 8}
        # convert variables from unconstrained to constrained space
        self._u_to_c = stanfit.unconstrain_pars

    def __call__(self, x):
        # handle case when tau < 0 is proposed (otherwise Stan throws error)
        if x[1] < 0:
            print("hello")
            return -np.inf
        self._dict_update(x)
        return self._log_prob(self._u_to_c(self._dict_dynamic), adjust_transform=True)
    
    def _dict_update(self, x):
        self._dict_dynamic["mu"] = x[0]
        self._dict_dynamic["tau"] = x[1]
        self._dict_dynamic["theta_tilde"] = x[2:]
    
    def evaluateS1(self, x):
        # handle case when tau < 0 is proposed (otherwise Stan throws error)
        if x[1] < 0:
            return -np.inf, np.repeat(1e6,10)
        self._dict_update(x)
        uncons = self._u_to_c(self._dict_dynamic)
        return self._log_prob(uncons, adjust_transform=True), self._grad_log_prob(uncons, adjust_transform=True)

    def n_parameters(self):
        return 10

Retry sampling, this time with the non-centered parameterisation.

In [None]:
# instantiate Pints version of Stan model
stanmodel = EightSchoolsNonCenteredStanLogPDF(fit)

# initialise
xs = [np.random.normal(size=10) for chain in range(4)]
# set 1st element of each list to positive number since corresponds to a scale parameter
for x in xs:
    x[1] = 5+np.random.uniform()

mcmc = pints.MCMCController(stanmodel, len(xs), xs, method=pints.HamiltonianMCMC)

# Add stopping criterion
mcmc.set_max_iterations(4000)

# Set up modest logging
mcmc.set_log_to_screen(True)
mcmc.set_log_interval(200)

# # Update step sizes used by individual samplers
for sampler in mcmc.samplers():
    sampler.set_leapfrog_step_size(0.2)

start = time.time()
# Run!
print('Running...')
full_chains = mcmc.run()
print('Done!')
end = time.time()

Obtain $\theta_j = \mu + \tilde{\theta}_j \tau$.

In [None]:
full_chains_transformed = []
for i in range(len(full_chains)):
    mu = full_chains[i][:, 0]
    tau = full_chains[i][:, 1]
    theta_tilde_j = full_chains[i][:, 2:]
    full_chains_transformed.append(
        np.concatenate((np.transpose(np.vstack((mu, tau))),
                        mu[:, np.newaxis] + (theta_tilde_j * tau[:, np.newaxis])),axis=1))

Aaaaahh, that's better...

In [None]:
results = pints.MCMCSummary(chains=full_chains_transformed,
                            time=(end-start), parameter_names=names)
print(results)

...much more efficient sampling!

In [None]:
import pints.plot
pints.plot.trace(full_chains_transformed)
plt.show()