# NumPyro SVI (Part 1 & 2)

Based on Pyro tutorial on SVI: 
- Part 1: http://pyro.ai/examples/svi_part_i.html
    - About SVI
- Part 2: http://pyro.ai/examples/svi_part_ii.html
    - About conditional independence via plates

NumPyro SVI documentation:
- http://num.pyro.ai/en/stable/svi.html

In [None]:
# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

In [None]:
import sys
import warnings

import numpy as np

import jax
import jax.numpy as jnp
from jax.experimental import optimizers

import numpyro
from numpyro.infer import SVI, Trace_ELBO, Predictive
import numpyro.distributions as dist

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
import seaborn as sns
import arviz as az
from tqdm import tqdm_notebook as tqdm

In [None]:
sns.set_style('darkgrid')
az.rcParams['stats.hdi_prob'] = 0.90
az.style.use("arviz-darkgrid")

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(8)

In [None]:
rng_key = jax.random.PRNGKey(42)

In [None]:
# create some data with 7 observed heads and 3 observed tails
data = jnp.concatenate([jnp.ones(7), jnp.zeros(3)])
data

In [None]:
def model(data):
    # define the hyperparameters that control the beta prior
    alpha0 = 10.
    beta0 = 10.
    # sample f from the beta prior
    f = numpyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    # loop over the observed data
    with numpyro.plate("N", data.shape[0]):
        # observe datapoint i using the bernoulli likelihood
        numpyro.sample("obs", dist.Bernoulli(f), obs=data)

In [None]:
def guide(data):
    # register the two variational parameters with NumPyro
    # - both parameters will have initial value set in `numpyro.param`.
    # - because we invoke constraints.positive, the optimizer
    # will take gradients on the unconstrained parameters
    # (which are related to the constrained parameters by a log)
    alpha_q = numpyro.param(
        "alpha_q", 15.,
        constraint=dist.constraints.positive
    )
    beta_q = numpyro.param(
        "beta_q", 15.,
        constraint=dist.constraints.positive
    )
    # sample latent_fairness from the distribution Beta(alpha_q, beta_q)
    numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

In [None]:
%%time

# setup the optimizer
optimizer = numpyro.optim.Adam(step_size=2e-4)

# setup the inference algorithm
svi = SVI(
    model=model,
    guide=guide,
    optim=optimizer,
    loss=Trace_ELBO()
)

# Run
svi_result = svi.run(
    jax.random.PRNGKey(0),
    num_steps=5000,
    data=data
)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(7, 3))
ax.plot(svi_result.losses)
ax.set_title("losses")
plt.show()

In [None]:
# grab the learned variational parameters
svi_result.params
alpha_q = svi_result.params["alpha_q"]
print("alpha_q: ", float(alpha_q))
beta_q = svi_result.params["beta_q"]
print("beta_q: ", float(beta_q))
                          
# here we use some facts about the beta distribution
# compute the inferred mean of the coin's fairness
inferred_mean = alpha_q / (alpha_q + beta_q)
# compute inferred standard deviation
factor = beta_q / (alpha_q * (1.0 + alpha_q + beta_q))
inferred_std = inferred_mean * jnp.sqrt(factor)
print("based on the data and our prior belief, the fairness "
      f"of the coin is {inferred_mean:.3f} +- {inferred_std:.3f}"
)

In [None]:
predictive = Predictive(
    guide,
    params=svi_result.params,
    num_samples=2500
)
samples = predictive(jax.random.PRNGKey(0), data)
sns.histplot(samples['latent_fairness'])
plt.show()