In [None]:
import copy
import pickle

import jax.nn
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import optax
import tqdm.notebook as tqdm
from numpyro import handlers

In [None]:
%load_ext autoreload
%autoreload 2

from experiments.src.experiment import *
from experiments.src.data import DataSlice, ToyData1, Sign
from experiments.src.model import BNNRegressor

In [None]:
# %matplotlib inline
# import matplotlib
# matplotlib.use("nbAgg")  # noqa: E402

plt.rcParams.update({
    "axes.grid": True,      # show grid by default
    "font.weight": "bold",  # bold fonts
    "xtick.labelsize": 15,  # large tick labels
    "ytick.labelsize": 15,  # large tick labels
    "lines.linewidth": 1,   # thick lines
    "lines.color": "k",     # black lines
    # "grid.color": "0.5",    # gray gridlines
    "grid.linestyle": "-",  # solid gridlines
    "grid.linewidth": 0.1,  # thin gridlines
    "savefig.dpi": 300,     # higher resolution output.
})

In [None]:
DEVICE = "gpu"
numpyro.set_platform(DEVICE)
# numpyro.set_host_device_count(NUM_CHAINS)
D_X = 2
PRIOR_SCALE = 5.0
VI_MAX_ITER = 250_000
VI_NUM_SAMPLES = 500
HMC_NUM_SAMPLES = 200
HMC_NUM_WARMUP = 100
HMC_NUM_CHAINS = 2
BNN_SIZE = [32, 32, 16]
BETA = 0.2
DO_HMC = True

In [None]:
VI_LR_SCHEDULE = optax.piecewise_interpolate_schedule(
    interpolate_type='cosine',
    init_value=-0.1,
    boundaries_and_scales={1000: 0.2, 5000: 0.5, 10_000: 0.2, 100_000: 0.5, 150_000: 0.5, 200_000: 0.5}
)

## Data

In [None]:
data = ToyData1(D_X=D_X, train_size=100, sigma_obs=0.05)

## Model

In [None]:
bnn = BNNRegressor(
    nonlin=jax.nn.silu,
    D_X=D_X,
    D_Y=1,
    D_H=BNN_SIZE,
    biases=True,
    prior_scale=PRIOR_SCALE,
    prior_type='xavier',
    # obs_model="classification",
    obs_model='loc_scale',
    # obs_model=1 / (0.05 / 0.26480442)**2,
    beta=BETA,
)

In [None]:
bnn.get_weight_dim()

In [None]:
sequential_experiment = SequentialExperiment(bnn, data, BasicMeanFieldGaussianVIExperiment, num_inference_steps=5, num_samples=VI_NUM_SAMPLES, max_iter=VI_MAX_ITER, lr_schedule=VI_LR_SCHEDULE)

In [None]:
# sequential_experiment = ExperimentWithLastBlockReplaced(sequential_experiment, BasicHMCExperiment, num_samples=200, num_warmup=100)

In [None]:
sequential_experiment.train(random.PRNGKey(0))

In [None]:
sequential_experiment.make_predictions(random.PRNGKey(2), final_only=False)

In [None]:
for i, experiment_block in enumerate(sequential_experiment._experiment_blocks):
    fig, ax = plt.subplots()
    experiment_block.make_plots(fig, ax)
    # plot prior data
    end = experiment_block._data._train_idx_slice.start
    ax.plot(sequential_experiment._data.train[0][:end, 1], sequential_experiment._data.train[1][:end], 'bx')
    ax.set_ylim(-6, 6)
    fig.savefig(f"figs/seq{i}.png")

    if DO_HMC:
        experiment_block._bnn.BETA = 1.0
        hmc = BasicHMCExperiment(experiment_block._bnn, experiment_block._data, HMC_NUM_SAMPLES, HMC_NUM_WARMUP,
                                 HMC_NUM_CHAINS)
        hmc.train(random.PRNGKey(0))
        hmc.make_predictions(random.PRNGKey(1))
        hmc._samples = None
        fig, ax = plt.subplots()
        hmc.make_plots(fig=fig, ax=ax)
        ax.plot(sequential_experiment._data.train[0][:end, 1], sequential_experiment._data.train[1][:end], 'bx')
        ax.set_ylim(-6, 6)
        fig.savefig(f"figs/hmcseq{i}.png")
        del hmc
        experiment_block._bnn.BETA = BETA

In [None]:
# # Custom plotting for sequential experiment
# fig, axs = plt.subplots(figsize=(20, 8), ncols=3)
# for i, ax in enumerate(axs.ravel()):
#     experiment_block = sequential_experiment._sequential_experiment._experiment_blocks[i]
#     predictions = experiment_block._predictions["Y"][..., 0]
#     mean_predictions = experiment_block._predictions["Y_mean"][..., 0]
#     data = experiment_block._data
#     X, Y = data.train
#     X_test, _ = data.test
#     # compute mean prediction and confidence interval around median
#     mean_means = jnp.mean(mean_predictions, axis=0)
#     mean_percentiles = np.percentile(predictions, [5.0, 95.0], axis=0)
#     # plot training data
#     ax.plot(X[:, 1], Y[:, 0], "kx")
#     # plot predictions & quantiles
#     ax.plot(X_test[:, 1], mean_means, color="blue")
#     ax.fill_between(X_test[:, 1], *mean_percentiles, color="lightblue")
#     ax.set_title(str(data._train_idx_slice))
# fig.tight_layout()
# fig.savefig("figs/sequential-VI-simple3.png")