# Transport maps (Triangular, NeuTra)
Based on "Transport map accelerated Markov chain Monte Carlo" by Parno and Marzouk (SIAM/ASA Journal on Uncertainty Quantification 2018) and "NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport" by Hoffman et al. (Symposium on Approximate Bayesian Inference 2018)

In [1]:
import jax
import jax.numpy as np
import numpy as onp
import numpyro
from numpyro.contrib.autoguide import AutoDelta
from numpyro import distributions as dist
from numpyro.distributions import constraints
from numpyro.examples.runge_kutta import runge_kutta_4
from tqdm import tqdm

In [2]:
rng_key = jax.random.PRNGKey(242)

## Predator Prey Model

In [3]:
def predator_prey_step(t, state, r=0.6, k=100, s=1.2, a=25, u=0.5, v=0.3):
    prey = state[..., 0]
    predator = state[..., 1]
    sh = (prey * predator)/(a + prey)
    prey_upd = r * prey * (1 - prey / k) - s * sh
    predator_upd = u * sh - v * predator
    return np.stack((prey_upd, predator_upd), axis=-1)
num_time = 5
step_size = 0.1
num_steps = int(num_time / step_size)
dampening_rate = 0.9
lyapunov_scale = 10e-4
predator_prey = jax.partial(runge_kutta_4, predator_prey_step, step_size, num_steps, dampening_rate,
                            lyapunov_scale, rng_key)

In [4]:
indices = np.array([1, 11, 21, 31, 41])
res, lyapunov_loss = predator_prey(np.array([50., 5.]))
noise = jax.random.normal(rng_key, (1000,5,2)) * 10
data = (indices, res[indices] + noise)
data

