# Transport maps (Triangular, NeuTra)
Based on "Transport map accelerated Markov chain Monte Carlo" by Parno and Marzouk (SIAM/ASA Journal on Uncertainty Quantification 2018) and "NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport" by Hoffman et al. (Symposium on Approximate Bayesian Inference 2018)

In [1]:
import jax
import jax.numpy as np
import numpy as onp
import numpyro
from numpyro.contrib.autoguide import AutoDelta
from numpyro import distributions as dist
from numpyro.distributions import constraints
from numpyro.examples.runge_kutta import runge_kutta_4
from tqdm import tqdm

In [2]:
rng_key = jax.random.PRNGKey(242)

## Predator Prey Model

In [3]:
def predator_prey_step(t, state, r=0.6, k=100, s=1.2, a=25, u=0.5, v=0.3):
    prey = state[..., 0]
    predator = state[..., 1]
    sh = (prey * predator)/(a + prey)
    prey_upd = r * prey * (1 - prey / k) - s * sh
    predator_upd = u * sh - v * predator
    return np.stack((prey_upd, predator_upd), axis=-1)
num_time = 5
step_size = 0.1
num_steps = int(num_time / step_size)
dampening_rate = 0.9
lyapunov_scale = 10e-4
predator_prey = jax.partial(runge_kutta_4, predator_prey_step, step_size, num_steps, dampening_rate,
                            lyapunov_scale, rng_key)

In [4]:
indices = np.array([1, 11, 21, 31, 41])
res, lyapunov_loss = predator_prey(np.array([50., 5.]))
noise = jax.random.normal(rng_key, (1000,5,2)) * 10
data = (indices, res[indices] + noise)
data

