In [None]:
import copy
import pickle

import jax.nn
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt
import numpyro
import optax
import tqdm.notebook as tqdm
from numpyro import handlers

In [None]:
%load_ext autoreload
%autoreload 2

from experiments.src.experiment import *
from experiments.src.data import *
from experiments.src.model import BNNRegressor

In [None]:
# %matplotlib inline
# import matplotlib
# matplotlib.use("nbAgg")  # noqa: E402

plt.rcParams.update({
    "axes.grid": True,      # show grid by default
    "font.weight": "bold",  # bold fonts
    "xtick.labelsize": 15,  # large tick labels
    "ytick.labelsize": 15,  # large tick labels
    "lines.linewidth": 1,   # thick lines
    "lines.color": "k",     # black lines
    # "grid.color": "0.5",    # gray gridlines
    "grid.linestyle": "-",  # solid gridlines
    "grid.linewidth": 0.1,  # thin gridlines
    "savefig.dpi": 300,     # higher resolution output.
})

In [None]:
DEVICE = "gpu"
numpyro.set_platform(DEVICE)
# numpyro.set_host_device_count(NUM_CHAINS)
D_X = 1
VI_MAX_ITER = 250_000
VI_NUM_SAMPLES = 500
HMC_NUM_WARMUP = 100
HMC_NUM_SAMPLES = 200
HMC_NUM_CHAINS = 2
BNN_SIZE = [68, 128, 128, 128, 64]
BETA = 0.1

In [None]:
VI_LR_SCHEDULE = optax.piecewise_interpolate_schedule(
    interpolate_type='cosine',
    init_value=-0.1,
    boundaries_and_scales={1000: 0.2, 5000: 0.5, 10_000: 0.2, 100_000: 0.5, 150_000: 0.5, 200_000: 0.5}
)

## Data

In [None]:
pretrain_data = LinearData(intercept=0.5, beta=0.0)

In [None]:
np.random.seed(0)
random_perm = np.random.choice(np.arange(50), size=50, replace=False)
shifted_data = PermutedData(LinearData(intercept=-0.5, beta=0.0), random_perm)

Take prefixes of shifted data to see effect of data size

In [None]:
retrain_sizes = [1, 2, 5, 10, 25, 50, 75]
retrain_datasets = [DataSlice(shifted_data, slice(size)) for size in retrain_sizes]

## Model

In [None]:
bnn = BNNRegressor(
    nonlin=jax.nn.silu,
    D_X=D_X,
    D_Y=1,
    D_H=BNN_SIZE,
    biases=True,
    prior_scale=10,
    prior_type='xavier',
    # obs_model="classification",
    obs_model='loc_scale',
    # obs_model=1 / (0.05 / 0.26480442)**2,
    beta=BETA,
)


In [None]:
bnn.get_weight_dim()

## Experiment

Pretrain on first data set

In [None]:
# HMC baseline
pretrain_hmc_baseline = BasicHMCExperiment(bnn, pretrain_data, HMC_NUM_SAMPLES, HMC_NUM_WARMUP, HMC_NUM_CHAINS)
pretrain_hmc_baseline.train(random.PRNGKey(0))

In [None]:
pretrain_hmc_baseline.make_predictions(random.PRNGKey(1))

In [None]:
fig, ax = plt.subplots()
pretrain_hmc_baseline.make_plots(fig=fig, ax=ax)
ax.set_ylim(-2, 2)
ax.set_title("HMC baseline for pretraining stage")
fig.tight_layout()
fig.savefig("pretrain-hmc-baseline.png")

In [None]:
del pretrain_hmc_baseline

In [None]:
# VI pretraining
pretrain_experiment = BasicMeanFieldGaussianVIExperiment(
    bnn, pretrain_data, num_samples=VI_NUM_SAMPLES, max_iter=VI_MAX_ITER, lr_schedule=VI_LR_SCHEDULE)

In [None]:
pretrain_experiment.train(random.PRNGKey(0))

In [None]:
pretrain_experiment.make_predictions(random.PRNGKey(1))
fig, ax = plt.subplots()
pretrain_experiment.make_plots(fig=fig, ax=ax)
ax.set_ylim(-2.0, 2.0)
fig.savefig("figs/vi-pretrained-predictive.png")

In [None]:
pretrained_prior = pretrain_experiment.posterior
del pretrain_experiment
pretrained_bnn = bnn.with_prior(*pretrained_prior)
del bnn

In [None]:
for retrain_data in retrain_datasets:
    retrain_data_size = retrain_data.train[1].shape[0]
    # train VI on pretrained prior
    vi_experiment_on_pretrained = BasicMeanFieldGaussianVIExperiment(
        pretrained_bnn, retrain_data, num_samples=VI_NUM_SAMPLES, max_iter=VI_MAX_ITER, lr_schedule=VI_LR_SCHEDULE)
    vi_experiment_on_pretrained.train(random.PRNGKey(0))
    vi_experiment_on_pretrained.make_predictions(random.PRNGKey(1))
    # Make plot
    fig, ax = plt.subplots()
    vi_experiment_on_pretrained.make_plots(fig=fig, ax=ax)
    ax.set_title(f"VI retrained on {retrain_data_size} points")
    ax.set_ylim(-2.0, 2.0)
    fig.tight_layout()
    fig.savefig(f"figs/vi-retrained-on-{retrain_data_size}.png")
    del vi_experiment_on_pretrained

    # train HMC on pretrained prior
    hmc_experiment_on_pretrained = BasicHMCExperiment(
        pretrained_bnn, retrain_data, HMC_NUM_SAMPLES, HMC_NUM_WARMUP, HMC_NUM_CHAINS)
    hmc_experiment_on_pretrained.train(random.PRNGKey(0))
    hmc_experiment_on_pretrained.make_predictions(random.PRNGKey(1))
    # Make plot
    fig, ax = plt.subplots()
    hmc_experiment_on_pretrained.make_plots(fig=fig, ax=ax)
    ax.set_title(f"HMC retrained on {retrain_data.train[1].shape[0]} points")
    ax.set_ylim(-2.0, 2.0)
    fig.tight_layout()
    fig.savefig(f"figs/hmc-retrained-on-{retrain_data_size}.png")
    del hmc_experiment_on_pretrained