In [209]:
import os
import time

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

from jax import vmap
import jax.numpy as jnp
import jax.random as random

import numpyro
from numpyro import handlers
from numpyro.distributions import constraints
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO, autoguide

from typing import Callable, Optional

In [15]:
matplotlib.use("nbAgg")  # noqa: E402

In [96]:
class Args:
    pass
args = Args()
args.maxiter = 5000
args.num_samples = 2000
args.num_warmup = 1000
args.num_chains = 1
args.num_data = 100
args.num_hidden = 5
args.device = "cpu"

numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)

N, D_X, D_H = args.num_data, 3, args.num_hidden

In [97]:
# Define toy regression problem
# create artificial regression dataset
def get_data(N=50, D_X=3, sigma_obs=0.05, N_test=500):
    D_Y = 1  # create 1d outputs
    np.random.seed(0)
    X = jnp.linspace(-1, 1, N)
    X = jnp.power(X[:, np.newaxis], jnp.arange(D_X))
    W = 0.5 * np.random.randn(D_X)
    # y = w0 + w1*x + w2*x**2 + 1/2 (1/2+x)**2 * sin(4x)
    Y = jnp.dot(X, W) + 0.5 * jnp.power(0.5 + X[:, 1], 2.0) * jnp.sin(4.0 * X[:, 1])
    Y += sigma_obs * np.random.randn(N)
    Y = Y[:, np.newaxis]
    Y -= jnp.mean(Y)
    Y /= jnp.std(Y)

    assert X.shape == (N, D_X)
    assert Y.shape == (N, D_Y)

    X_test = jnp.linspace(-1.3, 1.3, N_test)
    X_test = jnp.power(X_test[:, np.newaxis], jnp.arange(D_X))

    return X, Y, X_test

In [98]:
X, Y, X_test = get_data(N=N, D_X=D_X)

In [210]:
class BayesianNeuralNetwork:
    def __init__(self,
                 nonlin: Callable[[jnp.ndarray], jnp.ndarray],
                 D_X: int,
                 D_Y: int,
                 D_H: int,
                 L: int,
                 biases: bool):
        self._nonlin = nonlin
        self.D_X = D_X
        self.D_Y = D_Y
        self.D_H = D_H
        assert L >= 2
        self.L = L
        assert not biases
        self._biases = biases
        # Initialise priors to independent standard normals
        self._prior_w = dist.MultivariateNormal(jnp.zeros(self._get_weight_dim()), jnp.eye(self._get_weight_dim()))
        self._prior_prec_obs = dist.Gamma(3.0, 1.0)

    def _get_weight_dim(self) -> int:
        assert not self._biases
        return self.D_X * self.D_H + (self.L - 2) * self.D_H * self.D_H + self.D_H * self.D_Y

    def _wi_from_flat(self, a: jnp.ndarray, depth: int) -> jnp.ndarray:
        assert a.shape[0] == self._get_weight_dim()
        assert 0 <= depth < self.L
        assert not self._biases
        if depth == 0:
            return a[:(self.D_X * self.D_H)].reshape((self.D_X, self.D_H))
        if depth == self.L-1:
            return a[-(self.D_H * self.D_Y):].reshape((self.D_H, self.D_Y))
        mid = a[(self.D_X * self.D_H):-(self.D_H * self.D_Y)]
        mid_mat_size = self.D_H * self.D_H
        return mid[(depth-1)*mid_mat_size:depth*mid_mat_size].reshape((self.D_H, self.D_H))

    # def _flat_from_wis(self, layers: list[jnp.ndarray]) -> jnp.ndarray:
    #     return jnp.concatenate([elem.reshape(-1) for elem in layers])

    #noinspection PyPep8Naming
    def __call__(self, X: jnp.ndarray, Y: Optional[jnp.ndarray] = None):
        N, D_X = X.shape
        assert D_X == self.D_X

        # sample weights from prior
        w = numpyro.sample("w", self._prior_w)

        pre_activ = jnp.matmul(X, self._wi_from_flat(w, depth=0))
        for depth in range(1, self.L):
            pre_activ = jnp.matmul(self._nonlin(pre_activ), self._wi_from_flat(w, depth))

        if Y is not None:
            assert pre_activ.shape == Y.shape

        # we put a prior on the observation noise
        prec_obs = numpyro.sample("prec_obs", self._prior_prec_obs)
        sigma_obs = numpyro.deterministic("sigma_obs", 1.0 / jnp.sqrt(prec_obs))

        # observe data
        numpyro.sample("Y", dist.Normal(pre_activ, jnp.full((N, self.D_Y), sigma_obs)), obs=Y)

    def set_prior(self, prior_w: dist.Distribution, prior_prec_obs: dist.Distribution):
        self._prior_w = prior_w
        self._prior_prec_obs = prior_prec_obs


In [211]:
bnn = BayesianNeuralNetwork(
    nonlin=jnp.tanh,
    D_X=X.shape[1],
    D_Y=1,
    D_H=5,
    L=3,
    biases=False
)