(DeviceArray([ 1, 11, 21, 31, 41], dtype=int32),
 DeviceArray([[[ 39.024696 , -12.725592 ],
               [ 66.320984 ,   6.957094 ],
               [ 68.026535 ,   3.555342 ],
               [ 91.41165  ,  12.742411 ],
               [ 72.754944 ,  -2.7839265]],
 
              [[ 70.52715  ,   8.503686 ],
               [ 82.815445 ,   7.1307592],
               [ 47.49247  ,  -3.0286903],
               [ 79.045616 ,  12.486261 ],
               [ 61.746765 ,   7.427475 ]],
 
              [[ 44.7713   ,  -8.647638 ],
               [ 68.058876 ,  10.351557 ],
               [ 62.73179  ,  20.398611 ],
               [ 76.33872  ,  24.19844  ],
               [ 68.47187  ,  11.550174 ]],
 
              ...,
 
              [[ 65.181244 ,  17.313402 ],
               [ 62.664757 ,   5.912359 ],
               [ 77.01439  ,   7.073277 ],
               [ 80.45603  ,  16.42875  ],
               [ 73.00238  ,   4.452697 ]],
 
              [[ 60.07134  ,  -6.554132 ],
               

In [5]:
def model(indices, observations):
    prior_dist = dist.TransformedDistribution(dist.HalfNormal(1000), dist.transforms.AffineTransform(0.1, 1.0))
    prey0 = numpyro.sample('prey0', prior_dist)
    predator0 = numpyro.sample('predator0', prior_dist)
    r = numpyro.sample('r', prior_dist)
    k = numpyro.sample('k', prior_dist)
    s = numpyro.sample('s', prior_dist)
    a = numpyro.sample('a', prior_dist)
    u = numpyro.sample('u', prior_dist)
    v = numpyro.sample('v', prior_dist)
    ppres, lyapunov_loss = predator_prey(np.array([prey0, predator0]), r=r, k=k, s=s, a=a, u=u, v=v)
    numpyro.factor('lyapunov_loss', lyapunov_loss)
    numpyro.sample('obs', dist.Normal(ppres[indices], 10.0).to_event(2), obs=observations)

### NUTS MCMC Sampling

In [6]:
mcmc = numpyro.infer.MCMC(numpyro.infer.NUTS(model), 100, 500, num_chains=2, chain_method='vectorized')
mcmc.run(rng_key, *data)
mcmc.print_summary()

KeyboardInterrupt: 

### Guide and Stein with Transport Maps

In [18]:
svgd = numpyro.infer.SVGD(model, AutoDelta(model), 
                          numpyro.optim.Adam(0.0001), numpyro.infer.ELBO(),
                          numpyro.infer.kernels.RBFKernel(), num_stein_particles=20,
                          repulsion_temperature=0.001 * data[1].shape[0])
state, loss = svgd.run(rng_key, int(1e5), *data)  # rounds 100000

  0%|          | 0/100000 [00:00<?, ?it/s]

attr force Traced<ShapedArray(float32[20,8]):JaxprTrace(level=-1/1)>
repul force Traced<ShapedArray(float32[20,8]):JaxprTrace(level=-1/1)>


SVGD 1.5576e+05:   0%|          | 1/100000 [00:20<575:13:39, 20.71s/it]

attr force Traced<ShapedArray(float32[20,8]):JaxprTrace(level=-1/1)>
repul force Traced<ShapedArray(float32[20,8]):JaxprTrace(level=-1/1)>


SVGD 3.7956e+04: 100%|██████████| 100000/100000 [06:07<00:00, 271.98it/s]


In [19]:
svgd.get_params(state)

{'auto_a': DeviceArray([1.05216011e+02, 4.00399628e+01, 1.83080077e-01,
              1.08468078e-01, 1.26189232e+02, 2.17729479e-01,
              3.13297760e+02, 3.63027573e+00, 1.25300636e+01,
              1.45243719e-01, 2.48445068e+02, 6.51441269e+01,
              2.11934326e+02, 1.78412838e+01, 2.72933563e+02,
              1.08531494e+02, 4.83667731e-01, 1.09312925e+01,
              1.05666615e-01, 1.73268909e+01], dtype=float32),
 'auto_k': DeviceArray([ 70.36696 ,  69.714134, 105.80192 ,  78.8281  ,  69.64293 ,
               74.71989 , 101.29064 ,  96.049065,  91.80952 , 262.88666 ,
               94.11305 ,  91.243126,  69.640114,  91.506355,  93.65896 ,
               70.36578 , 110.99934 ,  73.76675 ,  74.13523 ,  69.656334],            dtype=float32),
 'auto_predator0': DeviceArray([5.1660438, 5.1174564, 5.1531987, 3.1182866, 5.1126046,
              4.5123105, 5.1697626, 5.1291137, 5.1260285, 4.452349 ,
              5.1690364, 5.1568727, 5.1400986, 5.1300516, 5.17124

In [20]:
def transmap_guide(indices, observations):
    prey0val = numpyro.param('prey0val', 25, constraints=dist.constraints.greater_than(0.1))
    predator0val = numpyro.param('predator0val', 25, constraints=dist.constraints.greater_than(0.1))
    rval = numpyro.param('rval', 25, constraints=dist.constraints.greater_than(0.1))
    kval = numpyro.param('kval', 25, constraints=dist.constraints.greater_than(0.1))
    sval = numpyro.param('sval', 25, constraints=dist.constraints.greater_than(0.1))
    aval = numpyro.param('aval', 25, constraints=dist.constraints.greater_than(0.1))
    uval = numpyro.param('uval', 25, constraints=dist.constraints.greater_than(0.1))
    vval = numpyro.param('vval', 25, constraints=dist.constraints.greater_than(0.1))
    tmapp = numpyro.param('tmapp', np.zeros(8 * (8 + 1) // 2) + 1e-3)
    tril_idx = np.tril_indices(8)
    tmap = jax.ops.index_update(np.zeros((8, 8)), tril_idx, tmapp)
    prey0val, predator0val, rval, kval, sval, aval, uval, vval = tmap @ np.array([prey0val, predator0val, rval, kval, sval, aval, uval, vval])
    prey0 = numpyro.sample('prey0', dist.Delta(prey0val))
    predator0 = numpyro.sample('predator0', dist.Delta(predator0val))
    r = numpyro.sample('r', dist.Delta(rval))
    k = numpyro.sample('k', dist.Delta(kval))
    s = numpyro.sample('s', dist.Delta(sval))
    a = numpyro.sample('a', dist.Delta(aval))
    u = numpyro.sample('u', dist.Delta(uval))
    v = numpyro.sample('v', dist.Delta(vval))

In [26]:
svgd = numpyro.infer.SVGD(model, numpyro.guides.WrappedGuide(transmap_guide), 
                          numpyro.optim.SGD(1e-100), numpyro.infer.ELBO(),
                          numpyro.infer.kernels.RBFKernel(), repulsion_temperature=0.001 *data[1].shape[0],
                          classic_guide_params_fn=lambda n: n in {'tmapp'})
state, loss = svgd.run(rng_key, 1, *data) # rounds 1e5

  0%|          | 0/1 [00:00<?, ?it/s]

attr force Traced<ShapedArray(float32[10,8]):JaxprTrace(level=-1/1)>
repul force Traced<ShapedArray(float32[10,8]):JaxprTrace(level=-1/1)>


SVGD nan: 100%|██████████| 1/1 [00:31<00:00, 31.16s/it]


In [24]:
svgd.get_params(state)


{'aval': DeviceArray([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], dtype=float32),
 'kval': DeviceArray([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], dtype=float32),
 'predator0val': DeviceArray([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], dtype=float32),
 'prey0val': DeviceArray([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], dtype=float32),
 'rval': DeviceArray([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], dtype=float32),
 'sval': DeviceArray([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], dtype=float32),
 'tmapp': DeviceArray([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
              nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],            dtype=float32),
 'uval': DeviceArray([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], dtype=float32),
 'vval': DeviceArray([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], dtype=float32)}