In [1]:
import pickle

import jax.nn
import jax.random as random
import matplotlib.pyplot as plt
import numpy as np
import numpyro

In [3]:
%load_ext autoreload
%autoreload 2

from experiments.src.experiment import *
from experiments.src.data import ToyData1
from experiments.src.model import BNNRegressor

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
# %matplotlib inline
# matplotlib.use("nbAgg")  # noqa: E402

plt.rcParams.update({
    "axes.grid": True,      # show grid by default
    "font.weight": "bold",  # bold fonts
    "xtick.labelsize": 15,  # large tick labels
    "ytick.labelsize": 15,  # large tick labels
    "lines.linewidth": 1,   # thick lines
    "lines.color": "k",     # black lines
    # "grid.color": "0.5",    # gray gridlines
    "grid.linestyle": "-",  # solid gridlines
    "grid.linewidth": 0.1,  # thin gridlines
    "savefig.dpi": 300,     # higher resolution output.
})

In [None]:
DEVICE = "gpu"
numpyro.set_platform(DEVICE)
# numpyro.set_host_device_count(NUM_CHAINS)
D_X = 3
VI_MAX_ITER = 100_000
BNN_SIZE = [32, 32, 16]
# FIG_PREFIX = f"fig"

In [None]:
data = ToyData1(D_X=D_X, train_size=100)

In [None]:
bnn = BNNRegressor(
    nonlin=jax.nn.silu,
    D_X=D_X,
    D_Y=1,
    D_H=BNN_SIZE,
    biases=True,
    prior_scale=1,
)

In [None]:
bnn.get_weight_dim()

## Experiment

In [None]:
experiment = BasicHMCExperiment(
    bnn,
    data,
    num_samples = 1000, num_warmup = 500,
    num_chains = 4, group_by_chain=True
)

In [None]:
experiment.train(random.PRNGKey(0))

In [None]:
with open("hmc-samples-first-half.pkl", "wb") as f:
    pickle.dump(experiment._samples, f)

In [None]:
experiment.make_predictions(random.PRNGKey(1))

In [None]:
with open("hmc-preds-first-half.pkl", "wb") as f:
    pickle.dump(experiment._predictions, f)

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(12, 10), sharex='all', sharey='all')
for i, ax in enumerate(axs.ravel()):
    ax.plot(data.test[0][:, 1], experiment._predictions["Y_mean"][..., 0][i].mean(axis=0))
    ax.fill_between(data.test[0][:, 1],
                    *np.percentile(experiment._predictions["Y_mean"][..., 0][i], (5.0, 95.0), axis=0), alpha=0.5,
                    color="orange")
    ax.fill_between(data.test[0][:, 1], *np.percentile(experiment._predictions["Y"][..., 0][i], (5.0, 95.0), axis=0),
                    alpha=0.5, color="lightgreen")
    ax.plot(data.train[0][:, 1], data.train[1], "kx")
    ax.set_ylim(-6, +6)
fig.tight_layout()
fig.savefig("figs/hmc-by-chain.png")