# Transport maps (Triangular, NeuTra)
Based on "Transport map accelerated Markov chain Monte Carlo" by Parno and Marzouk (SIAM/ASA Journal on Uncertainty Quantification 2018) and "NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport" by Hoffman et al. (Symposium on Approximate Bayesian Inference 2018)

In [31]:
import jax
import jax.numpy as np
import numpy as onp
import numpyro
from numpyro.contrib.autoguide import AutoDelta
from numpyro import distributions as dist
from numpyro.distributions import constraints
from numpyro.examples.runge_kutta import runge_kutta_4
from tqdm import tqdm

In [2]:
rng_key = jax.random.PRNGKey(242)

## Predator Prey Model

In [3]:
def predator_prey_step(t, state, r=0.6, k=100, s=1.2, a=25, u=0.5, v=0.3):
    prey = state[..., 0]
    predator = state[..., 1]
    sh = (prey * predator)/(a + prey)
    prey_upd = r * prey * (1 - prey / k) - s * sh
    predator_upd = u * sh - v * predator
    return np.stack((prey_upd, predator_upd), axis=-1)
num_time = 5
step_size = 0.1
num_steps = int(num_time / step_size)
dampening_rate = 0.9
lyapunov_scale = 10e-4
predator_prey = jax.partial(runge_kutta_4, predator_prey_step, step_size, num_steps, dampening_rate,
                            lyapunov_scale, rng_key)

In [4]:
indices = np.array([1, 11, 21, 31, 41])
res, lyapunov_loss = predator_prey(np.array([50., 5.]))
noise = jax.random.normal(rng_key, (1000,5,2)) * 10
data = (indices, res[indices] + noise)
data

