Skip to content

Commit

Permalink
add bayesian neural network example (#166)
Browse files Browse the repository at this point in the history
* initial commit

* address comments

* fix test examples
  • Loading branch information
martinjankowiak authored and neerajprad committed May 24, 2019
1 parent d61e614 commit 8a633e3
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 0 deletions.
132 changes: 132 additions & 0 deletions examples/bnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import matplotlib
matplotlib.use('Agg') # noqa: E402
import matplotlib.pyplot as plt

import argparse

import numpy as onp
from jax import vmap
import jax.numpy as np
import jax.random as random

import numpyro.distributions as dist
from numpyro.handlers import sample, seed, substitute, trace
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import mcmc


"""
We demonstrate how to use NUTS to do inference on a simple (small)
Bayesian neural network with two hidden layers.
"""


# the non-linearity we use in our neural network
def nonlin(x):
return np.tanh(x)


# a two-layer bayesian neural network with computational flow
# given by D_X => D_H => D_H => D_Y where D_H is the number of
# hidden units. (note we indicate tensor dimensions in the comments)
def model(X, Y, D_H):
D_X, D_Y = X.shape[1], 1

# sample first layer (we put unit normal priors on all weights)
w1 = sample("w1", dist.Normal(np.zeros((D_X, D_H)), np.ones((D_X, D_H)))) # D_X D_H
z1 = nonlin(np.matmul(X, w1)) # N D_H <= first layer of activations

# sample second layer
w2 = sample("w2", dist.Normal(np.zeros((D_H, D_H)), np.ones((D_H, D_H)))) # D_H D_H
z2 = nonlin(np.matmul(z1, w2)) # N D_H <= second layer of activations

# sample final layer of weights and neural network output
w3 = sample("w3", dist.Normal(np.zeros((D_H, D_Y)), np.ones((D_H, D_Y)))) # D_H D_Y
z3 = np.matmul(z2, w3) # N D_Y <= output of the neural network

# we put a prior on the observation noise
prec_obs = sample("prec_obs", dist.Gamma(3.0, 1.0))
sigma_obs = 1.0 / np.sqrt(prec_obs)

# observe data
sample("Y", dist.Normal(z3, sigma_obs), obs=Y)


# helper function for HMC inference
def run_inference(model, args, rng, X, Y, D_H):
init_params, potential_fn, constrain_fn = initialize_model(rng, model, X, Y, D_H)
samples = mcmc(args.num_warmup, args.num_samples, init_params,
sampler='hmc', potential_fn=potential_fn, constrain_fn=constrain_fn)
return samples


# helper function for prediction
def predict(model, rng, samples, X, D_H):
model = substitute(seed(model, rng), samples)
# note that Y will be sampled in the model because we pass Y=None here
model_trace = trace(model).get_trace(X=X, Y=None, D_H=D_H)
return model_trace['Y']['value']


# create artificial regression dataset
def get_data(N=50, D_X=3, sigma_obs=0.05, N_test=500):
D_Y = 1 # create 1d outputs
onp.random.seed(0)
X = np.linspace(-1, 1, N)
X = np.power(X[:, onp.newaxis], np.arange(D_X))
W = 0.5 * onp.random.randn(D_X)
Y = np.dot(X, W) + 0.5 * np.power(0.5 + X[:, 1], 2.0) * np.sin(4.0 * X[:, 1])
Y += sigma_obs * onp.random.randn(N)
Y = Y[:, onp.newaxis]
Y -= np.mean(Y)
Y /= np.std(Y)

assert X.shape == (N, D_X)
assert Y.shape == (N, D_Y)

X_test = np.linspace(-1.3, 1.3, N_test)
X_test = np.power(X_test[:, onp.newaxis], np.arange(D_X))

return X, Y, X_test


def main(args):
N, D_X, D_H = args.num_data, 3, args.num_hidden
X, Y, X_test = get_data(N=N, D_X=D_X)

# do inference
rng, rng_predict = random.split(random.PRNGKey(0))
samples = run_inference(model, args, rng, X, Y, D_H)

# predict Y_test at inputs X_test
vmap_args = (samples, random.split(rng_predict, args.num_samples))
predictions = vmap(lambda samples, rng: predict(model, rng, samples, X_test, D_H))(*vmap_args)
predictions = predictions[..., 0]

# compute mean prediction and confidence interval around median
mean_prediction = np.mean(predictions, axis=0)
percentiles = onp.percentile(predictions, [5.0, 95.0], axis=0)

# make plots
fig, ax = plt.subplots(1, 1)

# plot training data
ax.plot(X[:, 1], Y[:, 0], 'kx')
# plot 90% confidence level of predictions
ax.fill_between(X_test[:, 1], percentiles[0, :], percentiles[1, :], color='lightblue')
# plot mean prediction
ax.plot(X_test[:, 1], mean_prediction, 'blue', ls='solid', lw=2.0)
ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")

plt.savefig('bnn_plot.pdf')
plt.close()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Stochastic network")
parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
parser.add_argument("--num-data", nargs='?', default=100, type=int)
parser.add_argument("--num-hidden", nargs='?', default=5, type=int)
args = parser.parse_args()
main(args)
1 change: 1 addition & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
'baseball.py --num-samples 100 --num-warmup 100',
'covtype.py --algo hmc --num-samples 10',
'hmm.py --num-samples 100 --num-warmup 100',
'bnn.py --num-samples 10 --num-warmup 10 --num-data 7',
'minipyro.py',
'stochastic_volatility.py --num-samples 100 --num-warmup 100',
'ucbadmit.py',
Expand Down

0 comments on commit 8a633e3

Please sign in to comment.