# Transport maps (Triangular, NeuTra)
Based on "Transport map accelerated Markov chain Monte Carlo" by Parno and Marzouk (SIAM/ASA Journal on Uncertainty Quantification 2018) and "NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport" by Hoffman et al. (Symposium on Approximate Bayesian Inference 2018)

In [1]:
import jax
import jax.numpy as np
import numpy as onp
import numpyro
from numpyro import distributions as dist
from numpyro.distributions import constraints
from numpyro.examples.runge_kutta import runge_kutta_4
from tqdm import tqdm

In [2]:
rng_key = jax.random.PRNGKey(242)

## Predator Prey Model

In [3]:
def predator_prey(t, state, r=0.6, k=100, s=1.2, a=25, u=0.5, v=0.3):
    prey = state[..., 0]
    predator = state[..., 1]
    sh = (prey * predator)/(a + prey)
    prey_upd = r * prey * (1 - prey / k) - s * sh
    predator_upd = u * sh - v * predator
    return np.stack((prey_upd, predator_upd), axis=-1)

In [4]:
indices = np.array([1, 11, 21, 31, 41])
res = runge_kutta_4(predator_prey, np.array([50., 5.]), 0.1, 50)
noise = jax.random.normal(rng_key, (1000,5,2)) * 10
data = (indices, res[indices] + noise)
data

