In [2]:
import jax
import jax.nn
import numpy as np
import numpyro
import optax
from numpyro.handlers import block, trace, seed
from jax import random, vmap
import matplotlib.pyplot as plt
from numpyro.infer import Predictive
from typing import Optional

In [3]:
%load_ext autoreload
%autoreload 2

from experiments.src.experiment import *
from experiments.src.data import *
from experiments.src.model import BNNRegressor

In [4]:
# %matplotlib inline
# import matplotlib
# matplotlib.use("nbAgg")  # noqa: E402

plt.rcParams.update({
    "axes.grid": True,  # show grid by default
    "font.weight": "bold",  # bold fonts
    "xtick.labelsize": 15,  # large tick labels
    "ytick.labelsize": 15,  # large tick labels
    "lines.linewidth": 1,  # thick lines
    "lines.color": "k",  # black lines
    # "grid.color": "0.5",    # gray gridlines
    "grid.linestyle": "-",  # solid gridlines
    "grid.linewidth": 0.1,  # thin gridlines
    "savefig.dpi": 300,  # higher resolution output.
})

In [40]:
DEVICE = "cpu"
numpyro.set_platform(DEVICE)
# numpyro.set_host_device_count(NUM_CHAINS)
D_X = 10
BNN_SIZE = [16, 16]
VI_ITER = 100_000

## Spurious correlation

We hypothesise that MAP is more effective at finding a sparse solution when encountering correlated covariates, resulting in better generalisation performance than that of the Bayesian model-averaged posterior mean.

To test this we create a dataset with two features $x_1, x_2$, of correlation $\rho$, where the response is generated by a $\mathcal{N}(x_1, \sigma^2)$.
In the test dataset however, the covariates are sampled independently with the same marginals, and the response is still calculated as $\mathcal{N}(x_1, \sigma^2)$.

We generalise this to multiple features too.

Switching to independent Laplace priors on the weights now corresponds to Lasso regression in the linear model case, which we know induces sparsity in the MAP solution.
According to our hypothesis, BMA considers models which are less sparse, or picks the wrong feature due to correlation.
This might not be the fault of Bayesian inference per se -- we might have the prior belief that exactly one of the correlated features is explanatory, but any one of them can be -- but it is a simple demonstration that BMA performs poorly under  distribution shift.

Note: rho=0.9999, D_X=10, train_size=20, run key=0 works for id non-linearity, single hidden *unit*, iid Laplace prior scale of $\sqrt{0.002}$

In [41]:
class SpuriouslyCorrelatedData(Data):
    def __init__(self, rho=0.90, sigma_obs=0.05, train_size=100, test_size=500, D_X=2):
        np.random.seed(0)
        common = np.random.normal(scale=np.sqrt(rho), size=(train_size, 1))
        self.X_train = np.random.normal(scale=np.sqrt(1. - rho), size=(train_size, D_X)) + common
        # self.Y_train = np.mean(self.X_train, axis=1)[:, np.newaxis]
        # For now I'm calculating y = x1 + noise; above would calculate as mean + noise
        self.Y_train = self.X_train[:, [0]]
        self.Y_train += np.random.normal(scale=sigma_obs, size=(train_size, 1))

        self.X_test = np.random.normal(size=(test_size, D_X))
        self.Y_test = self.X_test[:, [0]]
        self.Y_test += np.random.normal(scale=sigma_obs, size=(test_size, 1))

    @property
    def train(self) -> tuple[jax.Array, jax.Array]:
        return self.X_train, self.Y_train

    @property
    def test(self) -> tuple[jax.Array, Optional[jax.Array]]:
        return self.X_test, self.Y_test

    def true_predictive(self, X: jax.Array) -> dist.Distribution:
        raise NotImplementedError()

In [52]:
data = SpuriouslyCorrelatedData(rho=0.99, D_X=D_X, train_size=20)

In [87]:
bnn = BNNRegressor(
    nonlin=jax.nn.silu,
    D_X=D_X,
    D_Y=1,
    D_H=BNN_SIZE,
    biases=True,
    obs_model=1 / 0.05 ** 2,
    prior_scale=np.sqrt(0.02),
    # prior_type='xavier',
)
bnn.get_weight_dim()

465