(DeviceArray([ 1, 11, 21, 31, 41], dtype=int32),
 DeviceArray([[[ 39.024696 , -12.725592 ],
               [ 66.320984 ,   6.957094 ],
               [ 68.026535 ,   3.555342 ],
               [ 91.41165  ,  12.742411 ],
               [ 72.754944 ,  -2.7839265]],
 
              [[ 70.52715  ,   8.503686 ],
               [ 82.815445 ,   7.1307592],
               [ 47.49247  ,  -3.0286903],
               [ 79.045616 ,  12.486261 ],
               [ 61.746765 ,   7.427475 ]],
 
              [[ 44.7713   ,  -8.647638 ],
               [ 68.058876 ,  10.351557 ],
               [ 62.73179  ,  20.398611 ],
               [ 76.33872  ,  24.19844  ],
               [ 68.47187  ,  11.550174 ]],
 
              ...,
 
              [[ 65.181244 ,  17.313402 ],
               [ 62.664757 ,   5.912359 ],
               [ 77.01439  ,   7.073277 ],
               [ 80.45603  ,  16.42875  ],
               [ 73.00238  ,   4.452697 ]],
 
              [[ 60.07134  ,  -6.554132 ],
               

In [5]:
def model(indices, observations):
    prior_dist = dist.TransformedDistribution(dist.HalfNormal(1000), dist.transforms.AffineTransform(0.1, 1.0))
    prey0 = numpyro.sample('prey0', prior_dist)
    predator0 = numpyro.sample('predator0', prior_dist)
    r = numpyro.sample('r', prior_dist)
    k = numpyro.sample('k', prior_dist)
    s = numpyro.sample('s', prior_dist)
    a = numpyro.sample('a', prior_dist)
    u = numpyro.sample('u', prior_dist)
    v = numpyro.sample('v', prior_dist)
    ppres, lyapunov_loss = predator_prey(np.array([prey0, predator0]), r=r, k=k, s=s, a=a, u=u, v=v)
    numpyro.factor('lyapunov_loss', lyapunov_loss)
    numpyro.sample('obs', dist.Normal(ppres[indices], 10.0).to_event(2), obs=observations)

### NUTS MCMC Sampling

In [6]:
mcmc = numpyro.infer.MCMC(numpyro.infer.NUTS(model), 1000, 5000, num_chains=2, chain_method='vectorized')
mcmc.run(rng_key, *data)
mcmc.print_summary()

sample: 100%|██████████| 6000/6000 [12:27<00:00,  8.03it/s, None]

                 mean       std    median      5.0%     95.0%     n_eff     r_hat
          a     19.20     18.74     10.39      0.80     43.43      1.06      4.94
          k     83.84     10.13     82.23     72.92     95.09      1.01     10.73
  predator0      5.88      0.89      5.76      4.88      7.61      1.15      3.18
      prey0     35.54     15.14     38.83     17.00     51.11      1.02     12.23
          r      5.22      4.76      4.39      0.45     10.97      1.02     14.28
          s      1.62      1.65      0.78      0.29      4.29      1.54      3.22
          u      1.04      0.71      0.93      0.29      1.81      1.01     12.78
          v      0.73      0.47      0.71      0.23      1.23      1.00     14.64

Number of divergences: 0


### Guide and Stein with Transport Maps

In [7]:
svgd = numpyro.infer.SVGD(model, AutoDelta(model), 
                          numpyro.optim.Adam(0.01), numpyro.infer.ELBO(),
                          numpyro.infer.kernels.RBFKernel(), num_stein_particles=10, repulsion_temperature=1.0)
state, loss = svgd.run(rng_key, 100000, *data)

1000
{'auto_a': DeviceArray([3.725022 , 2.4536145, 2.5342782, 0.9883258, 0.2730669,
             0.3243681, 4.35455  , 3.1677094, 1.8741413, 0.6096844],            dtype=float32), 'auto_k': DeviceArray([1.2957686, 0.6810942, 0.5173149, 4.1821365, 2.4269762,
             0.4937738, 6.6752315, 0.519132 , 4.7745185, 1.6039019],            dtype=float32), 'auto_predator0': DeviceArray([0.31267384, 3.843379  , 0.43577906, 1.2506862 , 0.32599968,
             2.45361   , 1.0649745 , 0.43088964, 4.516339  , 0.5544297 ],            dtype=float32), 'auto_prey0': DeviceArray([3.8734853 , 1.6689409 , 0.28280067, 3.3449383 , 2.8067923 ,
             0.23854902, 0.9464084 , 0.64569545, 1.2890145 , 1.1071557 ],            dtype=float32), 'auto_r': DeviceArray([1.7923366 , 0.7232873 , 1.0218931 , 0.45040405, 3.1459153 ,
             2.7192106 , 2.011129  , 0.25398123, 6.068802  , 0.3152902 ],            dtype=float32), 'auto_s': DeviceArray([0.31551304, 0.27015498, 6.01794   , 6.7138968 , 7.0042653 ,

KeyboardInterrupt: 

In [None]:
svgd.get_params(state)

In [None]:
def transmap_guide(indices, observations):
    prey0val = numpyro.param('prey0val', 25)
    predator0val = numpyro.param('predator0val', 25)
    rval = numpyro.param('rval', 25)
    kval = numpyro.param('kval', 25)
    sval = numpyro.param('sval', 25)
    aval = numpyro.param('aval', 25)
    uval = numpyro.param('uval', 25)
    vval = numpyro.param('vval', 25)
    tmapp = numpyro.param('tmapp', np.zeros(8 * (8 + 1) // 2) + 1e-3)
    tril_idx = np.tril_indices(8)
    tmap = jax.ops.index_update(np.zeros((8, 8)), tril_idx, tmapp)
    prey0val, predator0val, rval, kval, sval, aval, uval, vval = tmap @ np.array([prey0val, predator0val, rval, kval, sval, aval, uval, vval])
    prey0 = numpyro.sample('prey0', dist.Delta(prey0val))
    predator0 = numpyro.sample('predator0', dist.Delta(predator0val))
    r = numpyro.sample('r', dist.Delta(rval))
    k = numpyro.sample('k', dist.Delta(kval))
    s = numpyro.sample('s', dist.Delta(sval))
    a = numpyro.sample('a', dist.Delta(aval))
    u = numpyro.sample('u', dist.Delta(uval))
    v = numpyro.sample('v', dist.Delta(vval))

In [None]:
svgd = numpyro.infer.SVGD(model, numpyro.guides.WrappedGuide(transmap_guide), 
                          numpyro.optim.SGD(1e-100), numpyro.infer.ELBO(),
                          numpyro.infer.kernels.RBFKernel(), repulsion_temperature=data[1].shape[0],
                          classic_guide_params_fn=lambda n: n in {'tmapp'})
state, loss = svgd.run(rng_key, 100000, *data)