In [173]:
# Utils for flattening/de-flattening NN weights
# This is so we can sample from the joint
def layers_from_flattened_weights(a: jnp.ndarray) -> jnp.ndarray:
    pass


In [103]:
# the non-linearity we use in our neural network
def nonlin(x):
    return jnp.tanh(x)

# a two-layer bayesian neural network with computational flow
# given by D_X => D_H => D_H => D_Y where D_H is the number of
# hidden units. (note we indicate tensor dimensions in the comments)
def model(X, Y, D_H, D_Y=1):
    N, D_X = X.shape

    # sample first layer (we put unit normal priors on all weights)
    w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H))))
    assert w1.shape == (D_X, D_H)
    z1 = numpyro.deterministic("z1", nonlin(jnp.matmul(X, w1)))  # <= first layer of activations
    assert z1.shape == (N, D_H)

    # sample second layer
    w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((D_H, D_H)), jnp.ones((D_H, D_H))))
    assert w2.shape == (D_H, D_H)
    z2 = numpyro.deterministic("z2", nonlin(jnp.matmul(z1, w2)))  # <= second layer of activations
    assert z2.shape == (N, D_H)

    # sample final layer of weights and neural network output
    w3 = numpyro.sample("w3", dist.Normal(jnp.zeros((D_H, D_Y)), jnp.ones((D_H, D_Y))))
    assert w3.shape == (D_H, D_Y)
    z3 = numpyro.deterministic("z3", jnp.matmul(z2, w3))  # <= output of the neural network
    assert z3.shape == (N, D_Y)

    if Y is not None:
        assert z3.shape == Y.shape

    # we put a prior on the observation noise
    prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
    sigma_obs = numpyro.deterministic("sigma_obs", 1.0 / jnp.sqrt(prec_obs))

    # observe data
    numpyro.sample("Y", dist.Normal(z3, sigma_obs*jnp.ones((N, D_Y))), obs=Y)
    # with numpyro.plate("data", N):
        # note we use to_event(1) because each observation has shape (1,)
        # numpyro.sample("Y", dist.Normal(z3, sigma_obs).to_event(1), obs=Y)

