In [None]:
import copy
import pickle

import jax.nn
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import optax
import tqdm.notebook as tqdm
from numpyro import handlers

In [None]:
%load_ext autoreload
%autoreload 2

from experiments.src.experiment import *
from experiments.src.data import DataSlice, ToyData1, Sign
from experiments.src.model import BNNRegressor

In [None]:
# %matplotlib inline
# import matplotlib
# matplotlib.use("nbAgg")  # noqa: E402

plt.rcParams.update({
    "axes.grid": True,      # show grid by default
    "font.weight": "bold",  # bold fonts
    "xtick.labelsize": 15,  # large tick labels
    "ytick.labelsize": 15,  # large tick labels
    "lines.linewidth": 1,   # thick lines
    "lines.color": "k",     # black lines
    # "grid.color": "0.5",    # gray gridlines
    "grid.linestyle": "-",  # solid gridlines
    "grid.linewidth": 0.1,  # thin gridlines
    "savefig.dpi": 300,     # higher resolution output.
})

In [None]:
DEVICE = "gpu"
numpyro.set_platform(DEVICE)
# numpyro.set_host_device_count(NUM_CHAINS)
D_X = 2
VI_MAX_ITER = 250_000
BNN_SIZE = [64, 128, 128, 128, 64]
BETAS = np.logspace(-8, 2, num=11)

## Data

In [None]:
data = ToyData1(D_X=D_X, train_size=100, sigma_obs=0.05)

## Model

In [None]:
for beta in BETAS:
    bnn = BNNRegressor(
        nonlin=jax.nn.silu,
        D_X=D_X,
        D_Y=1,
        D_H=BNN_SIZE,
        biases=True,
        prior_scale=np.sqrt(2),
        prior_type='xavier',
        # obs_model="classification",
        obs_model='loc_scale',
        # obs_model=1 / (0.05 / 0.26480442)**2,
        beta=beta,
    )
    experiment = BasicMeanFieldGaussianVIExperiment(
        bnn, data, num_samples=10_000, max_iter=VI_MAX_ITER, lr_schedule=optax.piecewise_interpolate_schedule(
            interpolate_type='cosine',
            init_value=-0.1,
            boundaries_and_scales={1000: 0.2, 5000: 0.5, 10_000: 0.2, 100_000: 0.5, 150_000: 0.5, 200_000: 0.5}
        )
    )
    experiment.train(random.PRNGKey(0))
    experiment.make_predictions(random.PRNGKey(1))
    fig, ax = plt.subplots()
    experiment.make_plots(fig=fig, ax=ax)
    ax.set_ylim(-6, 6)
    fig.savefig(f"figs/beta_sweep_{beta}.png")