# Bayesian neural networks

We will consider fitting a Bayesian neural network to some toy data using `numpyro`.

Let's start by loading the required packages.

In [None]:
import os
import time

import matplotlib.pyplot as plt
import numpy as np

import jax
import jax.numpy as jnp
import jax.random as random

from sklearn.preprocessing import StandardScaler

!pip install numpyro
import numpyro
import numpyro.optim as optim
from numpyro import deterministic

import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal, AutoMultivariateNormal, AutoDelta


## Toy data

We start by generating data in a simple one-dimensional input and output setting. The inputs are generated from a uniform distribution, and the outputs are generated as noisy observation of the function $y=(x+2)^2(x/1-1)$ or $y = x^2$.

In [2]:
def function_true2(input):
  return jnp.power(input +2, 2.0)*(input/2-1)

def function_true(input):
  return jnp.power(input, 2.0)


In [None]:
N = 500 # sample size

# Generate inputs
X_train = random.uniform(random.PRNGKey(0), shape=(N,),  minval=-4, maxval=4)

# Generate outputs
sigma = 1
m_Y = function_true2(X_train)
# or m_Y = function_true2(X
y_train =  m_Y + sigma*random.normal(random.PRNGKey(1), shape=(N,))

# Plot the data
x_idx = jnp.argsort(X_train)
fig = plt.figure(figsize=(4, 3))
plt.plot(X_train[x_idx], m_Y[x_idx],  label = 'True function')
plt.scatter(X_train, y_train, marker = '.', linewidths = 0.1, label = 'Data')
plt.fill_between(X_train[x_idx], (m_Y - 2*sigma)[x_idx], (m_Y + 2*sigma)[x_idx], alpha = 0.2, label = r'true $\pm 2 \sigma$')
plt.legend()
plt.show()


Let's generate the test data.

In [None]:
N_test = 400 # sample size

# Generate test inputs
X_test = random.uniform(random.PRNGKey(0), shape=(N_test,),  minval=-4, maxval=4)

# Generate test outputs
m_Yt = function_true2(X_test)
# or m_Y = function_true2(X
y_test =  m_Yt + sigma*random.normal(random.PRNGKey(1), shape=(N_test,))


# Plot the training and test data together
fig, axs = plt.subplots(2,1,figsize=(4, 5), layout="constrained")

axs[0].scatter(x=X_train, y=y_train,  marker = '.', linewidths = 0.1, label="train")
axs[0].legend(loc="upper left")
axs[0].set(xlabel="x", ylabel="y")
axs[1].scatter(x=X_test, y=y_test,  marker = '.', linewidths = 0.1, label="test")
axs[1].legend(loc="upper left")
axs[1].set(xlabel="x", ylabel="y")

Let's rescale the data before fitting.

In [None]:
x_scaler = StandardScaler()
x_train_scaled = x_scaler.fit_transform(X_train[:, None])
x_train_scaled = jnp.array(x_train_scaled)
x_test_scaled = x_scaler.transform(X_test[:, None])
x_test_scaled = jnp.array(x_test_scaled)

fig, axs = plt.subplots(2,1,figsize=(4, 5), layout="constrained")

axs[0].scatter(x=x_train_scaled, y=y_train,  marker = '.', linewidths = 0.1, label="train")
axs[0].legend(loc="upper left")
axs[0].set(xlabel="x scaled", ylabel="y")
axs[1].scatter(x=x_test_scaled, y=y_test,  marker = '.', linewidths = 0.1, label="test")
axs[1].legend(loc="upper left")
axs[1].set(xlabel="x scaled", ylabel="y")


### BNN model

Let's build a simple BNN following this [paper](https://jmhl.org/wp-content/uploads/2015/05/pbp-icml2015.pdf), with code mostly from [numpyro documentation](https://num.pyro.ai/en/stable/examples/stein_bnn.html)