In [212]:
# do HMC inference
rng_key, rng_key_predict = random.split(random.PRNGKey(0))
start = time.time()
kernel = NUTS(bnn)
mcmc = MCMC(
    kernel,
    num_warmup=args.num_warmup,
    num_samples=args.num_samples,
    num_chains=args.num_chains,
    progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(rng_key, X, Y)
mcmc.print_summary()
print("\nMCMC elapsed time:", time.time() - start)
samples = mcmc.get_samples()

sample: 100%|██████████| 3000/3000 [01:39<00:00, 30.23it/s, 1023 steps of size 3.55e-03. acc. prob=0.94]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
  prec_obs     18.14      2.69     18.08     13.83     22.51   1994.50      1.00
      w[0]     -0.07      1.13     -0.12     -1.89      1.79    331.30      1.01
      w[1]     -0.02      1.21      0.06     -2.07      1.84    183.74      1.00
      w[2]      0.05      1.18     -0.00     -1.77      2.07    133.95      1.01
      w[3]      0.08      1.14      0.10     -1.63      2.14    226.14      1.00
      w[4]      0.06      1.17      0.10     -1.82      2.03    225.54      1.03
      w[5]     -0.03      1.13     -0.07     -1.81      1.70    229.72      1.00
      w[6]     -0.14      1.14     -0.21     -1.86      1.68    141.88      1.01
      w[7]     -0.02      1.19     -0.05     -1.86      1.78    171.68      1.00
      w[8]     -0.08      1.18     -0.13     -1.82      1.75    157.54      1.00
      w[9]     -0.07      1.18     -0.06     -1.87      1.75    201.13      1.00
     w[10]     -0.01      1

In [213]:
# HMC predictions
hmc_predictions = Predictive(bnn, samples)(rng_key_predict, X=X_test, Y=None)['Y'][..., 0]
# compute mean prediction and confidence interval around median
hmc_mean_prediction = jnp.mean(hmc_predictions, axis=0)
hmc_percentiles = np.percentile(hmc_predictions, [5.0, 95.0], axis=0)

In [214]:
fig, ax = plt.subplots(figsize=(6,4))
ax.plot(X_test[:, 1], hmc_mean_prediction, color="blue")
ax.fill_between(X_test[:, 1], *hmc_percentiles, color="lightblue")
plt.show()

<IPython.core.display.Javascript object>

In [144]:
# Variational guide mirroring the architecture of model
def mean_field_guide(X, Y, D_H, D_Y=1):
    N, D_X = X.shape

    # sample first layer
    w1_loc = numpyro.param("w1_loc", lambda rng_key: dist.Normal().sample(rng_key, (D_X, D_H)))
    w1_scale = numpyro.param("w1_scale", 0.1*jnp.ones((D_X, D_H)), constraint=constraints.positive)
    w1 = numpyro.sample("w1", dist.Normal(w1_loc, w1_scale))
    assert w1.shape == (D_X, D_H)
    # z1 = numpyro.deterministic("z1", nonlin(jnp.matmul(X, w1)))  # <= first layer of activations
    # assert z1.shape == (N, D_H)

    # sample second layer
    w2_loc = numpyro.param("w2_loc", lambda rng_key: dist.Normal().sample(rng_key, (D_H, D_H)))
    # w2_loc = numpyro.param("w2_loc", jnp.zeros((D_H, D_H)))
    w2_scale = numpyro.param("w2_scale", 0.1*jnp.ones((D_H, D_H)), constraint=constraints.positive)
    w2 = numpyro.sample("w2", dist.Normal(w2_loc, w2_scale))
    assert w2.shape == (D_H, D_H)
    # z2 = numpyro.deterministic("z2", nonlin(jnp.matmul(z1, w2)))  # <= second layer of activations
    # assert z2.shape == (N, D_H)

    # sample final layer of weights and neural network output
    w3_loc = numpyro.param("w3_loc", lambda rng_key: dist.Normal().sample(rng_key, (D_H, D_Y)))
    # w3_loc = numpyro.param("w3_loc", jnp.zeros((D_H, D_Y)))
    w3_scale = numpyro.param("w3_scale", 0.1*jnp.ones((D_H, D_Y)), constraint=constraints.positive)
    w3 = numpyro.sample("w3", dist.Normal(w3_loc, w3_scale))
    assert w3.shape == (D_H, D_Y)
    # z3 = numpyro.deterministic("z3", jnp.matmul(z2, w3))  # <= output of the neural network
    # assert z3.shape == (N, D_Y)

    # if Y is not None:
    #     assert z3.shape == Y.shape

    # we put a prior on the observation noise
    # prec_concentration = numpyro.param("prec_concentration", 3.0)
    # prec_rate = numpyro.param("prec_rate", 1.0)
    # prec_obs = numpyro.sample("prec_obs", dist.Gamma(prec_concentration, prec_rate))
    prec_obs_loc = numpyro.param("prec_obs_loc", 1.0, constraint=constraints.positive)
    prec_obs  =  numpyro.sample("prec_obs", dist.Delta(prec_obs_loc))
    sigma_obs = numpyro.deterministic("sigma_obs", 1.0 / jnp.sqrt(prec_obs))

    # observe data
    # with numpyro.plate("data", N):
    #     note we use to_event(1) because each observation has shape (1,)
    #     return numpyro.sample("Y", dist.Normal(z3, sigma_obs).to_event(1))

In [215]:
# Using bnn
def mean_field_guide(X, Y=None):
    w_loc = numpyro.param("w_loc", lambda rng_key: dist.Normal().sample(rng_key, (bnn._get_weight_dim(),)))
    w_scale = numpyro.param("w_scale", jnp.full((bnn._get_weight_dim(),), 0.1), constraint=constraints.positive)
    numpyro.sample("w", dist.Normal(w_loc, w_scale).to_event(1))
    prec_obs_loc = numpyro.param("prec_obs_loc", 1.0, constraint=constraints.positive)
    numpyro.sample("prec_obs", dist.Delta(prec_obs_loc))

In [216]:
optimizer = numpyro.optim.Adam(0.005)
svi = SVI(bnn, mean_field_guide, optimizer, Trace_ELBO())
svi_results = svi.run(rng_key, 10_000, X=X, Y=Y)
mean_field_params = svi_results.params

100%|██████████| 10000/10000 [00:14<00:00, 669.34it/s, init loss: 414.6161, avg. loss [9501-10000]: 117.6498]


In [146]:
mean_field_params.keys()

dict_keys(['prec_obs_loc', 'w1_loc', 'w1_scale', 'w2_loc', 'w2_scale', 'w3_loc', 'w3_scale'])

In [147]:
# Predictive(model, guide=mean_field_guide, params=mean_field_params, num_samples=args.num_samples)(rng_key_predict, X=X_test, Y=None, D_H=D_H).keys()

In [217]:
predictions = Predictive(bnn, guide=mean_field_guide, params=mean_field_params, num_samples=args.num_samples)(rng_key_predict, X=X_test, Y=None)['Y'][..., 0]
# compute mean prediction and confidence interval around median
mean_prediction = jnp.mean(predictions, axis=0)
percentiles = np.percentile(predictions, [5.0, 95.0], axis=0)
fig, ax = plt.subplots(figsize=(6, 4))
ax.plot(X_test[:, 1], hmc_mean_prediction, color="red", label="HMC mean predictions")
ax.plot(X_test[:, 1], mean_prediction, color="blue", label="VI mean")
ax.fill_between(X_test[:, 1], *percentiles, color="lightblue")
ax.legend()
plt.show()

<IPython.core.display.Javascript object>

In [208]:
plt.close('all')