# Interfacing with Stan
This notebook explains how it is possible to interface with Stan to make use of their log-probability calculations, gradient calculation (via autodiff) and their large library of probability distributions.

In [1]:
import pystan
import pickle

In [2]:
code="""data {
  int<lower=0> J;
  real y[J];
  real<lower=0> sigma[J];
}

parameters {
  real mu;
  real<lower=0> tau;
  real theta[J];
}

model {
  mu ~ normal(0, 5);
  tau ~ cauchy(0, 5);
  theta ~ normal(mu, tau);
  y ~ normal(theta, sigma);
}
"""

In [3]:
sm = pickle.load(open('model.pkl', 'rb'))
# sm = pystan.StanModel(model_code=code)

In [4]:
# with open('model.pkl', 'wb') as f:
#     pickle.dump(sm, f)

In [5]:
import os
os.chdir("../../")
import pints
import pints.toy
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy
import time
import scipy.stats

In [6]:
model = pints.toy.EightSchoolsCenteredLogPDF()

In [7]:
fit = sm.sampling(data=model.data(), iter=10, chains=1, verbose=True, refresh=10)



In [92]:
print(fit.unconstrained_param_names())
constrain_to_unconstrained = fit.unconstrain_pars
log_prob = fit.log_prob
print(log_prob(constrain_to_unconstrained({'mu':1, 'tau':1, 'theta': [2] * 8}), adjust_transform=True))
fit.grad_log_prob(constrain_to_unconstrained({'mu':1, 'tau':1, 'theta': [2] * 8}))
print(constrain_to_unconstrained({'mu':1, 'tau':1, 'theta': [2] * 8}))

['mu', 'tau', 'theta.1', 'theta.2', 'theta.3', 'theta.4', 'theta.5', 'theta.6', 'theta.7', 'theta.8']
-7.387585620114306
[1. 0. 2. 2. 2. 2. 2. 2. 2. 2.]


In [137]:
class EightSchoolsStanLogPDF(pints.LogPDF):
    def __init__(self, stanfit):
        self._fit = stanfit
        self._log_prob = stanfit.log_prob
        self._grad_log_prob = stanfit.grad_log_prob
        # Stan takes dictionary of parameter values
        self._dict_dynamic = {'mu':1, 'tau':1, 'theta': [2] * 8}
        # convert variables from unconstrained to constrained space
        self._u_to_c = stanfit.unconstrain_pars

    def __call__(self, x):
        self._dict_update(x)
        return self._log_prob(self._u_to_c(self._dict_dynamic))
    
    def _dict_update(self, x):
        self._dict_dynamic["mu"] = x[0]
        self._dict_dynamic["tau"] = x[1]
        self._dict_dynamic["theta"] = x[2:]
    
    def evaluateS1(self, x):
        self._dict_update(x)
        uncons = self._u_to_c(self._dict_dynamic)
        return self._log_prob(uncons, adjust_transform=True), self._grad_log_prob(uncons, adjust_transform=True)

    def n_parameters(self):
        return 10

In [139]:
stanmodel = EightSchoolsStanLogPDF(fit)
stanmodel.evaluateS1([xs])
1+1

IndexError: list index out of range

In [158]:
# initialise
xs = [np.random.normal(size=10) for chain in range(4)]
# set 1st element of each list to positive number since corresponds to a scale parameter
for x in xs:
    x[1] = np.random.uniform()

mcmc = pints.MCMCController(stanmodel, len(xs), xs, method=pints.HamiltonianMCMC)

# Add stopping criterion
mcmc.set_max_iterations(1000)

# Set up modest logging
mcmc.set_log_to_screen(True)
mcmc.set_log_interval(100)

# # Update step sizes used by individual samplers
for sampler in mcmc.samplers():
    sampler.set_leapfrog_step_size(0.1)

start = time.time()
# Run!
print('Running...')
full_chains = mcmc.run()
print('Done!')
end = time.time()

Running...
Using Hamiltonian Monte Carlo
Generating 4 chains.
Running in sequential mode.
Iter. Eval. Accept.   Accept.   Accept.   Accept.   Time m:s
0     4      0         0         0         0          0:00.0
1     84     0.333     0.333     0.333     0.333      0:00.0
2     164    0.5       0.5       0.5       0.25       0:00.0
3     244    0.6       0.6       0.6       0.4        0:00.0
100   8004   0.922     0.892     0.922     0.872549   0:01.1
200   16004  0.931     0.851     0.911     0.911      0:02.3
300   24004  0.937     0.877     0.934     0.901      0:03.5
400   32004  0.928     0.89801   0.920398  0.908      0:04.6
500   40004  0.92      0.914     0.914     0.894      0:05.8
600   48004  0.929     0.925     0.904     0.904      0:06.9
700   56004  0.917     0.91      0.883     0.896      0:08.1
800   64004  0.925187  0.909     0.867     0.89       0:09.2
900   72004  0.926     0.907     0.874     0.882      0:10.4
1000  79924  0.926     0.907     0.865     0.89       0:

In [159]:
results = pints.MCMCSummary(chains=full_chains, time=(end-start), parameter_names=fit.unconstrained_param_names())
print(results)

param    mean    std.    2.5%    25%    50%    75%    97.5%    rhat    ess    ess per sec.
-------  ------  ------  ------  -----  -----  -----  -------  ------  -----  --------------
mu       -0.13   1.61    -4.16   -0.83  0.09   1.02   2.29     1.44    28.44  2.48
tau      2.37    1.30    0.58    1.46   2.10   2.88   5.87     1.28    32.34  2.82
theta.1  0.77    1.84    -3.44   -0.38  1.05   2.08   3.81     1.00    20.35  1.77
theta.2  0.47    2.04    -3.31   -0.88  0.53   1.54   4.87     1.28    26.06  2.27
theta.3  -0.26   2.54    -7.16   -1.57  0.13   1.23   4.05     1.34    30.91  2.69
theta.4  0.74    2.05    -2.81   -0.62  0.46   1.89   5.77     1.23    38.47  3.35
theta.5  -1.33   3.96    -8.93   -5.15  -0.13  1.71   3.96     1.70    17.08  1.49
theta.6  -0.31   1.95    -4.94   -1.56  -0.03  1.12   2.77     1.54    33.30  2.90
theta.7  -0.81   2.06    -5.96   -1.75  -0.58  0.60   2.66     1.26    28.81  2.51
theta.8  -0.13   1.69    -2.85   -1.37  -0.42  1.04   3.49     1.24  