(DeviceArray([ 1, 11, 21, 31, 41], dtype=int32),
 DeviceArray([[[ 39.024696 , -12.725592 ],
               [ 66.320984 ,   6.957094 ],
               [ 68.026535 ,   3.555342 ],
               [ 91.41165  ,  12.742411 ],
               [ 72.754944 ,  -2.7839265]],
 
              [[ 70.52715  ,   8.503686 ],
               [ 82.815445 ,   7.1307592],
               [ 47.49247  ,  -3.0286903],
               [ 79.045616 ,  12.486261 ],
               [ 61.746765 ,   7.427475 ]],
 
              [[ 44.7713   ,  -8.647638 ],
               [ 68.058876 ,  10.351557 ],
               [ 62.73179  ,  20.398611 ],
               [ 76.33872  ,  24.19844  ],
               [ 68.47187  ,  11.550174 ]],
 
              ...,
 
              [[ 65.181244 ,  17.313402 ],
               [ 62.664757 ,   5.912359 ],
               [ 77.01439  ,   7.073277 ],
               [ 80.45603  ,  16.42875  ],
               [ 73.00238  ,   4.452697 ]],
 
              [[ 60.07134  ,  -6.554132 ],
               

In [48]:
def model(indices, observations):
    prior_dist = dist.TransformedDistribution(dist.HalfNormal(1000), dist.transforms.AffineTransform(0.1, 1.0))
    prey0 = numpyro.sample('prey0', prior_dist)
    predator0 = numpyro.sample('predator0', prior_dist)
    r = numpyro.sample('r', prior_dist)
    k = numpyro.sample('k', prior_dist)
    s = numpyro.sample('s', prior_dist)
    a = numpyro.sample('a', prior_dist)
    u = numpyro.sample('u', prior_dist)
    v = numpyro.sample('v', prior_dist)
    ppres, lyapunov_loss = predator_prey(np.array([prey0, predator0]))
    numpyro.factor('lyapunov_loss', lyapunov_loss)
    numpyro.sample('obs', dist.Normal(ppres[indices], 10.0).to_event(2), obs=observations)

### NUTS MCMC Sampling

In [50]:
mcmc = numpyro.infer.MCMC(numpyro.infer.NUTS(model), 1000, 5000, num_chains=2, chain_method='vectorized')
mcmc.run(rng_key, *data)
mcmc.print_summary()

sample: 100%|██████████| 6000/6000 [10:17<00:00,  9.72it/s, None]

                 mean       std    median      5.0%     95.0%     n_eff     r_hat
          a      2.91      2.43      1.86      0.59      6.31      1.09      3.63
          k      3.04      3.12      0.92      0.77      8.40      2.10      1.65
  predator0      5.13      0.08      5.13      5.00      5.27    544.67      1.00
      prey0     50.08      0.23     50.08     49.67     50.44    467.59      1.01
          r      2.11      0.63      1.92      1.41      3.29      3.24      2.94
          s      2.02      0.97      1.62      0.91      3.66      1.52      4.27
          u      0.74      0.40      0.65      0.32      1.23      1.03      6.42
          v      1.72      1.55      1.20      0.21      3.80      1.06      7.47

Number of divergences: 0


### Guide and Stein with Transport Maps

In [37]:
svgd = numpyro.infer.SVGD(model, AutoDelta(standard_guide), 
                          numpyro.optim.Adam(0.01), numpyro.infer.ELBO(),
                          numpyro.infer.kernels.RBFKernel(), num_stein_particles=10, repulsion_temperature=1e6)
print(data[1].shape[0])                        
state = svgd.init(rng_key, *data)
print(svgd.get_params(state))
state, loss = svgd.run(rng_key, 100000, *data)

1000
{'auto_a': DeviceArray([ 0.2810459 , -1.1512799 , -0.9791584 , -1.1150136 ,
             -1.3962798 , -0.4815421 ,  0.22602224, -0.05310202,
              1.7047496 ,  1.584774  ], dtype=float32), 'auto_k': DeviceArray([ 1.7867632 , -0.46652842,  0.01522875,  1.6880732 ,
             -0.81017447,  1.5335546 , -0.27158308,  1.5721688 ,
             -0.6820264 , -0.90884256], dtype=float32), 'auto_predator0': DeviceArray([ 1.5217471 ,  0.38708544, -0.47504282,  1.3877826 ,
              1.2827177 , -1.6553111 , -1.3955574 ,  1.2053485 ,
              1.5027761 ,  1.949708  ], dtype=float32), 'auto_prey0': DeviceArray([ 0.96609116, -0.74285316,  1.6176782 ,  0.9832182 ,
             -0.3874774 ,  1.1533241 , -1.8001752 ,  1.5888176 ,
              0.90665674,  1.3857479 ], dtype=float32), 'auto_r': DeviceArray([ 0.35203552,  0.7305212 ,  1.3158989 ,  1.9379086 ,
             -0.4936204 ,  0.26758862, -1.0140486 ,  1.8725686 ,
              0.42908382, -1.3745251 ], dtype=float32), 'a

In [38]:
svgd.get_params(state)

{'auto_a': DeviceArray([0.09644218, 0.09193711, 0.09941586, 0.09488524, 0.10837403,
              0.10440823, 0.09379473, 0.10275237, 0.11022998, 0.09776031],            dtype=float32),
 'auto_k': DeviceArray([0.10018528, 0.10520142, 0.09967141, 0.09982951, 0.10028005,
              0.11348224, 0.099717  , 0.10032418, 0.09479452, 0.08651443],            dtype=float32),
 'auto_predator0': DeviceArray([5.2504826, 5.250642 , 5.2506332, 5.2501235, 5.250519 ,
              5.250584 , 5.2505174, 5.2505965, 5.2505927, 5.250656 ],            dtype=float32),
 'auto_prey0': DeviceArray([50.36169 , 50.36152 , 50.361534, 50.36204 , 50.361645,
              50.361584, 50.361656, 50.361576, 50.36158 , 50.36151 ],            dtype=float32),
 'auto_r': DeviceArray([0.102716  , 0.10103273, 0.08652895, 0.10255948, 0.10378079,
              0.09827673, 0.09490088, 0.11215659, 0.09764636, 0.10040153],            dtype=float32),
 'auto_s': DeviceArray([0.10307536, 0.08938336, 0.09729765, 0.10449575, 0.0917

In [47]:
runge_kutta_4(lambda t, y: predator_prey_step(t, y, r=0.1, k=0.3, s=0.1,a=0.1,u=0.1, v=0.1), step_size, num_steps, dampening_rate,
                            lyapunov_scale, rng_key, np.array([50., 5.]))

(DeviceArray([[4.3650956e+00, 4.9997859e+00],
              [3.8034420e+00, 4.9985857e+00],
              [3.3639469e+00, 4.9972243e+00],
              [3.0097077e+00, 4.9956999e+00],
              [2.7173209e+00, 4.9940104e+00],
              [2.4712257e+00, 4.9921536e+00],
              [2.2606721e+00, 4.9901261e+00],
              [2.0779948e+00, 4.9879246e+00],
              [1.9175801e+00, 4.9855442e+00],
              [1.7752224e+00, 4.9829807e+00],
              [1.6477084e+00, 4.9802279e+00],
              [1.5325407e+00, 4.9772792e+00],
              [1.4277488e+00, 4.9741278e+00],
              [1.3317573e+00, 4.9707646e+00],
              [1.2432915e+00, 4.9671807e+00],
              [1.1613094e+00, 4.9633651e+00],
              [1.0849509e+00, 4.9593062e+00],
              [1.0135007e+00, 4.9549899e+00],
              [9.4635916e-01, 4.9504008e+00],
              [8.8302100e-01, 4.9455214e+00],
              [8.2305807e-01, 4.9403319e+00],
              [7.6610649e-01, 4.93

In [9]:
def transmap_guide(indices, observations):
    prey0val = numpyro.param('prey0val', 25)
    predator0val = numpyro.param('predator0val', 25)
    rval = numpyro.param('rval', 25)
    kval = numpyro.param('kval', 25)
    sval = numpyro.param('sval', 25)
    aval = numpyro.param('aval', 25)
    uval = numpyro.param('uval', 25)
    vval = numpyro.param('vval', 25)
    tmapp = numpyro.param('tmapp', np.zeros(8 * (8 + 1) // 2) + 1e-3)
    tril_idx = np.tril_indices(8)
    tmap = jax.ops.index_update(np.zeros((8, 8)), tril_idx, tmapp)
    prey0val, predator0val, rval, kval, sval, aval, uval, vval = tmap @ np.array([prey0val, predator0val, rval, kval, sval, aval, uval, vval])
    prey0 = numpyro.sample('prey0', dist.Delta(prey0val))
    predator0 = numpyro.sample('predator0', dist.Delta(predator0val))
    r = numpyro.sample('r', dist.Delta(rval))
    k = numpyro.sample('k', dist.Delta(kval))
    s = numpyro.sample('s', dist.Delta(sval))
    a = numpyro.sample('a', dist.Delta(aval))
    u = numpyro.sample('u', dist.Delta(uval))
    v = numpyro.sample('v', dist.Delta(vval))

In [8]:
svgd = numpyro.infer.SVGD(model, numpyro.guides.WrappedGuide(transmap_guide), 
                          numpyro.optim.SGD(1e-100), numpyro.infer.ELBO(),
                          numpyro.infer.kernels.RBFKernel(), repulsion_temperature=data[1].shape[0],
                          classic_guide_params_fn=lambda n: n in {'tmapp'})
state, loss = svgd.run(rng_key, 100000, *data)

NameError: name 'transmap_guide' is not defined