In [7]:
def model(x, y=None, hidden_dim=10):


    prec_nn = numpyro.sample(
        "prec_nn", dist.Gamma(6.0, 6.0)
    )  # hyper prior for precision of nn weights and biases

    n, m = x.shape

    with numpyro.plate("l1_hidden", hidden_dim, dim=-1):
        # prior l=1 bias term
        b1 = numpyro.sample(
            "nn_b1",
            dist.Normal(
                0.0,
                1.0 / jnp.sqrt(prec_nn*(m+1)),
            ),
        )
        assert b1.shape == (hidden_dim,)

        with numpyro.plate("l1_feat", m, dim=-2):
            w1 = numpyro.sample(
                "nn_w1", dist.Normal(0.0, 1.0 / jnp.sqrt(prec_nn*(m+1)))
            )  # prior on l=1 weights
            assert w1.shape == (m, hidden_dim)

    with numpyro.plate("l2_hidden", hidden_dim, dim=-1):
        w2 = numpyro.sample(
            "nn_w2", dist.Normal(0.0, 1.0 / jnp.sqrt(prec_nn*(hidden_dim+1)))
        )  # prior on output weights

    b2 = numpyro.sample(
        "nn_b2", dist.Normal(0.0, 1.0 / jnp.sqrt(prec_nn*(hidden_dim+1)))
    )  # prior on output bias term

    # precision prior on observations
    prec_obs = numpyro.sample("prec_obs", dist.Gamma(6, 6))

    loc_y = deterministic("y_pred", jnp.maximum(x @ w1 + b1, 0) @ w2 + b2)

    with numpyro.plate( "data", x.shape[0], dim=-1):

        numpyro.sample( "y",
            dist.Normal(loc_y, 1.0 / jnp.sqrt(prec_obs)),  obs=y,)

In [None]:
numpyro.render_model(model, model_args=(x_train_scaled, y_train, 20),
    render_distributions=True,
    render_params=True,)

## SVI

Now, we will use SVI to approximate the posterior. Note that the variational posterior is called `guide`. For details, see http://pyro.ai/examples/svi_part_i.html. Here we are assuming a (diagonal) normal using `AutoNormal`. For other automatic guide options, see https://num.pyro.ai/en/stable/autoguide.html.

In [9]:
numpyro.set_platform('cpu')
rng_key, rng_key_predict = random.split(random.PRNGKey(0))

In [10]:
def do_SVI(model, rng_key, numsteps, x, y, hidden_dim):
    guide = AutoNormal(model)
    optimizer = optim.Adam(0.01)
    svi = SVI(model, guide, optimizer, Trace_ELBO())
    svi_results = svi.run(rng_key = rng_key, num_steps = numsteps, x=x, y = y, hidden_dim = hidden_dim)
    params = svi_results.params
    losses = svi_results.losses
    return losses, params, guide



In [11]:
D_h = 20
numsteps= 1000
num_samples=2000

In [None]:
svi_losses, svi_params, svi_guide = do_SVI(model=model, rng_key=rng_key, numsteps=numsteps, x = x_train_scaled, y = y_train, hidden_dim=D_h)

In [13]:
posterior_predictive = Predictive(model=model, guide=svi_guide, params=svi_params, num_samples=num_samples)
svi_predictions = posterior_predictive(rng_key = rng_key_predict, x=x_test_scaled, y=None, hidden_dim=D_h)

Let's see visualize the predictions and uncertainty.

In [None]:
# Extract the mean and standard deviation for the predictions on the test data
mean_pred = jnp.mean(svi_predictions['y_pred'], axis = 0)
std_pred = jnp.std(svi_predictions['y_pred'], axis = 0)
x_idx = jnp.argsort(x_test_scaled[:, 0])

# make plots
fig, ax = plt.subplots(1, figsize=(6, 6), constrained_layout=True)

# plot training data
ax.plot(x_test_scaled[:, 0], y_test, '.', label = 'Samples for testing')
# plot mean prediction
ax.plot(x_test_scaled[:, 0][x_idx], mean_pred[x_idx], color = "purple", ls="solid", lw=2.0, label = 'Prediction mean')
# plot true mean
ax.plot(x_test_scaled[:, 0][x_idx], m_Yt[x_idx], color = "red", ls="solid", lw=2.0, label = 'True mean')
#plot uncertainty
ax.fill_between(x_test_scaled[:, 0][x_idx], mean_pred[x_idx] -2*std_pred[x_idx], mean_pred[x_idx] + 2*std_pred[x_idx], alpha = 0.2, label = 'Uncertainty')
ax.set(xlabel="X", ylabel="Y", title="SVI")
fig.legend()

In [None]:
plt.plot(jnp.arange(len(svi_losses)), -svi_losses, linewidth = 0.5, label = 'ELBO')
plt.legend()

Let's compare the predictive accuracy and the evaluate the uncertainty by counting the number of points contined within the credible interval for different levels

In [None]:
# MSE and R2
from sklearn.metrics import mean_squared_error, r2_score

print('MSE for the outputs is ', mean_squared_error(y_test,mean_pred))
print('MSE for the function is ', mean_squared_error(m_Yt,mean_pred))
print('R2 for the outputs is ', r2_score(y_test,mean_pred))

In [None]:
q = np.arange(0.5,1,0.01)
ec = np.zeros(q.shape)

for i in range(q.shape[0]):
  # Compute credible interval
  lower_pred = np.quantile(svi_predictions['y_pred'], q = (1-q[i])/2, axis = 0)
  upper_pred = np.quantile(svi_predictions['y_pred'], q = q[i]+ (1-q[i])/2, axis = 0)
  # Summarize coverage
  ec[i]= np.sum((m_Yt>lower_pred)&(m_Yt<upper_pred))/N_test

