In [113]:
import pickle

import jax.nn
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import time
import tqdm.notebook as tqdm

In [114]:
%load_ext autoreload
%autoreload 2

from experiments.src.experiment import *
from experiments.src.data import ToyData1
from experiments.src.model import BNNRegressor

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [112]:
# %matplotlib inline
# matplotlib.use("nbAgg")  # noqa: E402

plt.rcParams.update({
    "axes.grid": True,      # show grid by default
    "font.weight": "bold",  # bold fonts
    "xtick.labelsize": 15,  # large tick labels
    "ytick.labelsize": 15,  # large tick labels
    "lines.linewidth": 1,   # thick lines
    "lines.color": "k",     # black lines
    # "grid.color": "0.5",    # gray gridlines
    "grid.linestyle": "-",  # solid gridlines
    "grid.linewidth": 0.1,  # thin gridlines
    "savefig.dpi": 300,     # higher resolution output.
})

In [129]:
DEVICE = "gpu"
numpyro.set_platform(DEVICE)
# numpyro.set_host_device_count(NUM_CHAINS)
D_X = 2
BNN_SIZE = [64, 128, 128, 128, 64]

In [116]:
data = ToyData1(D_X=D_X, train_size=100)

In [130]:
bnn = BNNRegressor(
    nonlin=jax.nn.silu,
    D_X=D_X,
    D_Y=1,
    D_H=BNN_SIZE,
    biases=True,
    prior_scale=10,
    prior_type='xavier',
)

In [131]:
bnn.get_weight_dim()

49922

## Experiment

In [119]:
HOURS = 60 * 60
MAX_TIME = 16 * HOURS

In [124]:
experiment = BasicHMCExperiment(
    bnn,
    data,
    num_samples = 50, num_warmup = 50,
    num_chains = 4, group_by_chain=True
)

In [126]:
tqdm_obj = tqdm.tqdm()
start = time.time()
cnt = 0
while time.time() - start <= MAX_TIME:
    experiment.train(random.PRNGKey(1), progress_bar=True)
    with open(f"samples{cnt}.pkl", "wb") as f:
        pickle.dump(experiment._samples, f)
    experiment.make_predictions(random.PRNGKey(1))
    with open(f"preds{cnt}.pkl", "wb") as f:
        pickle.dump(experiment._predictions, f)
    experiment._samples["w"] = jnp.empty((4, 0, bnn.get_weight_dim(),))
    experiment._predictions = None
    info = f"{time.time() - start:.2f} / {MAX_TIME:.2f}"
    tqdm_obj.clear()
    tqdm_obj.display(info)
    cnt += 1

0it [00:00, ?it/s]


  0%|          | 0/10 [00:00<?, ?it/s][A
sample:  10%|█         | 1/10 [00:00<00:01,  8.58it/s][A
sample:  30%|███       | 3/10 [00:00<00:00,  8.79it/s][A
sample:  40%|████      | 4/10 [00:00<00:00,  8.45it/s][A
sample:  60%|██████    | 6/10 [00:00<00:00,  9.56it/s][A
sample:  70%|███████   | 7/10 [00:00<00:00,  8.51it/s][A
sample:  80%|████████  | 8/10 [00:00<00:00,  8.84it/s][A
sample: 100%|██████████| 10/10 [00:01<00:00,  8.58it/s][A



MCMC elapsed time: 1.3155150413513184



  0%|          | 0/10 [00:00<?, ?it/s][A
sample:  10%|█         | 1/10 [00:00<00:00,  9.17it/s][A
sample:  20%|██        | 2/10 [00:00<00:01,  7.65it/s][A
sample:  40%|████      | 4/10 [00:00<00:00, 10.20it/s][A
sample:  60%|██████    | 6/10 [00:00<00:00, 10.21it/s][A
sample:  80%|████████  | 8/10 [00:00<00:00,  9.30it/s][A
sample: 100%|██████████| 10/10 [00:01<00:00,  9.88it/s][A



MCMC elapsed time: 1.1089098453521729



  0%|          | 0/10 [00:00<?, ?it/s][A
sample:  10%|█         | 1/10 [00:00<00:01,  5.76it/s][A
sample:  20%|██        | 2/10 [00:00<00:01,  5.26it/s][A
sample:  30%|███       | 3/10 [00:00<00:01,  4.72it/s][A
sample:  40%|████      | 4/10 [00:00<00:01,  5.91it/s][A
sample:  50%|█████     | 5/10 [00:01<00:02,  2.27it/s][A
sample:  60%|██████    | 6/10 [00:01<00:01,  2.88it/s][A
sample:  70%|███████   | 7/10 [00:01<00:00,  3.49it/s][A
sample: 100%|██████████| 10/10 [00:02<00:00,  4.50it/s][A



MCMC elapsed time: 2.3191730976104736



  0%|          | 0/10 [00:00<?, ?it/s][A
sample:  10%|█         | 1/10 [00:00<00:01,  7.86it/s][A
sample:  20%|██        | 2/10 [00:00<00:02,  3.42it/s][A
sample:  40%|████      | 4/10 [00:00<00:00,  6.50it/s][A
sample:  50%|█████     | 5/10 [00:00<00:00,  7.13it/s][A
sample:  70%|███████   | 7/10 [00:01<00:00,  7.12it/s][A
sample: 100%|██████████| 10/10 [00:01<00:00,  7.49it/s][A



MCMC elapsed time: 1.4198660850524902


In [None]:
# with open("hmc-samples-first-half.pkl", "wb") as f:
#     pickle.dump(experiment._samples, f)

In [None]:
# experiment.make_predictions(random.PRNGKey(1))

In [None]:
# with open("hmc-preds-first-half.pkl", "wb") as f:
#     pickle.dump(experiment._predictions, f)

In [None]:
# fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(12, 10), sharex='all', sharey='all')
# for i, ax in enumerate(axs.ravel()):
#     ax.plot(data.test[0][:, 1], experiment._predictions["Y_mean"][..., 0][i].mean(axis=0))
#     ax.fill_between(data.test[0][:, 1],
#                     *np.percentile(experiment._predictions["Y_mean"][..., 0][i], (5.0, 95.0), axis=0), alpha=0.5,
#                     color="orange")
#     ax.fill_between(data.test[0][:, 1], *np.percentile(experiment._predictions["Y"][..., 0][i], (5.0, 95.0), axis=0),
#                     alpha=0.5, color="lightgreen")
#     ax.plot(data.train[0][:, 1], data.train[1], "kx")
#     ax.set_ylim(-6, +6)
# fig.tight_layout()
# fig.savefig("figs/hmc-by-chain.png")