In [1]:
import warnings
warnings.filterwarnings("ignore")
import pymc4 as pm4
import tensorflow as tf
from tensorflow_probability import distributions as tfd
import numpy as np
import arviz as az

```
Original Stan model
data {
  int<lower=0> J; // number of schools
  real y[J]; // estimated treatment effects
  real<lower=0> sigma[J]; // s.e. of effect estimates
}
parameters {
  real mu;
  real<lower=0> tau;
  real eta[J];
}
transformed parameters {
  real theta[J];
  for (j in 1:J)
    theta[j] <- mu + tau * eta[j];
}
model {
  eta ~ normal(0, 1);
  y ~ normal(theta, sigma);
}
```

In [2]:
J = 8
y = np.array([28,  8, -3,  7, -1,  1, 18, 12], dtype=np.float32)
sigma = np.array([15, 10, 16, 11,  9, 11, 10, 18], dtype=np.float32)

In [3]:
@pm4.model
def schools_pm4():
    #eta = yield pm4.Normal("eta", 0, 1, plate=J)
    
    # Unfortunately these have to be scalar tensors, maybe default to plate=1?
    eta = yield pm4.Normal("eta", tf.zeros(J), tf.ones(J))
    mu = yield pm4.Normal("mu", tf.zeros(1), tf.ones(1))
    tau = yield pm4.HalfNormal('tau', tf.ones(1) * 2.)

    theta = mu + tau * eta

    obs = yield pm4.Normal('obs', theta, scale=sigma, observed=y)

In [4]:
@pm4.model
def schools_pm4():
    eta = yield pm4.Normal("eta", 0, 1, plate=J)
    mu = yield pm4.Normal("mu", 0, 1, plate=1)
    tau = yield pm4.HalfNormal('tau', 2., plate=1)
    
    theta = mu + tau * eta

    obs = yield pm4.Normal('obs', theta, scale=sigma, observed=y)

In [5]:
%%time
tf_trace = pm4.inference.sampling.sample(schools_pm4(), 
                                         step_size=.28,
                                         num_chains=5, 
                                         num_samples=100,
                                         xla=False)

CPU times: user 9.52 s, sys: 766 ms, total: 10.3 s
Wall time: 8.1 s


In [6]:
%%time
tf_trace = pm4.inference.sampling.sample(schools_pm4(), 
                                         step_size=.28,
                                         num_chains=50, 
                                         num_samples=100,
                                         xla=False)

CPU times: user 10.1 s, sys: 887 ms, total: 10.9 s
Wall time: 7.82 s


TODO: Make work with XLA

In [7]:
%%time
tf_trace = pm4.inference.sampling.sample(schools_pm4(), 
                                         step_size=.28,
                                         num_chains=50, 
                                         num_samples=100,
                                         xla=True)

CPU times: user 16.7 s, sys: 207 ms, total: 16.9 s
Wall time: 17.1 s


In [None]:
import arviz as az

In [None]:
tf_trace[1].numpy().shape

In [None]:
tf_trace[0].numpy().shape

In [None]:
# TODO: Order changes for different runs, should fix order
trace_tfp = az.from_dict({'eta': np.swapaxes(tf_trace[1].numpy(), 1, 0),
                          'mu':  tf_trace[0].numpy()[..., 0].T,
                          'tau': tf_trace[2].numpy()[..., 0].T})

In [None]:
trace_tfp.posterior.eta

In [None]:
az.plot_trace(trace_tfp);

## Using python NUTS sampler

In [None]:
from pymc4 import hmc

In [None]:
model = schools_pm4()
logp_func, tensors = pm4.inference.sampling.build_logp_function(model)

In [None]:
def logp_array(input_tensors):
    t1, t2, t3 = input_tensors[:8], input_tensors[8], input_tensors[9]
    tensors = [t1, t2, t3]
    with tf.GradientTape() as tape:
        tape.watch(tensors)
        logp = logp_func(*[tensors])
    grad = tape.gradient(logp, tensors)

    return logp, grad

# As the above function expects TF inputs and outputs, wrap it as PyMC3's samplers want numpy
def logp_wrapper(arr):
    #logp, grad = logp_array([tf.convert_to_tensor(arr) for arr in arrs])
    logp, grad = logp_array(tf.convert_to_tensor(arr))
    grad = np.concatenate([np.atleast_1d(l.numpy()) for l in grad])
    return logp, grad#,.numpy(), grad.numpy()

In [None]:
size = 10
n_samples = 500

sampler = hmc.NUTS(logp_dlogp_func=logp_wrapper, 
                   size=size, 
                   dtype=np.float32)

curr = np.ones(size, dtype='float32') * .05
posterior_samples = []
stats = []

In [None]:
%%time
for i in range(n_samples):
    curr, stat = sampler.step(curr)
    posterior_samples.append(curr)
    stats.append(stat)
    if i % 20 == 0:
        print(i)
    
trace = np.array(posterior_samples)

In [None]:
trace_python = az.from_dict({'eta': trace[:, :8].T[..., np.newaxis].T, 
                             'mu':  trace[:, 8, np.newaxis].T, 
                             'tau': trace[:, 9, np.newaxis].T})

In [None]:
az.plot_trace(trace_python);