# Normalized vs Non-normalized log_prob: Difference in speed?

Notebook to test the difference in NumPyro inference speed vs:
- Normalized Poisson `log_prob`
- Non-Normalized Poisson `log_prob`

Related discourse thread: https://forum.pyro.ai/t/unnormalized-densities/3251

In [None]:
# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

In [None]:
import sys
import warnings
import time

import numpy as np

import jax
import jax.numpy as jnp

import numpyro
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.distributions.util import validate_sample

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm  # Colormaps
import seaborn as sns
import arviz as az

import tqdm

In [None]:
sns.set_style('darkgrid')
az.rcParams['stats.hdi_prob'] = 0.90
az.style.use("arviz-darkgrid")

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(1)

In [None]:
np.random.seed(42)
rng_key = jax.random.PRNGKey(42)

In [None]:
warnings.filterwarnings('ignore')

In [None]:
k = 5

## Data

In [None]:
n = 10_000
true_rate = 251.34

observations = np.random.poisson(lam=true_rate, size=n)
print(observations.shape)

In [None]:
observations.mean()

### Rate inference: Normalized

In [None]:
def model_poisson(obs=None):
    rate = numpyro.sample("rate", dist.ImproperUniform(dist.constraints.positive, (), ()))
#     rate = numpyro.sample("rate", dist.HalfCauchy(scale=100.0))
    numpyro.sample('obs', dist.Poisson(rate=rate), obs=obs)

In [None]:
rng_key = jax.random.PRNGKey(42)

num_warmup, num_samples = 1000, 10000

# Run NUTS.
kernel_poisson = NUTS(model_poisson)
mcmc_poisson = MCMC(
    kernel_poisson,
    num_warmup=num_warmup,
    num_samples=num_samples,
    num_chains=4,
    progress_bar=True,
)
# Run once to compile
mcmc_poisson.run(rng_key, obs=observations)

# Show trace
display(az.summary(mcmc_poisson, var_names=["~log_rate"], round_to=2))
inference_data_poisson = az.from_numpyro(
    posterior=mcmc_poisson,
)

az.plot_trace(
    inference_data_poisson,
    compact=True,
    var_names=["~log_rate"],
    lines=[
        ("rate", {}, true_rate),
    ],
)
plt.suptitle('Trace plots', fontsize=18)
plt.show()


### Rate inference: Un-Normalized

- http://sherrytowers.com/2014/07/10/poisson-likelihood/