plt.plot(q, ec, linewidth = 2, label = 'Empirical Coverage')
plt.plot(q,q, linestyle = 'dashed', linewidth = 2)
plt.legend()
plt.show()

## Exercises
1. See how the results change by altering different settings:
  *   increasing the number of hidden units
  *   change the prior to Laplace (`dist.Laplace`)
  *   change the hyperparameters of the Gamma priors.
2. Try changing the variational posterior, e.g. to `AutoMultivariateNormal` to remove the mean-field assumption.
3. Try adding another layer

## MCMC

Now, let's compare with MCMC. Here we will use the NUTS (no-U-turn sampler). For other MCMC algorithms and options, see https://num.pyro.ai/en/stable/mcmc.html


In [18]:

def run_mcmc(model, num_chains, num_samples, num_warmup, rng_key, x, y, hidden_dim):
    start = time.time()
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        num_warmup=num_warmup,
        num_samples=num_samples,
        num_chains=num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(rng_key = rng_key, x=x, y=y, hidden_dim = hidden_dim)
    # mcmc.print_summary()
    print("\nMCMC  time:", time.time() - start)
    return mcmc.get_samples()

In [19]:
num_warmup=2000
num_samples=5000
num_chains=1
numpyro.set_host_device_count(num_chains)
D_h = 20

In [None]:
posterior_samples = run_mcmc(model = model,  num_chains = num_chains,num_samples = num_samples, num_warmup = num_warmup, rng_key = rng_key, x = x_train_scaled, y = y_train, hidden_dim=D_h)

In [21]:
predictive = Predictive(model = model, posterior_samples=posterior_samples)
predictions = predictive(rng_key = rng_key_predict, x=x_test_scaled, y = None, hidden_dim = D_h)

In [None]:

mean_pred = jnp.mean(predictions['y_pred'], axis = 0)
lower_pred = jnp.quantile(predictions['y_pred'], q = 0.0275, axis = 0)
upper_pred = jnp.quantile(predictions['y_pred'], q = 0.975, axis = 0)
x_idx = jnp.argsort(x_test_scaled[:, 0])

# make plots
fig, ax = plt.subplots(1, figsize=(6, 6), constrained_layout=True)

# plot training data
ax.plot(x_test_scaled[:, 0], y_test, '.', label = 'Samples for testing')
# plot mean prediction
# ax.plot(x_test_scaled[:, 0], mean_pred[:, 0], '.')
ax.plot(x_test_scaled[:, 0][x_idx], mean_pred[x_idx], color = "purple", ls="solid", lw=2.0, label = 'Prediction mean')
# plot true mean
ax.plot(x_test_scaled[:, 0][x_idx], m_Yt[x_idx], color = "red", ls="solid", lw=2.0, label = 'True mean')
#plot uncertainty
ax.fill_between(x_test_scaled[:, 0][x_idx], lower_pred[x_idx], upper_pred[x_idx], alpha = 0.2, label = 'Uncertainty')
ax.set(xlabel="X", ylabel="Y", title="MCMC")
fig.legend()

Let's compare the predictive accuracy and uncertainty with VB.

In [None]:
# MSE and R2
print('MSE for the outputs is ', mean_squared_error(y_test,mean_pred))
print('MSE for the function is ', mean_squared_error(m_Yt,mean_pred))
print('R2 for the outputs is ', r2_score(y_test,mean_pred))

In [None]:
q = np.arange(0.5,1,0.01)
ec_mcmc = np.zeros(q.shape)

for i in range(q.shape[0]):
  # Compute credible interval
  lower_pred = np.quantile(predictions['y_pred'], q = (1-q[i])/2, axis = 0)
  upper_pred = np.quantile(predictions['y_pred'], q = q[i]+ (1-q[i])/2, axis = 0)
  # Summarize coverage
  ec_mcmc[i]= np.sum((m_Yt>lower_pred)&(m_Yt<upper_pred))/N_test

plt.plot(q, ec, linewidth = 2, label = 'VB')
plt.plot(q,q, linestyle = 'dashed', linewidth = 2)
plt.plot(q, ec_mcmc, linewidth = 2, label = 'MCMC')
plt.legend()
plt.show()

## MAP

Now, let's try MAP estimation by changing the variational posterior to AutoDelta

In [25]:
def do_MAP(model, rng_key, numsteps, x, y, hidden_dim):
    guide = AutoDelta(model)
    optimizer = optim.Adam(0.01)
    svi = SVI(model, guide, optimizer, Trace_ELBO())
    svi_results = svi.run(rng_key = rng_key, num_steps = numsteps, x=x, y = y, hidden_dim = hidden_dim)
    params = svi_results.params
    losses = svi_results.losses
    return losses, params, guide