In [72]:
# Set Laplace prior for equivalent of Lasso regression, classically inducing sparsity
bnn = bnn.with_prior(dist.Laplace(scale=bnn.prior[0].base_dist.scale).to_event(1), bnn.prior[1])

In [88]:
delta = AutoDeltaVIExperiment(bnn, data, max_iter=VI_ITER, lr_schedule=optax.constant_schedule(-0.01))

In [89]:
delta.train(random.PRNGKey(1))
delta.make_predictions(random.PRNGKey(1))

Initial eval loss: 2709.2339 (lik: -3329.8101, kl: -620.5762)


100%|██████████| 50/50 [00:07<00:00,  6.27it/s, init loss: 2709.2339, avg. train loss / eval. loss [98000-100000]: -329.8375 / -329.7882]



SVI elapsed time: 8.267685174942017


In [90]:
delta._params

{'prec_obs_loc': DeviceArray(400.00003, dtype=float32),
 'w_loc': DeviceArray([ 7.14827547e-05,  1.03929685e-03, -1.02534681e-03,
               3.37931397e-06, -9.96556349e-09,  7.64857973e-08,
               2.31798243e-04,  2.77147591e-01,  1.18252565e-07,
               8.61674853e-05,  3.04525933e-08,  4.35796323e-08,
               1.12666690e-03, -1.07384018e-04,  6.74544775e-04,
               2.38919881e-07,  7.51016778e-05,  1.03490439e-03,
              -1.02116249e-03,  3.19651463e-06, -4.70031170e-09,
               7.23475466e-08,  2.34135601e-04,  2.36684844e-01,
               1.12510158e-07,  8.59741631e-05,  3.09402886e-08,
               3.18646443e-08,  1.11263560e-03, -1.07569700e-04,
               6.70220121e-04,  2.53432802e-07,  7.94269508e-05,
               1.04482437e-03, -1.02828024e-03,  3.01744353e-06,
              -3.40514106e-09,  7.12111543e-08,  2.33534462e-04,
               2.20589072e-01,  1.03097612e-07,  8.62102825e-05,
               2.84990218

In [91]:
map_posterior = delta._predictions['Y_mean'][0]
map_mse = np.mean(np.square(map_posterior - data.test[1]))
map_mse

0.9058541

In [92]:
hmc = BasicHMCExperiment(bnn, data, init_params={'w': delta._params['w_loc']},
                         num_samples=400, num_warmup=300)

In [93]:
hmc.train(random.PRNGKey(0))
hmc.make_predictions(random.PRNGKey(1))

sample: 100%|██████████| 700/700 [00:07<00:00, 97.60it/s, 31 steps of size 1.24e-01. acc. prob=0.93] 



MCMC elapsed time: 8.11474609375


In [94]:
hmc_mean_predictions = hmc._predictions['Y_mean'].mean(axis=0)
hmc_mse = np.mean(np.square(hmc_mean_predictions - data.test[1]))
hmc_mse

0.90449476

In [30]:
hmc._samples['w'][-5:]

DeviceArray([[ 1.3847658e-02,  1.7952403e-02, -2.3003712e-03,
              -2.6926407e-04,  3.7997626e-02,  2.8197473e-01,
               7.5710572e-02, -5.0192405e-03,  3.8544316e-02,
               2.7631648e-02,  3.1657892e-01,  2.5192289e+00,
              -4.8109064e-01],
             [ 1.0599532e-01,  7.1835779e-02,  1.9686129e-02,
               1.4376019e-02,  4.7238357e-02,  1.1606112e-01,
              -4.3411851e-03,  1.6112406e-02,  3.1297136e-02,
               2.3165448e-03,  3.2873976e-01,  2.9378760e+00,
              -5.7730156e-01],
             [ 1.2263170e-01,  7.9705320e-02,  1.5255280e-03,
               5.2526459e-02,  1.1021169e-02,  7.7841409e-02,
               6.3210919e-02,  1.0040593e-02,  1.3078890e-03,
               5.6479238e-03,  3.5915419e-01,  2.9328864e+00,
              -6.5556222e-01],
             [ 1.5390995e-01,  6.7853101e-02,  6.9738016e-02,
               4.1303501e-02,  1.8976253e-02,  5.6318004e-02,
               8.3732806e-02,  3.887910