# Benchmark NumPyro in large dataset

This notebook uses `numpyro` and replicates experiments in references [1] which evaluates the performance of NUTS on various frameworks. The benchmark is run with CUDA 10.0 on a NVIDIA RTX 2070.

In [1]:
import time

import numpy as onp

import jax.numpy as np
from jax import random
# NB: replace gpu by cpu to run this notebook in cpu
from jax.config import config; config.update("jax_platform_name", "gpu")

import numpyro.distributions as dist
from numpyro.diagnostics import summary
from numpyro.examples.datasets import COVTYPE, load_dataset
from numpyro.handlers import sample
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import hmc, mcmc
from numpyro.util import fori_collect

We do preprocessing steps as in [source code](https://github.com/google-research/google-research/blob/master/simple_probabilistic_programming/no_u_turn_sampler/logistic_regression.py) of reference [1]:

In [2]:
_, fetch = load_dataset(COVTYPE, shuffle=False)
features, labels = fetch()

# normalize features and add intercept
features = (features - features.mean(0)) / features.std(0)
features = np.hstack([features, np.ones((features.shape[0], 1))])

# make binary feature
_, counts = onp.unique(labels, return_counts=True)
specific_category = np.argmax(counts)
labels = (labels == specific_category)

N, dim = features.shape
print("Data shape:", features.shape)
print("Label distribution: {} has label 1, {} has label 0"
      .format(labels.sum(), N - labels.sum()))

Data shape: (581012, 55)
Label distribution: 211840 has label 1, 369172 has label 0


Now, we construct the model:

In [3]:
def model(data, labels):
    coefs = sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim)))
    logits = np.dot(data, coefs)
    return sample('obs', dist.Bernoulli(logits=logits), obs=labels)

In [4]:
step_size = np.sqrt(0.5 / N)
init_params = {'coefs': np.zeros(dim)}

## Benchmark HMC

In [5]:
_, potential_fn, _ = initialize_model(random.PRNGKey(1), model, features, labels)

In [6]:
t0 = time.time()
samples = mcmc(num_warmup=0, num_samples=1000, init_params=init_params, potential_fn=potential_fn,
               algo='HMC', step_size=step_size, trajectory_length=(10 * step_size),
               adapt_step_size=False)
t1 = time.time()
num_leapfrogs = 1000 * 10
print("number of leapfrog steps:", num_leapfrogs)
print("avg. time for each step :", (t1 - t0) / 10000)

warmup: 0it [00:00, ?it/s]
sample: 100%|██████████| 1000/1000 [00:40<00:00, 26.43it/s, 10 steps of size 9.28e-04. acc. prob=0.93]



                           mean         sd       5.5%      94.5%      n_eff       Rhat
            coefs[0]       1.97       0.01       1.96       1.99     327.37       1.02
            coefs[1]      -0.04       0.01      -0.05      -0.03    1003.13       1.00
            coefs[2]      -0.07       0.02      -0.09      -0.04       5.77       1.31
            coefs[3]      -0.30       0.01      -0.31      -0.29    3253.70       1.00
            coefs[4]      -0.09       0.01      -0.10      -0.09    2783.86       1.00
            coefs[5]      -0.14       0.01      -0.15      -0.14    1898.41       1.00
            coefs[6]       0.19       0.09       0.11       0.31       5.43       1.32
            coefs[7]      -0.63       0.06      -0.70      -0.58       5.52       1.31
            coefs[8]       0.52       0.11       0.43       0.66       5.44       1.32
            coefs[9]      -0.01       0.01      -0.02      -0.00   -4414.57       1.00
           coefs[10]       0.37       0.0




In CPU, we get `avg. time for each step : 0.03028029832839966`.

## Benchmark NUTS

To have a fair benchmark in NUTS, we need to record the number of leapfrog steps during sampling. Hence we will use the api [hmc](https://numpyro.readthedocs.io/en/latest/mcmc.html#numpyro.mcmc.hmc) and `fori_collect`.

In [7]:
_, potential_fn, _ = initialize_model(random.PRNGKey(1), model, features, labels)
init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
hmc_state = init_kernel(init_params, num_warmup=0, step_size=step_size, adapt_step_size=False)

warmup: 0it [00:00, ?it/s]


In [8]:
t0 = time.time()
hmc_states = fori_collect(100, sample_kernel, hmc_state,
                          transform=lambda state: {'coefs': state.z['coefs'],
                                                   'num_steps': state.num_steps})
t1 = time.time()
summary(hmc_states)
num_leapfrogs = np.sum(hmc_states['num_steps']).copy()
print("number of leapfrog steps:", num_leapfrogs)
print("avg. time for each step :", (t1 - t0) / num_leapfrogs)

100%|██████████| 100/100 [05:00<00:00,  3.31s/it]




                           mean         sd       5.5%      94.5%      n_eff       Rhat
            coefs[0]       1.86       0.37       1.72       2.00      10.49       1.10
            coefs[1]      -0.04       0.01      -0.06      -0.03      21.48       1.00
            coefs[2]      -0.07       0.03      -0.12      -0.04       9.70       1.11
            coefs[3]      -0.28       0.07      -0.31      -0.25      10.50       1.10
            coefs[4]      -0.09       0.01      -0.10      -0.08      27.55       1.00
            coefs[5]      -0.13       0.06      -0.15      -0.11      10.55       1.10
            coefs[6]       0.17       0.16      -0.18       0.29       5.04       1.28
            coefs[7]      -0.58       0.17      -0.72      -0.34       5.29       1.24
            coefs[8]       0.47       0.22       0.06       0.69       4.68       1.30
            coefs[9]      -0.01       0.01      -0.02      -0.01      45.16       1.00
           coefs[10]       0.81       0.4

In CPU, we get `avg. time for each step : 0.029775922484157266`.

### Average time for each leapfrog (verlet) step

|               |    HMC    |    NUTS   |
| ------------- |:---------:|:---------:|
| Edward2 (CPU) |           |  68.4 ms  |
| Edward2 (GPU) |           |   9.7 ms  |
| NumPyro (CPU) |  30.3 ms  |  29.8 ms  |
| NumPyro (GPU) |   4.3 ms  |   4.7 ms  |

*Note:* Edward 2 results are obtained from reference [1], which is run under a different environment system.

**Some takeaways:**
+ The overhead of iterative NUTS is small. So most of computation time is indeed spent for evaluating potential function and its gradient.
+ GPU outperforms CPU by a large margin. The data is large, so evaluating potential function in GPU is clearly faster than doing so in CPU.
+ Iterative NUTS is 2.2x faster (in both GPU and CPU) than the reported speed in reference [2]. This illustates the benefit of a graph-mode (using iterative algorithm) over an eager-mode (using recursive algorithm).

## References

1. `Simple, Distributed, and Accelerated Probabilistic Programming,` [arxiv](https://arxiv.org/abs/1811.02091)<br>
Dustin Tran, Matthew D. Hoffman, Dave Moore, Christopher Suter, Srinivas Vasudevan, Alexey Radul, Matthew Johnson, Rif A. Saurous