# Transport maps (Triangular, NeuTra)
Based on "Transport map accelerated Markov chain Monte Carlo" by Parno and Marzouk (SIAM/ASA Journal on Uncertainty Quantification 2018) and "NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport" by Hoffman et al. (Symposium on Approximate Bayesian Inference 2018)

In [1]:
import jax
import jax.numpy as np
import numpy as onp
import numpyro
from numpyro import distributions as dist
from numpyro.distributions import constraints
from numpyro.examples.runge_kutta import runge_kutta_4
from tqdm import tqdm

In [2]:
rng_key = jax.random.PRNGKey(242)

## Predator Prey Model

In [3]:
def predator_prey_step(t, state, r=0.6, k=100, s=1.2, a=25, u=0.5, v=0.3):
    prey = state[..., 0]
    predator = state[..., 1]
    sh = (prey * predator)/(a + prey)
    prey_upd = r * prey * (1 - prey / k) - s * sh
    predator_upd = u * sh - v * predator
    return np.stack((prey_upd, predator_upd), axis=-1)
num_time = 5
step_size = 0.1
num_steps = int(num_time / step_size)
dampening_rate = 0.9
lyapunov_scale = 10e-4
predator_prey = jax.partial(runge_kutta_4, predator_prey_step, step_size, num_steps, dampening_rate,
                            lyapunov_scale, rng_key)

In [4]:
indices = np.array([1, 11, 21, 31, 41])
res, lyapunov_loss = predator_prey(np.array([50., 5.]))
noise = jax.random.normal(rng_key, (1000,5,2)) * 10
data = (indices, res[indices] + noise)
data