In [None]:
class PoissonUN(dist.Distribution):
    arg_constraints = {"rate": constraints.positive}
    support = constraints.nonnegative_integer

    def __init__(self, rate, *, validate_args=None):
        self.rate = rate
        super().__init__(jnp.shape(rate), validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        assert is_prng_key(key)
        return random.poisson(key, self.rate, shape=sample_shape + self.batch_shape)

    @validate_sample
    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        value = jax.device_get(value)
        return (jnp.log(self.rate) * value) - self.rate

In [None]:
def model_poisson_unnormalized(obs=None):
    rate = numpyro.sample("rate", dist.ImproperUniform(dist.constraints.positive, (), ()))
#     rate = numpyro.sample("rate", dist.HalfCauchy(scale=100.0))
    numpyro.sample('obs', PoissonUN(rate=rate), obs=obs)

In [None]:
rng_key = jax.random.PRNGKey(42)

num_warmup, num_samples = 1000, 10000

# Run NUTS.
kernel_poisson_un = NUTS(model_poisson_unnormalized)
mcmc_poisson_un  = MCMC(
    kernel_poisson_un,
    num_warmup=num_warmup,
    num_samples=num_samples,
    num_chains=4,
    progress_bar=True,
)
mcmc_poisson_un.run(rng_key, obs=observations)

# Show trace
display(az.summary(mcmc_poisson_un, var_names=["~log_rate"], round_to=2))
inference_data_poisson_un = az.from_numpyro(
    posterior=mcmc_poisson_un,
)

az.plot_trace(
    inference_data_poisson_un,
    compact=True,
    var_names=["~log_rate"],
    lines=[
        ("rate", {}, true_rate),
    ],
)
plt.suptitle('Trace plots', fontsize=18)
plt.show()


# Comparisons

In [None]:
def run_with_timing(model_fn, n_runs, n_warmup, n_samples, obs, true_rate):
    """Original"""
    rng_key = jax.random.PRNGKey(42)
    # Run NUTS.
    kernel = NUTS(model_fn)
    mcmc = MCMC(
        kernel,
        num_warmup=n_warmup,
        num_samples=n_samples,
        num_chains=1,
        progress_bar=False,
    )
    # Run once to compile
    mcmc.run(rng_key, obs=obs)
    # Run k times to time
    times = []
    for _ in range(n_runs):
        start_time = time.monotonic()
        mcmc.run(rng_key, obs=obs)
        stop_time = time.monotonic()
        times.append(stop_time - start_time)
        posterior_samples = mcmc.get_samples()
    times = np.array(times)
    median_time = np.median(times)
    mad_time = np.median(np.abs(times - median_time))
    rate_error = np.abs(posterior_samples["rate"] - true_rate)
    mean_rate_error = np.mean(rate_error)
    std_rate_error = np.std(rate_error)
    return (median_time, mad_time), (mean_rate_error, std_rate_error)

## Comparison (varying True rate)

In [None]:
true_rates = [1., 2., 4., 8., 16., 32., 64., 128., 256., 512.]

data_size = 5000

In [None]:
np.random.seed(42)

median_times_normalized = np.zeros_like(true_rates, dtype=np.float32)
mad_times_normalized = np.zeros_like(true_rates, dtype=np.float32)
mean_rate_errors_normalized = np.zeros_like(true_rates, dtype=np.float32)
std_rate_errors_normalized = np.zeros_like(true_rates, dtype=np.float32)

median_times_unnormalized = np.zeros_like(true_rates, dtype=np.float32)
mad_times_unnormalized = np.zeros_like(true_rates, dtype=np.float32)
mean_rate_errors_unnormalized = np.zeros_like(true_rates, dtype=np.float32)
std_rate_errors_unnormalized = np.zeros_like(true_rates, dtype=np.float32)

n_runs = 5
n_warmup = 500
n_samples = 5000

pbar = tqdm.tqdm(true_rates)
for i, true_rate in enumerate(pbar):
    pbar.set_description(f"#rate = {true_rate}")
    observations = np.random.poisson(lam=true_rate, size=data_size)
    # Run normalized
    (median_time_normalized, mad_time_normalized), (mean_rate_error_normalized, std_rate_error_normalized) = run_with_timing(
        model_fn=model_poisson,
        n_runs=n_runs,
        n_warmup=n_warmup,
        n_samples=n_samples,
        obs=observations,
        true_rate=true_rate
    )
    median_times_normalized[i] = median_time_normalized
    mad_times_normalized[i] = mad_time_normalized
    mean_rate_errors_normalized[i] = mean_rate_error_normalized
    std_rate_errors_normalized[i] = std_rate_error_normalized
    # Run non-normalized
    (median_time_unnormalized, mad_time_unnormalized), (mean_rate_error_unnormalized, std_rate_error_unnormalized) = run_with_timing(
        model_fn=model_poisson_unnormalized,
        n_runs=n_runs,
        n_warmup=n_warmup,
        n_samples=n_samples,
        obs=observations,
        true_rate=true_rate
    )
    median_times_unnormalized[i] = median_time_unnormalized
    mad_times_unnormalized[i] = mad_time_unnormalized
    mean_rate_errors_unnormalized[i] = mean_rate_error_unnormalized
    std_rate_errors_unnormalized[i] = std_rate_error_unnormalized

In [None]:
fig, ax1 = plt.subplots(nrows=1, ncols=1, figsize=(8, 4))

ax1.plot(true_rates, median_times_normalized, "o-", color="blue", label="Normalized")
ax1.fill_between(
    true_rates, median_times_normalized-mad_times_normalized, median_times_normalized+mad_times_normalized,
    color="blue", alpha=0.15)
ax1.plot(true_rates, median_times_unnormalized, "o-", color="red", label="Non-Normalized")
ax1.fill_between(
    true_rates, median_times_unnormalized-mad_times_unnormalized, median_times_unnormalized+mad_times_unnormalized,
    color="red", alpha=0.15)
ax1.set_xscale("log", base=2)
ax1.set_yscale("log")
ax1.set_xlabel("true-rate")
ax1.set_ylabel("time (seconds)")
ax1.set_title("Inference time")
ax1.legend()

fig, ax2 = plt.subplots(nrows=1, ncols=1, figsize=(8, 4))

ax2.plot(true_rates, mean_rate_errors_normalized, "o-", color="blue", label="Normalized")
ax2.fill_between(
    true_rates, mean_rate_errors_normalized-std_rate_errors_normalized, mean_rate_errors_normalized+std_rate_errors_normalized,
    color="blue", alpha=0.15)
ax2.plot(true_rates, mean_rate_errors_unnormalized, "o-", color="red", label="Non-Normalized")
ax2.fill_between(
    true_rates, mean_rate_errors_unnormalized-std_rate_errors_unnormalized, mean_rate_errors_unnormalized+std_rate_errors_unnormalized,
    color="red", alpha=0.15)
ax2.set_xscale("log", base=10)

ax2.set_xlabel("true-rate")
ax2.set_ylabel("Error")
ax2.set_title("Inference error on \"Rate\"")
ax2.legend()
plt.show()

## Comparison (varying data size)

In [None]:
true_rate = 251.34

data_sizes = [2, 5, 10, 50, 100, 500, 1000, 5000, 10_000, 25_000]

In [None]:
np.random.seed(42)

median_times_normalized = np.zeros_like(data_sizes, dtype=np.float32)
mad_times_normalized = np.zeros_like(data_sizes, dtype=np.float32)
mean_rate_errors_normalized = np.zeros_like(data_sizes, dtype=np.float32)
std_rate_errors_normalized = np.zeros_like(data_sizes, dtype=np.float32)

median_times_unnormalized = np.zeros_like(data_sizes, dtype=np.float32)
mad_times_unnormalized = np.zeros_like(data_sizes, dtype=np.float32)
mean_rate_errors_unnormalized = np.zeros_like(data_sizes, dtype=np.float32)
std_rate_errors_unnormalized = np.zeros_like(data_sizes, dtype=np.float32)

n_runs = 5
n_warmup = 500
n_samples = 5000

pbar = tqdm.tqdm(data_sizes)
for i, n in enumerate(pbar):
    pbar.set_description(f"#samples = {n}")
    observations = np.random.poisson(lam=true_rate, size=n)
    # Run normalized
    (median_time_normalized, mad_time_normalized), (mean_rate_error_normalized, std_rate_error_normalized) = run_with_timing(
        model_fn=model_poisson,
        n_runs=n_runs,
        n_warmup=n_warmup,
        n_samples=n_samples,
        obs=observations,
        true_rate=true_rate
    )
    median_times_normalized[i] = median_time_normalized
    mad_times_normalized[i] = mad_time_normalized
    mean_rate_errors_normalized[i] = mean_rate_error_normalized
    std_rate_errors_normalized[i] = std_rate_error_normalized
    # Run non-normalized
    (median_time_unnormalized, mad_time_unnormalized), (mean_rate_error_unnormalized, std_rate_error_unnormalized) = run_with_timing(
        model_fn=model_poisson_unnormalized,
        n_runs=n_runs,
        n_warmup=n_warmup,
        n_samples=n_samples,
        obs=observations,
        true_rate=true_rate
    )
    median_times_unnormalized[i] = median_time_unnormalized
    mad_times_unnormalized[i] = mad_time_unnormalized
    mean_rate_errors_unnormalized[i] = mean_rate_error_unnormalized
    std_rate_errors_unnormalized[i] = std_rate_error_unnormalized

In [None]:
fig, ax1 = plt.subplots(nrows=1, ncols=1, figsize=(8, 4))

ax1.plot(data_sizes, median_times_normalized, "o-", color="blue", label="Normalized")
ax1.fill_between(
    data_sizes, median_times_normalized-mad_times_normalized, median_times_normalized+mad_times_normalized,
    color="blue", alpha=0.15)
ax1.plot(data_sizes, median_times_unnormalized, "o-", color="red", label="Non-Normalized")
ax1.fill_between(
    data_sizes, median_times_unnormalized-mad_times_unnormalized, median_times_unnormalized+mad_times_unnormalized,
    color="red", alpha=0.15)
ax1.set_xscale("log", base=10)
# ax1.set_yscale("log")
ax1.set_xlabel("#data")
ax1.set_ylabel("time (seconds)")
ax1.set_title("Inference time")
ax1.legend()

fig, ax2 = plt.subplots(nrows=1, ncols=1, figsize=(8, 4))

ax2.plot(data_sizes, mean_rate_errors_normalized, "o-", color="blue", label="Normalized")
ax2.fill_between(
    data_sizes, mean_rate_errors_normalized-std_rate_errors_normalized, mean_rate_errors_normalized+std_rate_errors_normalized,
    color="blue", alpha=0.15)
ax2.plot(data_sizes, mean_rate_errors_unnormalized, "o-", color="red", label="Non-Normalized")
ax2.fill_between(
    data_sizes, mean_rate_errors_unnormalized-std_rate_errors_unnormalized, mean_rate_errors_unnormalized+std_rate_errors_unnormalized,
    color="red", alpha=0.15)
ax2.set_xscale("log", base=10)

ax2.set_xlabel("#data")
ax2.set_ylabel("Error")
ax2.set_title("Inference error on \"Rate\"")
ax2.legend()
plt.show()