In [26]:
D_h = 20
numsteps= 1000
num_samples=2000

In [None]:
map_losses, map_params, map_guide = do_MAP(model=model, rng_key=rng_key, numsteps=numsteps, x = x_train_scaled, y = y_train, hidden_dim=D_h)

In [28]:
posterior_predictive = Predictive(model=model, guide=map_guide, params=map_params, num_samples=num_samples)
map_predictions = posterior_predictive(rng_key = rng_key_predict, x=x_test_scaled, y=None, hidden_dim=D_h)

In [None]:
mean_pred = jnp.mean(map_predictions['y_pred'], axis = 0)
lower_pred = jnp.quantile(map_predictions['y'], q = 0.0275, axis = 0)
upper_pred = jnp.quantile(map_predictions['y'], q = 0.975, axis = 0)
x_idx = jnp.argsort(x_test_scaled[:, 0])

# make plots
fig, ax = plt.subplots(1, figsize=(6, 6), constrained_layout=True)

# plot training data
ax.plot(x_test_scaled[:, 0], y_test, '.', label = 'Samples for testing')
# plot mean prediction
# ax.plot(x_test_scaled[:, 0], mean_pred[:, 0], '.')
ax.plot(x_test_scaled[:, 0][x_idx], mean_pred[x_idx], color = "purple", ls="solid", lw=2.0, label = 'Prediction mean')
# plot true mean
ax.plot(x_test_scaled[:, 0][x_idx], m_Yt[x_idx], color = "red", ls="solid", lw=2.0, label = 'True mean')
#plot uncertainty
ax.fill_between(x_test_scaled[:, 0][x_idx], lower_pred[x_idx], upper_pred[x_idx], alpha = 0.2, label = 'Uncertainty')
ax.set(xlabel="X", ylabel="Y", title="MAP")
fig.legend()


Now, let's add some additional out of sample test points and see what happens.

In [None]:
N_test2 = 500 # sample size

# Generate test inputs
X_test2 = random.uniform(random.PRNGKey(0), shape=(N_test2,),  minval=-5, maxval=5)

# Generate test outputs
m_Yt2 = function_true2(X_test2)
# or m_Y = function_true2(X
y_test2 =  m_Yt2 + sigma*random.normal(random.PRNGKey(1), shape=(N_test2,))


# Plot the training and test data together
fig, axs = plt.subplots(1,1,figsize=(6, 5), layout="constrained")

axs.scatter(x=X_train, y=y_train,  marker = '.', linewidths = 0.1, label="train")
axs.scatter(x=X_test2, y=y_test2,  marker = '.', linewidths = 0.1, label="test")
axs.legend(loc="upper left")
axs.set(xlabel="x", ylabel="y")

In [45]:
# Scale the data
x_test2_scaled = x_scaler.transform(X_test2[:, None])
x_test2_scaled = jnp.array(x_test2_scaled)

In [46]:
posterior_predictive = Predictive(model=model, guide=map_guide, params=map_params, num_samples=num_samples)
map_predictions2 = posterior_predictive(rng_key = rng_key_predict, x=x_test2_scaled, y=None, hidden_dim=D_h)

In [None]:
mean_pred = jnp.mean(map_predictions2['y_pred'], axis = 0)
lower_pred = jnp.quantile(map_predictions2['y'], q = 0.0275, axis = 0)
upper_pred = jnp.quantile(map_predictions2['y'], q = 0.975, axis = 0)
x_idx = jnp.argsort(x_test2_scaled[:, 0])

# make plots
fig, ax = plt.subplots(1, figsize=(6, 6), constrained_layout=True)

# plot training data
ax.plot(x_test2_scaled[:, 0], y_test2, '.', label = 'Samples for testing')
# plot mean prediction
# ax.plot(x_test_scaled[:, 0], mean_pred[:, 0], '.')
ax.plot(x_test2_scaled[:, 0][x_idx], mean_pred[x_idx], color = "purple", ls="solid", lw=2.0, label = 'Prediction mean')
# plot true mean
ax.plot(x_test2_scaled[:, 0][x_idx], m_Yt2[x_idx], color = "red", ls="solid", lw=2.0, label = 'True mean')
#plot uncertainty
ax.fill_between(x_test2_scaled[:, 0][x_idx], lower_pred[x_idx], upper_pred[x_idx], alpha = 0.2, label = 'Uncertainty')
ax.set(xlabel="X", ylabel="Y", title="MAP")
fig.legend()

### Exercises

Notice how MAP inference provides poor, overconfident out of sample predictions. How do MCMC and VI compare?