(DeviceArray([ 1, 11, 21, 31, 41], dtype=int32),
 DeviceArray([[[ 39.024696 , -12.725592 ],
               [ 66.320984 ,   6.957094 ],
               [ 68.026535 ,   3.555342 ],
               [ 91.41165  ,  12.742411 ],
               [ 72.754944 ,  -2.7839265]],
 
              [[ 70.52715  ,   8.503686 ],
               [ 82.815445 ,   7.1307592],
               [ 47.49247  ,  -3.0286903],
               [ 79.045616 ,  12.486261 ],
               [ 61.746765 ,   7.427475 ]],
 
              [[ 44.7713   ,  -8.647638 ],
               [ 68.058876 ,  10.351557 ],
               [ 62.73179  ,  20.398611 ],
               [ 76.33872  ,  24.19844  ],
               [ 68.47187  ,  11.550174 ]],
 
              ...,
 
              [[ 65.181244 ,  17.313402 ],
               [ 62.664757 ,   5.912359 ],
               [ 77.01439  ,   7.073277 ],
               [ 80.45603  ,  16.42875  ],
               [ 73.00238  ,   4.452697 ]],
 
              [[ 60.07134  ,  -6.554132 ],
               

In [5]:
# %%
def model(indices, observations):
    prey0 = numpyro.sample('prey0', dist.Uniform(0.1, 50.))
    predator0 = numpyro.sample('predator0', dist.Uniform(0.1, 50.))
    r = numpyro.sample('r', dist.Uniform(0.1, 50.))
    k = numpyro.sample('k', dist.Uniform(0.1, 150.))
    s = numpyro.sample('s', dist.Uniform(0.1, 50.))
    a = numpyro.sample('a', dist.Uniform(0.1, 50.))
    u = numpyro.sample('u', dist.Uniform(0.1, 50.))
    v = numpyro.sample('v', dist.Uniform(0.1, 50.))
    stddev = numpyro.sample('stddev', dist.InverseGamma(1.))
    ppres = runge_kutta_4(lambda t, state: predator_prey(t, state, r, k, s, a, u, v),
                          np.array([prey0, predator0]), 0.01, 50)
    numpyro.sample('obs', dist.Normal(ppres[indices], stddev).to_event(2), obs=observations)

### NUTS MCMC Sampling

In [6]:
mcmc = numpyro.infer.MCMC(numpyro.infer.NUTS(model), 100, 1000, num_chains=2, chain_method='sequential')
mcmc.run(rng_key, *data)
mcmc.print_summary()

sample: 100%|██████████| 1100/1100 [00:56<00:00, 19.34it/s, 511 steps of size 2.00e-03. acc. prob=0.93]
sample: 100%|██████████| 1100/1100 [00:26<00:00, 41.55it/s, 575 steps of size 4.63e-03. acc. prob=0.83]

                 mean       std    median      5.0%     95.0%     n_eff     r_hat
          a     24.39     14.84     24.59      0.13     44.92    127.62      1.01
          k    105.48     11.03    103.96     88.85    121.91    112.86      1.05
  predator0      5.34      0.28      5.34      4.91      5.82   1289.45      1.00
      prey0     49.85      0.14     49.89     49.66     50.00   1328.83      1.00
          r      6.68      1.13      6.45      5.10      8.37    173.96      1.04
          s     19.45     13.50     17.08      0.10     40.82    119.50      1.05
     stddev     10.12      0.07     10.12     10.01     10.24   2184.77      1.00
          u      9.23      9.54      5.84      0.34     20.89     89.78      1.01
          v      7.02      9.09      3.69      0.10  

### Guide and Stein with Transport Maps

In [7]:
def standard_guide(indices, observations):
    prey0val = numpyro.param('prey0val', 25, constraint=constraints.interval(0.1, 50))
    prey0 = numpyro.sample('prey0', dist.Delta(prey0val))
    predator0val = numpyro.param('predator0val', 25, constraint=constraints.interval(0.1, 50))
    predator0 = numpyro.sample('predator0', dist.Delta(predator0val))
    rval = numpyro.param('rval', 25, constraint=constraints.interval(0.1, 50))
    r = numpyro.sample('r', dist.Delta(rval))
    kval = numpyro.param('kval', 25, constraint=constraints.interval(0.1, 50))
    k = numpyro.sample('k', dist.Delta(kval))
    sval = numpyro.param('sval', 25, constraint=constraints.interval(0.1, 50))
    s = numpyro.sample('s', dist.Delta(sval))
    aval = numpyro.param('aval', 25, constraint=constraints.interval(0.1, 50))
    a = numpyro.sample('a', dist.Delta(aval))
    uval = numpyro.param('uval', 25, constraint=constraints.interval(0.1, 50))
    u = numpyro.sample('u', dist.Delta(uval))
    vval = numpyro.param('vval', 25, constraint=constraints.interval(0.1, 50))
    v = numpyro.sample('v', dist.Delta(vval))

In [11]:
svgd = numpyro.infer.SVGD(model, numpyro.guides.WrappedGuide(standard_guide), 
                          numpyro.optim.SGD(1e-100), numpyro.infer.ELBO(),
                          numpyro.infer.kernels.RBFKernel(), repulsion_temperature=data[1].shape[0])
state, loss = svgd.run(rng_key, 100000, *data)

SVGD 1.929e+05: 100%|██████████| 100000/100000 [03:11<00:00, 522.46it/s]


In [14]:
def transmap_guide(indices, observations):
    prey0val = numpyro.param('prey0val', 25)
    predator0val = numpyro.param('predator0val', 25)
    rval = numpyro.param('rval', 25)
    kval = numpyro.param('kval', 25)
    sval = numpyro.param('sval', 25)
    aval = numpyro.param('aval', 25)
    uval = numpyro.param('uval', 25)
    vval = numpyro.param('vval', 25)
    tmapp = numpyro.param('tmapp', np.zeros(8 * (8 + 1) // 2) + 1e-3)
    tril_idx = np.tril_indices(8)
    tmap = jax.ops.index_update(np.zeros((8, 8)), tril_idx, tmapp)
    prey0val, predator0val, rval, kval, sval, aval, uval, vval = tmap @ np.array([prey0val, predator0val, rval, kval, sval, aval, uval, vval])
    prey0 = numpyro.sample('prey0', dist.Delta(prey0val))
    predator0 = numpyro.sample('predator0', dist.Delta(predator0val))
    r = numpyro.sample('r', dist.Delta(rval))
    k = numpyro.sample('k', dist.Delta(kval))
    s = numpyro.sample('s', dist.Delta(sval))
    a = numpyro.sample('a', dist.Delta(aval))
    u = numpyro.sample('u', dist.Delta(uval))
    v = numpyro.sample('v', dist.Delta(vval))

In [16]:
svgd = numpyro.infer.SVGD(model, numpyro.guides.WrappedGuide(transmap_guide), 
                          numpyro.optim.SGD(1e-100), numpyro.infer.ELBO(),
                          numpyro.infer.kernels.RBFKernel(), repulsion_temperature=data[1].shape[0],
                          classic_guide_params_fn=lambda n: n in {'tmapp'})
with jax.disable_jit():
    state, loss = svgd.run(rng_key, 100000, *data)

0%|          | 0/100000 [00:00<?, ?it/s]


ValueError: Incompatible shapes for broadcasting: (10, 36) and requested shape (36,)