(DeviceArray([ 1, 11, 21, 31, 41], dtype=int32),
 DeviceArray([[[ 39.024696 , -12.725592 ],
               [ 66.320984 ,   6.957094 ],
               [ 68.026535 ,   3.555342 ],
               [ 91.41165  ,  12.742411 ],
               [ 72.754944 ,  -2.7839265]],
 
              [[ 70.52715  ,   8.503686 ],
               [ 82.815445 ,   7.1307592],
               [ 47.49247  ,  -3.0286903],
               [ 79.045616 ,  12.486261 ],
               [ 61.746765 ,   7.427475 ]],
 
              [[ 44.7713   ,  -8.647638 ],
               [ 68.058876 ,  10.351557 ],
               [ 62.73179  ,  20.398611 ],
               [ 76.33872  ,  24.19844  ],
               [ 68.47187  ,  11.550174 ]],
 
              ...,
 
              [[ 65.181244 ,  17.313404 ],
               [ 62.664757 ,   5.912359 ],
               [ 77.01439  ,   7.073277 ],
               [ 80.45603  ,  16.42875  ],
               [ 73.00238  ,   4.452697 ]],
 
              [[ 60.07134  ,  -6.554132 ],
               

In [5]:
def model(indices, observations):
    prior_dist = dist.TransformedDistribution(dist.Normal(1.5, 100),
                                              dist.transforms.ComposeTransform([
                                                  dist.transforms.ExpTransform(),
                                                  dist.transforms.AffineTransform(0.1, 1.0)
                                              ]))
    prey0 = numpyro.sample('prey0', prior_dist)
    predator0 = numpyro.sample('predator0', prior_dist)
    r = numpyro.sample('r', prior_dist)
    k = numpyro.sample('k', prior_dist)
    s = numpyro.sample('s', prior_dist)
    a = numpyro.sample('a', prior_dist)
    u = numpyro.sample('u', prior_dist)
    v = numpyro.sample('v', prior_dist)
    ppres, lyapunov_loss = predator_prey(np.array([prey0, predator0]))
    numpyro.factor('lyapunov_loss', lyapunov_loss)
    numpyro.sample('obs', dist.Normal(ppres[indices], 1.0).to_event(2), obs=observations)

### NUTS MCMC Sampling

In [6]:
mcmc = numpyro.infer.MCMC(numpyro.infer.NUTS(model), 100, 1000, num_chains=2, chain_method='vectorized')
mcmc.run(rng_key, *data)
mcmc.print_summary()

sample: 100%|██████████| 1100/1100 [03:32<00:00,  5.19it/s, None]

                 mean       std    median      5.0%     95.0%     n_eff     r_hat
          a      2.96      2.37      2.96      0.59      5.33      1.00   6756.00
          k      1.10      0.18      1.10      0.92      1.28      1.00   1232.37
  predator0      3.30      0.20      3.30      3.10      3.50      1.00    192.67
      prey0     47.83      1.22     47.83     46.61     49.05      1.00   2763.85
          r      2.17      0.60      2.17      1.57      2.78      1.00   2193.87
          s      1.41      0.16      1.41      1.25      1.57      1.00    913.34
          u      0.76      0.38      0.76      0.39      1.14       nan   4268.08
          v      1.47      1.22      1.47      0.25      2.70      1.00   6874.50

Number of divergences: 0


### Guide and Stein with Transport Maps

In [7]:
def standard_guide(indices, observations):
    prey0val = numpyro.param('prey0val', 25, constraint=constraints.interval(0.1, 50))
    prey0 = numpyro.sample('prey0', dist.Delta(prey0val))
    predator0val = numpyro.param('predator0val', 25, constraint=constraints.interval(0.1, 50))
    predator0 = numpyro.sample('predator0', dist.Delta(predator0val))
    rval = numpyro.param('rval', 25, constraint=constraints.interval(0.1, 50))
    r = numpyro.sample('r', dist.Delta(rval))
    kval = numpyro.param('kval', 25, constraint=constraints.interval(0.1, 50))
    k = numpyro.sample('k', dist.Delta(kval))
    sval = numpyro.param('sval', 25, constraint=constraints.interval(0.1, 50))
    s = numpyro.sample('s', dist.Delta(sval))
    aval = numpyro.param('aval', 25, constraint=constraints.interval(0.1, 50))
    a = numpyro.sample('a', dist.Delta(aval))
    uval = numpyro.param('uval', 25, constraint=constraints.interval(0.1, 50))
    u = numpyro.sample('u', dist.Delta(uval))
    vval = numpyro.param('vval', 25, constraint=constraints.interval(0.1, 50))
    v = numpyro.sample('v', dist.Delta(vval))

In [12]:
svgd = numpyro.infer.SVGD(model, numpyro.guides.WrappedGuide(standard_guide), 
                          numpyro.optim.Adam(1e-3), numpyro.infer.ELBO(),
                          numpyro.infer.kernels.RBFKernel(), repulsion_temperature=data[1].shape[0])
print(data[1].shape[0])                        
state = svgd.init(rng_key, *data)
print(svgd.get_params(state))
state, loss = svgd.run(rng_key, 100000, *data)

1000


In [10]:
svgd.get_params(state)

{'aval': DeviceArray([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], dtype=float32),
 'kval': DeviceArray([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], dtype=float32),
 'predator0val': DeviceArray([5.161248 , 5.160422 , 5.161248 , 5.1612544, 5.1612153,
              5.161208 , 5.16125  , 5.1612535, 5.1612525, 5.1612473],            dtype=float32),
 'prey0val': DeviceArray([49.999992, 49.999992, 49.999992, 49.999992, 49.999992,
              49.999992, 49.999992, 49.999992, 49.999992, 49.999992],            dtype=float32),
 'rval': DeviceArray([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], dtype=float32),
 'sval': DeviceArray([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], dtype=float32),
 'uval': DeviceArray([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], dtype=float32),
 'vval': DeviceArray([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], dtype=float32)}

In [9]:
def transmap_guide(indices, observations):
    prey0val = numpyro.param('prey0val', 25)
    predator0val = numpyro.param('predator0val', 25)
    rval = numpyro.param('rval', 25)
    kval = numpyro.param('kval', 25)
    sval = numpyro.param('sval', 25)
    aval = numpyro.param('aval', 25)
    uval = numpyro.param('uval', 25)
    vval = numpyro.param('vval', 25)
    tmapp = numpyro.param('tmapp', np.zeros(8 * (8 + 1) // 2) + 1e-3)
    tril_idx = np.tril_indices(8)
    tmap = jax.ops.index_update(np.zeros((8, 8)), tril_idx, tmapp)
    prey0val, predator0val, rval, kval, sval, aval, uval, vval = tmap @ np.array([prey0val, predator0val, rval, kval, sval, aval, uval, vval])
    prey0 = numpyro.sample('prey0', dist.Delta(prey0val))
    predator0 = numpyro.sample('predator0', dist.Delta(predator0val))
    r = numpyro.sample('r', dist.Delta(rval))
    k = numpyro.sample('k', dist.Delta(kval))
    s = numpyro.sample('s', dist.Delta(sval))
    a = numpyro.sample('a', dist.Delta(aval))
    u = numpyro.sample('u', dist.Delta(uval))
    v = numpyro.sample('v', dist.Delta(vval))

In [8]:
svgd = numpyro.infer.SVGD(model, numpyro.guides.WrappedGuide(transmap_guide), 
                          numpyro.optim.SGD(1e-100), numpyro.infer.ELBO(),
                          numpyro.infer.kernels.RBFKernel(), repulsion_temperature=data[1].shape[0],
                          classic_guide_params_fn=lambda n: n in {'tmapp'})
state, loss = svgd.run(rng_key, 100000, *data)

NameError: name 'transmap_guide' is not defined