# Transport maps (Triangular, NeuTra)
Based on "Transport map accelerated Markov chain Monte Carlo" by Parno and Marzouk (SIAM/ASA Journal on Uncertainty Quantification 2018) and "NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport" by Hoffman et al. (Symposium on Approximate Bayesian Inference 2018)

In [1]:
import jax
import jax.numpy as np
import numpy as onp
import numpyro
from numpyro import distributions as dist
from numpyro.examples.runge_kutta import runge_kutta_4

In [2]:
rng_key = jax.random.PRNGKey(242)

## Predator Prey Model

In [3]:
def predator_prey(t, state, r=0.6, k=100, s=1.2, a=25, u=0.5, v=0.3):
    prey = state[..., 0]
    predator = state[..., 1]
    sh = (prey * predator)/(a + prey)
    prey_upd = r * prey * (1 - prey / k) - s * sh
    predator_upd = u * sh - v * predator
    return np.stack((prey_upd, predator_upd), axis=-1)

In [4]:
indices = np.array([1, 11, 21, 31, 41])
res = runge_kutta_4(predator_prey, np.array([50., 5.]), 0.1, 50)
noise = jax.random.normal(rng_key, (1000,5,2)) * 10
data = (indices, res[indices] + noise)
data

(DeviceArray([ 1, 11, 21, 31, 41], dtype=int32),
 DeviceArray([[[ 39.024696 , -12.725592 ],
               [ 66.320984 ,   6.957094 ],
               [ 68.026535 ,   3.555342 ],
               [ 91.41165  ,  12.742411 ],
               [ 72.754944 ,  -2.7839265]],
 
              [[ 70.52715  ,   8.503686 ],
               [ 82.815445 ,   7.1307592],
               [ 47.49247  ,  -3.0286903],
               [ 79.045616 ,  12.486261 ],
               [ 61.746765 ,   7.427475 ]],
 
              [[ 44.7713   ,  -8.647638 ],
               [ 68.058876 ,  10.351557 ],
               [ 62.73179  ,  20.398611 ],
               [ 76.33872  ,  24.19844  ],
               [ 68.47187  ,  11.550174 ]],
 
              ...,
 
              [[ 65.181244 ,  17.313404 ],
               [ 62.664757 ,   5.912359 ],
               [ 77.01439  ,   7.073277 ],
               [ 80.45603  ,  16.42875  ],
               [ 73.00238  ,   4.452697 ]],
 
              [[ 60.07134  ,  -6.554132 ],
               

In [5]:
# %%
def model(indices, observations):
    prey0 = numpyro.sample('prey0', dist.Uniform(0.1, 50.))
    predator0 = numpyro.sample('predator0', dist.Uniform(0.1, 50.))
    r = numpyro.sample('r', dist.Uniform(0.1, 50.))
    k = numpyro.sample('k', dist.Uniform(0.1, 150.))
    s = numpyro.sample('s', dist.Uniform(0.1, 50.))
    a = numpyro.sample('a', dist.Uniform(0.1, 50.))
    u = numpyro.sample('u', dist.Uniform(0.1, 50.))
    v = numpyro.sample('v', dist.Uniform(0.1, 50.))
    stddev = numpyro.sample('stddev', dist.InverseGamma(1.))
    ppres = runge_kutta_4(lambda t, state: predator_prey(t, state, r, k, s, a, u, v),
                          np.array([prey0, predator0]), 0.01, 50)
    numpyro.sample('obs', dist.Normal(ppres[indices], stddev).to_event(2), obs=observations)

### NUTS MCMC Sampling

In [6]:
mcmc = numpyro.infer.MCMC(numpyro.infer.NUTS(model), 1_000, 10_000, num_chains=2, chain_method='sequential')
mcmc.run(rng_key, *data)
mcmc.print_summary()

sample: 100%|██████████| 11000/11000 [09:33<00:00, 19.18it/s, 7 steps of size 4.62e-03. acc. prob=0.76]
sample: 100%|██████████| 11000/11000 [09:08<00:00, 20.04it/s, 383 steps of size 5.48e-03. acc. prob=0.80]

                 mean       std    median      5.0%     95.0%     n_eff     r_hat
          a     22.92     15.00     21.85      0.11     44.48   1778.86      1.01
          k    104.82     10.59    103.15     88.58    120.91   2312.04      1.00
  predator0      5.35      0.27      5.35      4.88      5.77   1652.85      1.00
      prey0     49.84      0.14     49.88     49.65     50.00   1904.03      1.00
          r      6.64      1.10      6.43      4.98      8.21   2481.80      1.00
          s     18.50     12.89     15.93      0.11     38.60   2148.65      1.00
     stddev     10.12      0.07     10.12     10.00     10.24   4402.94      1.00
          u      9.88     10.04      6.50      0.55     23.51    681.40      1.01
          v      7.68      9.63      4.19      0.10

### Guide and Stein with Transport Maps