In [0]:
from functools import partial

import jax
import jax.numpy as np
from jax.scipy.stats import norm
from jax.scipy.special import logsumexp

import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

import numpy as onp

# don't do this at home, warnings are there for a reason
import warnings
warnings.filterwarnings('ignore')

This notebook accompanies a [blog post](https://rlouf.github.io/post/jax-random-walk-metropolis/) I wrote on the performance of vectorized sampling with the Random Walk Metropolis algorithm. I compared the performance of Numpy, JAX and Tensorflow Probability on CPU. The response was overwhelming, and [Matthew Johnson](https://twitter.com/SingularMattrix) and [Hector Yee](https://twitter.com/eigenhector) were kind enough to point out that I fell into JAX's pseudo-random generator trap, and thus greatly overestimated JAX's performance. Don't make the same mistake, and read the [doc](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Random-Numbers) carefully before playing with random numbers with JAX.

Finally, [Erwin Coumans](https://twitter.com/erwincoumans) suggested on Twitter that I turn part of this benchmark into a notebook Colab to give people a point of comparison with GPU and TPUs. This is a excellent idea, so here it is.

# The setup

The basic requirements to be able to generate samples is the transition kernel of the random walk, and a log-probability density function to sample from. I chose a completely arbitrary gaussian mixture with 4 components.

There are a couple of interesting things going on already:

- The kernel is written for a single chain. We will use JAX's `vmap` function to vectorize the computation. You can find the doc [here](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Auto-vectorization-with-vmap).
- `jax.numpy` acts as a drop-in replacement to `numpy`.
- `jax.random.split` is, roughly speaking, the function that allows you to "advance" the random number generator. If you don't call this function you will be always using the same number. This is hat I did not understand the first time.
- The `jax.jit` decorator tells JAX that you want the function to be JIT compiled. `static_argnums` tells the compiler which parameters of the function will not change when the function is called repeatedly.

In [0]:
@partial(jax.jit, static_argnums=(1,))
def rw_metropolis_kernel(rng_key, logpdf, position, log_prob):
    """Moves a single chain by one step using the Random Walk Metropolis algorithm.

    Attributes
    ----------
    rng_key: jax.random.PRNGKey
      Key for the pseudo random number generator.
    logpdf: function
      Returns the log-probability of the model given a position.
    position: np.ndarray, shape (n_dims,)
      The starting position.
    log_prob: float
      The log probability at the starting position.

    Returns
    -------
    Tuple
        The next positions of the chains along with their log probability.
    """
    key1, key2 = jax.random.split(rng_key)
    move_proposal = jax.random.normal(key1, shape=position.shape) * 0.1
    proposal = position + move_proposal
    proposal_log_prob = logpdf(proposal)

    log_uniform = np.log(jax.random.uniform(key2))
    do_accept = log_uniform < proposal_log_prob - log_prob

    position = np.where(do_accept, proposal, position)
    log_prob = np.where(do_accept, proposal_log_prob, log_prob)
    return position, log_prob

`jax.scipy` acts as a drop-in replacement to `scipy` functions:

In [0]:
def mixture_logpdf(x):
    """Log probability distribution function of a gaussian mixture model.

    Attribute
    ---------
    x: np.ndarray (4,)
        Position at which to evaluate the probability density function.

    Returns
    -------
    float
        The value of the log probability density function at x.
    """
    dist_1 = jax.partial(norm.logpdf, loc=-2.0, scale=1.2)
    dist_2 = jax.partial(norm.logpdf, loc=0, scale=1)
    dist_3 = jax.partial(norm.logpdf, loc=3.2, scale=5)
    dist_4 = jax.partial(norm.logpdf, loc=2.5, scale=2.8)
    log_probs = np.array([dist_1(x), dist_2(x), dist_3(x), dist_4(x)])
    weights = np.array([0.2, 0.3, 0.1, 0.4])
    return -logsumexp(np.log(weights) + log_probs)

In [0]:
dtype = onp.float32
target = tfd.Mixture(
    cat=tfd.Categorical(probs=[0.2, 0.3, 0.1, 0.4]),
    components=[
            tfd.Normal(loc=dtype(-2.0), scale=dtype(1.2)),
            tfd.Normal(loc=dtype(0.0), scale=dtype(1.0)),
            tfd.Normal(loc=dtype(3.2), scale=dtype(5.0)),
            tfd.Normal(loc=dtype(2.5), scale=dtype(2.8)),
    ],
)
mixture_logpdf_tfp = target.log_prob

print(mixture_logpdf_tfp)

<bound method Distribution.log_prob of <tfp.distributions.Mixture 'Mixture_1/' batch_shape=[] event_shape=[] dtype=float32>>


# Drawing a fixed number of samples

Here I start with a common setting in probabilistic programming libraries where the number of samples that we want is known in advance and is not meant to be updated between iterations of the sampler.

This post was originally a simple sanity check before working on a larger project where I am only interested in the last sample that was obtained for each chain. I therefore use the `lax.fori_loop` construct. If you want want to work with all the samples that were produced, use the `lax.scan` construct instead.

In [0]:
@partial(jax.jit, static_argnums=(1, 2))
def rw_metropolis_sampler(rng_key, n_samples, logpdf, initial_position):
    """Generate samples using the Random Walk Metropolis algorithm.

    Attributes
    ----------
    rng_key: jax.random.PRNGKey
        Key for the pseudo random number generator.
    n_samples: int
        Number of samples to generate per chain.
    logpdf: function
      Returns the log-probability of the model given a position.
    inital_position: np.ndarray (n_dims, n_chains)
      The starting position.

    Returns
    -------
    (n_samples, n_dim)
    """

    def sampler_step(i, state):
        key, position, log_prob = state
        new_position, new_log_prob = rw_metropolis_kernel(key, logpdf, position, log_prob)
        return (key, new_position, new_log_prob)

    logp = logpdf(initial_position)
    rng_key, position, log_prob = jax.lax.fori_loop(0, n_samples, sampler_step, (rng_key, initial_position, logp))
    return position

And we define the function that will initialize and run the sampler:

In [0]:
def sample_jax(rng_key, logpdf, n_dim, n_samples, n_chains):
    rng_keys = jax.random.split(rng_key, n_chains)  # (nchains,)
    initial_position = np.zeros((n_dim, n_chains))  # (n_dim, n_chains)
    run_mcmc = jax.vmap(rw_metropolis_sampler, in_axes=(0, None, None, 1),
                        out_axes=1)
    positions = run_mcmc(rng_keys, n_samples, logpdf, initial_position).block_until_ready()
    assert positions.shape == (n_dim, n_chains)

In [0]:
def sample_tfp(logpdf, n_dim, n_samples, n_chains):
    with tf.device("device:GPU:0"):
        dtype = onp.float32
        target = tfd.Mixture(
            cat=tfd.Categorical(probs=[0.2, 0.3, 0.1, 0.4]),
            components=[
                    tfd.Normal(loc=dtype(-2.0), scale=dtype(1.2)),
                    tfd.Normal(loc=dtype(0.0), scale=dtype(1.0)),
                    tfd.Normal(loc=dtype(3.2), scale=dtype(5.0)),
                    tfd.Normal(loc=dtype(2.5), scale=dtype(2.8)),
            ],
        )
        samples, _ = tfp.mcmc.sample_chain(
                num_results=n_samples,
                current_state=tf.random.normal((n_dim, n_chains)),
                kernel=tfp.mcmc.RandomWalkMetropolis(target.log_prob, seed=42),
                num_burnin_steps=0,
                parallel_iterations=1,
        )

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, \
                  log_device_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())

## Drawing 1,000 samples for an increasing number of chains 

In [0]:
n_dim=4
n_samples = 1_000
rng_key = jax.random.PRNGKey(42)

chain_lengths = onp.logspace(1, 7, 7)
chains_exec_times = {
    "JAX": [],
    "TFP": [],
    "TFP w/ XLA compilation": []
}

### JAX

In [0]:
for n_chains in chain_lengths:
    n_chains = int(n_chains)
    t = %timeit -o sample_jax(rng_key, mixture_logpdf, n_dim, n_samples, n_chains)
    avg = sum(t.all_runs)/(len(t.all_runs))
    chains_exec_times["JAX"].append(avg)
    print("{:,} chains: {:.3} s".format(n_chains, avg))

10 loops, best of 3: 22.6 ms per loop
10 chains: 0.229 s
10 loops, best of 3: 22.7 ms per loop
100 chains: 0.234 s
10 loops, best of 3: 27.8 ms per loop
1,000 chains: 0.278 s
10 loops, best of 3: 31.6 ms per loop
10,000 chains: 0.316 s
10 loops, best of 3: 116 ms per loop
100,000 chains: 1.16 s
1 loop, best of 3: 1.03 s per loop
1,000,000 chains: 1.03 s
1 loop, best of 3: 10.1 s per loop
10,000,000 chains: 10.1 s


### Tensorflow Probability

In [0]:
for n_chains in chain_lengths:
        n_chains = int(n_chains)
        t = %timeit -o sample_tfp(mixture_logpdf_tfp, n_dim, n_samples, n_chains)
        avg = sum(t.all_runs)/(len(t.all_runs))
        chains_exec_times["TFP"].append(avg)
        print("{:,} chains: {:.3} s".format(n_chains, avg))

Device mapping:
/job:localhost/replica:0/task:0/device:XLA_CPU:0 -> device: XLA_CPU device
/job:localhost/replica:0/task:0/device:XLA_GPU:0 -> device: XLA_GPU device
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: Tesla K80, pci bus id: 0000:00:04.0, compute capability: 3.7

Device mapping:
/job:localhost/replica:0/task:0/device:XLA_CPU:0 -> device: XLA_CPU device
/job:localhost/replica:0/task:0/device:XLA_GPU:0 -> device: XLA_GPU device
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: Tesla K80, pci bus id: 0000:00:04.0, compute capability: 3.7

Device mapping:
/job:localhost/replica:0/task:0/device:XLA_CPU:0 -> device: XLA_CPU device
/job:localhost/replica:0/task:0/device:XLA_GPU:0 -> device: XLA_GPU device
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: Tesla K80, pci bus id: 0000:00:04.0, compute capability: 3.7

Device mapping:
/job:localhost/replica:0/task:0/device:XLA_CPU:0 -> device: XLA_CPU device
/job:localhost/replica:0/task:

### Tensorflow Probability with XLA compilation

In [0]:
for n_chains in chain_lengths:
    n_chains = int(n_chains)
    run_tfp = partial(sample_tfp, mixture_logpdf_tfp, n_dim, n_samples, n_chains)
    t = %timeit -o tf.xla.experimental.compile(run_tfp)
    avg = sum(t.all_runs)/(len(t.all_runs))
    chains_exec_times["TFP w/ XLA compilation"].append(avg)
    print("{:,} chains: {:.3} s".format(n_chains, avg))

Device mapping:
/job:localhost/replica:0/task:0/device:XLA_CPU:0 -> device: XLA_CPU device
/job:localhost/replica:0/task:0/device:XLA_GPU:0 -> device: XLA_GPU device
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: Tesla K80, pci bus id: 0000:00:04.0, compute capability: 3.7



InaccessibleTensorError: ignored

## Drawing an increasing number of samples for 1,000 chains

In [0]:
n_dim=4
n_chains = 1_000
rng_key = jax.random.PRNGKey(42)

samples_num = onp.logspace(1, 6, 6)
samples_exec_times = {
    "JAX": [],
    "TFP": [],
    "TFP w/ XLA compilation": []
}

### JAX

In [0]:
for n_samples in samples_num:
    n_samples = int(n_samples)
    t = %timeit -o sample_jax(rng_key, mixture_logpdf, n_dim, n_samples, n_chains)
    avg = sum(t.all_runs)/(len(t.all_runs))
    samples_exec_times["JAX"].append(avg)
    print("{:,} samples: {:.3} s".format(n_samples, avg))

The slowest run took 749.86 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 3: 1.43 ms per loop
10 samples: 0.00164 s
The slowest run took 290.76 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 3: 3.73 ms per loop
100 samples: 0.00386 s
10 loops, best of 3: 27.6 ms per loop
1,000 samples: 0.277 s


KeyboardInterrupt: ignored

### Tensorflow Probability

In [0]:
for n_samples in samples_num:
    n_samples = int(n_samples)
    run_tfp = partial(sample_tfp, mixture_logpdf_tfp, n_dim, n_samples, n_chains)
    t = %timeit -o run_tfp()
    avg = sum(t.all_runs)/(len(t.all_runs))
    samples_exec_times["TFP"].append(avg)
    print("{:,} samples: {:.3} s".format(n_samples, avg))

Device mapping:
/job:localhost/replica:0/task:0/device:XLA_CPU:0 -> device: XLA_CPU device
/job:localhost/replica:0/task:0/device:XLA_GPU:0 -> device: XLA_GPU device
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: Tesla K80, pci bus id: 0000:00:04.0, compute capability: 3.7

Device mapping:
/job:localhost/replica:0/task:0/device:XLA_CPU:0 -> device: XLA_CPU device
/job:localhost/replica:0/task:0/device:XLA_GPU:0 -> device: XLA_GPU device
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: Tesla K80, pci bus id: 0000:00:04.0, compute capability: 3.7

Device mapping:
/job:localhost/replica:0/task:0/device:XLA_CPU:0 -> device: XLA_CPU device
/job:localhost/replica:0/task:0/device:XLA_GPU:0 -> device: XLA_GPU device
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: Tesla K80, pci bus id: 0000:00:04.0, compute capability: 3.7

Device mapping:
/job:localhost/replica:0/task:0/device:XLA_CPU:0 -> device: XLA_CPU device
/job:localhost/replica:0/task:

### Tensorflow Probability with XLA compilation

In [0]:
for n_samples in samples_num:
    n_samples = int(n_samples)
    run_tfp = partial(sample_tfp, mixture_logpdf_tfp, n_dim, n_samples, n_chains)
    t = %timeit -o tf.xla.experimental.compile(run_tfp)
    avg = sum(t.all_runs)/(len(t.all_runs))
    samples_exec_times["TFP w/ XLA compilation"].append(avg)
    print("{:,} samples: {:.3} s".format(n_samples, avg))