In [1012]:
import copy
import os
import time
import abc
import copy
import pickle

import jax.tree_util
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

from jax import jit, lax, vmap
import jax.numpy as jnp
import jax.nn
import jax.random as random

import numpyro
from numpyro import handlers
from numpyro.distributions import constraints
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO, TraceEnum_ELBO, TraceMeanField_ELBO, autoguide
from numpyro.infer.svi import SVIState

import optax

import tqdm
from typing import Any, Callable, Iterable, Optional, Sequence, Type
from typing_extensions import Self
import functools

In [2]:
# %matplotlib inline
# matplotlib.use("nbAgg")  # noqa: E402

plt.rcParams.update({
    "axes.grid": True,      # show grid by default
    "font.weight": "bold",  # bold fonts
    "xtick.labelsize": 15,  # large tick labels
    "ytick.labelsize": 15,  # large tick labels
    "lines.linewidth": 1,   # thick lines
    "lines.color": "k",     # black lines
    # "grid.color": "0.5",    # gray gridlines
    "grid.linestyle": "-",  # solid gridlines
    "grid.linewidth": 0.1,  # thin gridlines
    "savefig.dpi": 300,     # higher resolution output.
})

In [1127]:
DEVICE = "gpu"
numpyro.set_platform(DEVICE)
# numpyro.set_host_device_count(NUM_CHAINS)
D_X = 3
VI_MAX_ITER = 100_000
BNN_SIZE = [128, 256, 64]
BETA = 1.0
# FIG_PREFIX = f"fig"

## Data

In [942]:
class Data(abc.ABC):
    @property
    @abc.abstractmethod
    def train(self) -> tuple[jnp.ndarray, jnp.ndarray]:
        raise NotImplementedError()

    @property
    @abc.abstractmethod
    def test(self) -> tuple[jnp.ndarray, jnp.ndarray]:
        raise NotImplementedError()

    # @abc.abstractmethod
    def true_predictive(self, X: jnp.ndarray) -> dist.Distribution:
        raise NotImplementedError()

In [943]:
# Create partial view decorator of data
class DataSlice(Data):
    def __init__(self, data: Data, train_idx_slice: slice):
        self._data = data
        self._train_idx_slice: slice = train_idx_slice

    @property
    def train(self) -> tuple[jnp.ndarray, jnp.ndarray]:
        X, Y = self._data.train
        return X[self._train_idx_slice], Y[self._train_idx_slice]

    @property
    def test(self) -> tuple[jnp.ndarray, jnp.ndarray]:
        return self._data.test

    def true_predictive(self, X: jnp.ndarray) -> dist.Distribution:
        return self._data.true_predictive(X)

In [361]:
# Reverse dataset
class ReverseData(Data):
    def __init__(self, data: Data):
        self._data = data

    @property
    def train(self) -> tuple[jnp.ndarray, jnp.ndarray]:
        X, Y = self._data.train
        return X[::-1, ...], Y[::-1, ...]

    @property
    def test(self) -> tuple[jnp.ndarray, jnp.ndarray]:
        return self._data.test

    def true_predictive(self, X: jnp.ndarray) -> dist.Distribution:
        return self._data.true_predictive(X)

In [362]:
class PermutedData(Data):
    def __init__(self, data: Data, perm: np.array):
        self._data = data
        assert perm.shape[0] == data.train[0].shape[0], "wrong len"
        perm_copy = perm.copy()
        perm_copy.sort()
        assert np.all(perm_copy == np.arange(len(perm))), "not a permutation"
        self._perm = perm

    @property
    def train(self) -> tuple[jnp.ndarray, jnp.ndarray]:
        X, Y = self._data.train
        return X[self._perm], Y[self._perm]

    @property
    def test(self) -> tuple[jnp.ndarray, jnp.ndarray]:
        return self._data.test

    def true_predictive(self, X: jnp.ndarray) -> dist.Distribution:
        return self._data.true_predictive(X)

In [619]:
# Define toy regression problem
# create artificial regression dataset
class ToyData1(Data):
    def __init__(self, D_X: int = 3, sigma_obs: float = 0.05, train_size: int = 50, test_size: int = 500):
        self.D_X = D_X
        self.sigma_obs = sigma_obs
        D_Y = 1  # create 1d outputs
        np.random.seed(0)
        X = jnp.concatenate((jnp.linspace(-1, -0.4, train_size // 2),
                             jnp.linspace(0.4, 1, train_size - (train_size // 2))))
        X = jnp.power(X[:, np.newaxis], jnp.arange(D_X))  # XXX ?bias included in model
        W = 0.5 * np.random.randn(D_X)
        # y = w0 + w1*x + w2*x**2 + 1/2 (1/2+x)**2 * sin(4x)
        Y = jnp.dot(X, W) + 0.5 * jnp.power(0.5 + X[:, 1], 2.0) * jnp.sin(4.0 * X[:, 1])
        Y += sigma_obs * np.random.randn(train_size)
        Y = Y[:, np.newaxis]
        Y -= jnp.mean(Y)
        Y /= jnp.std(Y)

        assert X.shape == (train_size, D_X)
        assert Y.shape == (train_size, D_Y)

        X_test = jnp.linspace(-1.7, 1.7, test_size)
        X_test = jnp.power(X_test[:, np.newaxis], jnp.arange(D_X))

        self._X = X
        self._Y = Y
        self._X_test = X_test
        self._Y_test = None

    @property
    def train(self):
        return (self._X, self._Y)

    @property
    def test(self):
        return (self._X_test, self._Y_test)

    def true_predictive(self, X: jnp.ndarray) -> dist.Distribution:
        raise NotImplementedError()

In [618]:
data = ToyData1(D_X=D_X, train_size=100)

0.26480442


## Model

In [1137]:
class BayesianNeuralNetwork:
    # TODO might be cleaner to make this class immutable: done? maybe weird interactions with deepcopy
    def __init__(self,
                 nonlin: Callable[[jnp.ndarray], jnp.ndarray],
                 D_X: int,
                 D_Y: int,
                 D_H: list[int],
                 biases: bool,
                 obs_model: str | float = "loc_scale",
                 prior_scale: float = 1.0,
                 prior_type: str = "iid",  # or "xavier"
                 ):
        """ :param obs_model: float: precision of Gaussian / "loc_scale": predict both / "inv_gamma": Gamma
            hyper-prior on precision
        """
        self._nonlin = nonlin
        # map scales into R+ using softplus ie log(1+exp(.))
        self._scale_nonlin = lambda xs: jax.nn.softplus(xs) + 1e-1  # Add eps so lik doesn't vanish
        self.D_X = D_X
        self.D_Y = D_Y
        self.D_H = D_H
        self._biases = biases
        if obs_model == "loc_scale":
            if self.D_Y > 1:
                raise NotImplementedError("Should predict a cov matrix... not impl yet")
            self.OBS_MODEL = "loc_scale"
            self.D_Y += 1
            assert self.D_Y == 2
        elif obs_model == "inv_gamma":
            self.OBS_MODEL = "inv_gamma"
            self._prior_prec_obs = dist.Gamma(3.0, 1.0)
        elif isinstance(obs_model, float):
            self.OBS_MODEL = "const_prec"
            # Abstract const parameter into dist; mask according to convention below, see guides
            self._prior_prec_obs = dist.Delta(obs_model).mask(False)
        # add trainable numpyro.param too?
        else:
            raise ValueError(obs_model)

        assert prior_type in ("iid", "xavier")
        prior_scales = self._scale_init(prior_scale, prior_type)
        # Initialise priors to independent standard normals
        self._prior_w = dist.Normal(jnp.zeros(self.get_weight_dim()), prior_scales).to_event(1)
        # self._prior_w = dist.MultivariateNormal(jnp.zeros(self.get_weight_dim()),
        #                                         jnp.diag(jnp.full((self.get_weight_dim(),), prior_scale)))


    def get_weight_dim(self) -> int:
        if self._biases:
            dim = 0
            prev = self.D_X
            for width in self.D_H:
                dim += prev * width + width
                prev = width
            dim += prev * self.D_Y + self.D_Y
            return dim
        else:
            dim = 0
            prev = self.D_X
            for width in self.D_H:
                dim += prev * width
                prev = width
            dim += prev * self.D_Y
            return dim

    def _wi_from_flat(self, a: jnp.ndarray, depth: int, bias: bool = False) -> jnp.ndarray:
        # set bias to return bias of that layer
        assert a.shape[0] == self.get_weight_dim()
        assert 0 <= depth <= len(self.D_H)
        if bias:
            assert self._biases
        prev = self.D_X
        idx = 0
        layer = 0
        for width in self.D_H:
            if depth == layer:
                if not bias:
                    return a[idx:(idx + prev * width)].reshape((prev, width))
                else:
                    idx += prev * width
                    return a[idx:(idx+width)]#.reshape((width, 1))
            idx += prev * width
            if self._biases:
                idx += width
            layer += 1
            prev = width
        assert depth == layer == len(self.D_H)
        if not bias:
            return a[idx:(idx + prev * self.D_Y)].reshape((prev, self.D_Y))
        else:
            idx += prev * self.D_Y
            return a[idx:(idx + self.D_Y)]#.reshape((self.D_Y, 1))

    def _scale_init(self, prior_scale, prior_type: str) -> jnp.array:
        res = np.full((self.get_weight_dim(),), prior_scale, dtype=float)
        if prior_type == "iid":
            return res
        assert prior_type == "xavier"
        idx = np.arange(self.get_weight_dim())
        for depth, width in enumerate([self.D_X] + self.D_H):
            res[self._wi_from_flat(idx, depth)] /= float(width)
            if self._biases:
                res[self._wi_from_flat(idx, depth, bias=True)] /= float(width)
        return res

    #noinspection PyPep8Naming
    def __call__(self, X: jnp.ndarray, Y: Optional[jnp.ndarray] = None):
        N, D_X = X.shape
        assert D_X == self.D_X

        # sample weights from prior
        with handlers.scale(scale=BETA):
            w = numpyro.sample("w", self._prior_w)

        pre_activ = jnp.matmul(X, self._wi_from_flat(w, depth=0))
        if self._biases:
            pre_activ += self._wi_from_flat(w, depth=0, bias=True)
        for depth in range(1, 1+len(self.D_H)):
            pre_activ = jnp.matmul(self._nonlin(pre_activ), self._wi_from_flat(w, depth))
            if self._biases:
                pre_activ += self._wi_from_flat(w, depth, bias=True)

        if self.OBS_MODEL == "loc_scale":
            assert pre_activ.shape[-1] == 2
            Y_mean = numpyro.deterministic("Y_mean", pre_activ[..., [0]])
            if Y is not None:
                assert Y_mean.shape == Y.shape
            Y_scale = numpyro.deterministic("Y_scale", self._scale_nonlin(pre_activ[..., [1]]))
            # observe data
            with numpyro.plate("data", N):
                numpyro.sample("Y", dist.Normal(Y_mean, Y_scale).to_event(1), obs=Y)

        else:
            assert hasattr(self, "_prior_prec_obs")
            # we put a prior on the observation noise
            prec_obs = numpyro.sample("prec_obs", self._prior_prec_obs)
            sigma_obs = numpyro.deterministic("sigma_obs", 1.0 / jnp.sqrt(prec_obs))

            Y_mean = numpyro.deterministic("Y_mean", pre_activ)
            if Y is not None:
                assert Y_mean.shape == Y.shape

            # observe data
            with numpyro.plate("data", N):
                numpyro.sample("Y", dist.Normal(Y_mean, jnp.full((N, self.D_Y), sigma_obs)).to_event(1), obs=Y)

    @property
    def prior(self) -> tuple[dist.Distribution, Optional[dist.Distribution]]:
        """ :returns prior on w and (prec_obs if exists) """
        return self._prior_w,  self._prior_prec_obs if hasattr(self, "_prior_prec_obs")  else None

    def with_prior(self, prior_w: dist.Distribution, prior_prec_obs: Optional[dist.Distribution] = None) -> Self:
        cpy = copy.deepcopy(self)
        cpy._prior_w = prior_w
        cpy._prior_prec_obs = prior_prec_obs
        return cpy

In [1151]:
bnn = BayesianNeuralNetwork(
    nonlin=jax.nn.silu,
    D_X=D_X,
    D_Y=1,
    D_H=BNN_SIZE,
    biases=True,
    prior_scale=10,
    prior_type='xavier',
    obs_model='loc_scale',
    # obs_model=1 / (0.05 / 0.26480442)**2,
)


In [1152]:
bnn.get_weight_dim()

50114

In [1153]:
fresh_bnn = copy.deepcopy(bnn)

## Experiment

In [948]:
class Experiment(abc.ABC):
    def __init__(self, bnn: BayesianNeuralNetwork, data: Data):
        self._bnn: BayesianNeuralNetwork = bnn
        self._data: Data = data
        # Initialise state
        self._predictions: Optional[dict] = None  # numpyro trace on data.test predictive
        # self._predictions: Optional[jnp.ndarray] = None  # of shape (num_samples, X_test.shape[0])

    @abc.abstractmethod
    def train(self, rng_key_train: random.PRNGKey):
        pass

    @abc.abstractmethod
    def make_predictions(self, rng_key_predict: random.PRNGKey):
        pass

    def make_plots(self, fig=None, ax=None, **kwargs) -> plt.Figure:
        assert self._predictions is not None
        X, Y = self._data.train
        X_test, _ = self._data.test
        # compute mean prediction and confidence interval around median
        Y_mean_pred, Y_pred = self._predictions["Y_mean"][..., 0], self._predictions["Y"][..., 0]
        mean_means = jnp.mean(Y_mean_pred, axis=0)
        mean_percentiles = np.percentile(Y_mean_pred, [5.0, 95.0], axis=0)
        Y_percentiles = np.percentile(Y_pred, [5.0, 95.0], axis=0)
        # plotting
        if fig is None or ax is None:
            fig, ax = plt.subplots(figsize=(8, 6))
        # plot training data
        ax.plot(X[:, 1], Y[:, 0], "kx")
        # plot predictions & quantiles
        ax.plot(X_test[:, 1], mean_means, color="blue")
        ax.fill_between(X_test[:, 1], *mean_percentiles, color="orange", alpha=0.5, label="90% CI on mean")
        ax.fill_between(X_test[:, 1], *Y_percentiles, color="lightgreen", alpha=0.5, label="90% prediction")
        return fig

    def run(self, rng_key: random.PRNGKey):
        rng_key_train, rng_key_predict = random.split(rng_key)
        self.train(rng_key_train)
        self.make_predictions(rng_key_predict)
        fig = self.make_plots()
        return fig

In [949]:
class SequentialExperimentBlock(Experiment):
    @property
    @abc.abstractmethod
    def posterior(self) -> tuple[dist.Distribution, dist.Distribution]:
        """ Returns distribution on w and prec_obs """
        raise NotImplementedError()

### HMC

In [919]:
class BasicHMCExperiment(Experiment):
    def __init__(self, bnn: BayesianNeuralNetwork, data: Data, num_samples: int = 2_000,
                 num_warmup: int = 1_000, num_chains: int = 1, group_by_chain: bool = False):
        super().__init__(bnn, data)
        self._num_samples = num_samples
        self._num_warmup = num_warmup
        self._num_chains = num_chains
        self._group_by_chain = group_by_chain
        # Initialise state
        self._samples: Optional[dict] = None

    def train(self, rng_key_train: random.PRNGKey):
        start = time.time()
        X, Y = self._data.train
        kernel = NUTS(self._bnn)
        mcmc = MCMC(
            kernel,
            num_warmup=self._num_warmup,
            num_samples=self._num_samples,
            num_chains=self._num_chains,
            chain_method="vectorized",
            progress_bar=False if DEVICE == "gpu" or "NUMPYRO_SPHINXBUILD" in os.environ else True,
        )
        mcmc.run(rng_key_train, X, Y)
        # mcmc.print_summary()
        print("\nMCMC elapsed time:", time.time() - start)
        self._samples = mcmc.get_samples(group_by_chain=self._group_by_chain)

    def make_predictions(self, rng_key_predict: random.PRNGKey):
        assert self._samples is not None
        X_test, _ = self._data.test
        if not self._group_by_chain:
            self._predictions = Predictive(self._bnn, self._samples, return_sites=['w', 'Y_mean', 'Y_std', 'Y'])(rng_key_predict, X=X_test, Y=None)#['Y'][..., 0]
        else:
            def pred(rng_key, samples):
                return Predictive(self._bnn, samples)(rng_key, X=X_test, Y=None)
            self._predictions = vmap(pred)(random.split(rng_key_predict, self._num_chains), self._samples)

In [1123]:
experiment = BasicHMCExperiment(
    bnn,
    data,
    num_samples = 60, num_warmup = 30,
    # num_chains = 4, group_by_chain=True
)
experiment.run(random.PRNGKey(0))#.savefig("figs/simple_hmc_4.png")

KeyboardInterrupt: 

In [323]:
# fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(12,10), sharex=True, sharey=True)
# for i, ax in enumerate(axs.ravel()):
#     ax.plot(data.test[0][:, 1], experiment._predictions["Y_mean"][..., 0][i].mean(axis=0))
#     ax.fill_between(data.test[0][:, 1], *np.percentile(experiment._predictions["Y_mean"][..., 0][i], (5.0, 95.0), axis=0), alpha=0.5, color="orange")
#     ax.fill_between(data.test[0][:, 1], *np.percentile(experiment._predictions["Y"][..., 0][i], (5.0, 95.0), axis=0), alpha=0.5, color="lightgreen")
#     ax.plot(data.train[0][:, 1], data.train[1], "kx")
#     ax.set_ylim(-6, +6)
# fig.tight_layout()
# fig.savefig("figs/hmc-by-chain.png")

<IPython.core.display.Javascript object>

### VI

In [890]:
class EvalLoss:
    def __init__(self, num_particles=1):
        self.num_particles = num_particles

    def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
        def single_particle_elbo(rng_key):
            params = param_map.copy()
            model_seed, guide_seed = random.split(rng_key)
            seeded_model = handlers.seed(model, model_seed)
            seeded_guide = handlers.seed(guide, guide_seed)
            subs_guide = handlers.substitute(seeded_guide, data=param_map)
            guide_trace = handlers.trace(subs_guide).get_trace(*args, **kwargs)
            subs_model = handlers.substitute(handlers.replay(seeded_model, guide_trace), data=params)
            model_trace = handlers.trace(subs_model).get_trace(*args, **kwargs)
            # check_model_guide_match(model_trace, guide_trace)
            # _validate_model(model_trace, plate_warning="loose")
            # _check_mean_field_requirement(model_trace, guide_trace)

            elbo_lik = 0
            elbo_kl = 0
            for name, model_site in model_trace.items():
                if model_site["type"] == "sample":
                    if model_site["is_observed"]:
                        elbo_lik = elbo_lik + numpyro.infer.elbo._get_log_prob_sum(model_site)
                    else:
                        guide_site = guide_trace[name]
                        try:
                            kl_qp = dist.kl.kl_divergence(guide_site["fn"], model_site["fn"])
                            kl_qp = dist.util.scale_and_mask(kl_qp, scale=guide_site["scale"])
                            elbo_kl = elbo_kl + jnp.sum(kl_qp)
                            # elbo_lik = elbo_lik - jnp.sum(kl_qp)
                        except NotImplementedError:
                            raise NotImplementedError()
                        #     elbo_particle = (
                        #             elbo_particle
                        #             + numpyro.infer.elbo._get_log_prob_sum(model_site)
                        #             - numpyro.infer.elbo._get_log_prob_sum(guide_site)
                        #     )

            # handle auxiliary sites in the guide
            for name, site in guide_trace.items():
                if site["type"] == "sample" and name not in model_trace:
                    assert site["infer"].get("is_auxiliary") or site["is_observed"]
                    elbo_lik = elbo_lik - numpyro.infer.elbo._get_log_prob_sum(site)

            return elbo_lik, elbo_kl

        if self.num_particles == 1:
            elbo_lik, elbo_kl = single_particle_elbo(rng_key)
            return {"elbo_lik": elbo_lik, "elbo_kl": elbo_kl}
        else:
            rng_keys = random.split(rng_key, self.num_particles)
            elbo_liks, elbo_kls = vmap(single_particle_elbo)(rng_keys)
            assert jnp.all(elbo_kls == elbo_kls[0])
            return {
                "elbo_lik": jnp.mean(elbo_liks),
                "elbo_kl": elbo_kls[0],
                "loss": -jnp.mean(elbo_liks) + elbo_kls[0]
            }

In [1000]:
class BasicVIExperiment(SequentialExperimentBlock):
    def __init__(self, bnn: BayesianNeuralNetwork, data: Data, num_samples: int = 2_000,
                 max_iter: int = 150_000):
        super().__init__(bnn, data)
        self._num_samples = num_samples
        self._max_iter = max_iter
        # Initialise state
        self._svi: Optional[SVI] = None
        self._guide: Optional[Callable] = None
        self._saved_svi_state: Optional[SVIState] = None
        self._losses: jnp.array = jnp.array([])
        self._eval_losses: jnp.array = jnp.array([]).reshape((0,3))
        self._params: Optional[dict] = None

    def train(self, rng_key_train: random.PRNGKey, num_iter: Optional[int] = None):
        if num_iter is None:
            num_iter = self._max_iter

        start = time.time()
        X, Y = self._data.train

        if self._svi is None:
            self._guide = self._get_guide()
            # Custom optimizer to prevent effect of exploding gradients (by tail ELBO estimates)
            # Taken from phuijse.github.io/BLNNbook
            lr_schedule = optax.constant_schedule(-0.0005)
            # lr_schedule = optax.polynomial_schedule(
            #     init_value=-0.01, end_value=-0.00001, power=1, transition_steps=5*VI_MAX_ITER)
            clipped_adam = optax.chain(optax.clip_by_global_norm(10.0),
                                       optax.scale_by_adam(),
                                       optax.scale_by_schedule(lr_schedule))
            optimizer = clipped_adam  # Default taken from ashleve/lightning-hydra-template
            train_loss = TraceMeanField_ELBO(num_particles=16)
            self._svi = SVI(self._bnn, self._guide, optimizer, train_loss)
        eval_loss = EvalLoss(num_particles=64)
        rng_key_train, rng_key_eval, rng_key_init_loss = random.split(rng_key_train, 3)

        if self._saved_svi_state is None:
            self._saved_svi_state = self._svi.init(rng_key_train, X=X, Y=Y)

        def body_fn(svi_state, _):
            svi_state, loss = self._svi.stable_update(svi_state, X=X, Y=Y)
            return svi_state, loss

        init_eval_loss = eval_loss.loss(
            rng_key_init_loss, self._svi.get_params(self._saved_svi_state), self._bnn, self._guide, X=X, Y=Y)
        print("Initial eval loss: {:.4f} (lik: {:.4f}, kl: {:.4f})".format(
            init_eval_loss["loss"], init_eval_loss["elbo_lik"], init_eval_loss["elbo_kl"]))

        batch = max(num_iter // 50, 1)
        with tqdm.trange(1, num_iter // batch + 1) as t:
            for i in t:
                self._saved_svi_state, batch_losses = lax.scan(body_fn, self._saved_svi_state, None, length=batch)
                self._losses = jnp.concatenate((self._losses, batch_losses))
                valid_losses = [x for x in batch_losses if x == x]
                num_valid = len(valid_losses)
                if num_valid == 0:
                    avg_loss = float("nan")
                else:
                    avg_loss = sum(valid_losses) / num_valid
                # Compute full loss
                rng_key_eval, rng_key_eval_curr = random.split(rng_key_eval)
                eval_loss_res = eval_loss.loss(rng_key_eval_curr, self._svi.get_params(self._saved_svi_state), self._bnn, self._guide, X=X, Y=Y)
                self._eval_losses = jnp.append(
                    self._eval_losses, jnp.array([[eval_loss_res["loss"], eval_loss_res["elbo_lik"], eval_loss_res["elbo_kl"]]]),
                    axis=0
                )
                t.set_postfix_str(
                    "init loss: {:.4f}, avg. train loss / eval. loss [{}-{}]: {:.4f} / {:.4f}".format(
                        self._losses[0], (i-1)*batch, i*batch, avg_loss, eval_loss_res["loss"]
                    ),
                    refresh=False,
                )
        self._params = self._svi.get_params(self._saved_svi_state)
        print("\nSVI elapsed time:", time.time() - start)

    def make_predictions(self, rng_key_predict: random.PRNGKey):
        assert self._params is not None and self._guide is not None
        X_test, _ = self._data.test
        predictive = Predictive(model=self._bnn, guide=self._guide,
                                params=self._params, num_samples=self._num_samples)
        self._predictions = predictive(rng_key_predict, X=X_test, Y=None)#['Y'][..., 0]

    def show_convergence_plot(self, fig=None, ax=None) -> plt.Figure:
        if fig is None or ax is None:
            fig, ax = plt.subplots()
        ax.plot(self._eval_losses[:, 0], label="loss")
        ax.plot(self._eval_losses[:, 1], label="lik")
        ax.plot(-self._eval_losses[:, 2], label="-kl")
        return fig

    @property
    @abc.abstractmethod
    def posterior(self) -> tuple[dist.Distribution, dist.Distribution]:
        """ :returns distribution of w and (prec_obs if in the model)
            Note if prec_obs has a Delta distribution, it should be marked as masked so that
            hack with keeping it constant under another Delta approximation doesn't blow up loss
        """
        raise NotImplementedError()

    @abc.abstractmethod
    def _get_guide(self) -> Callable[[jnp.ndarray, Optional[jnp.ndarray]], Any]:
        # This needs to enforce that if self._bnn's prior on prec_obs is masked then
        # in the guide, prec_obs is treated as a constant and not as a numpy.param + Delta
        # so that gradients exist and loss is not inf
        raise NotImplementedError()

#### Mean-Field

In [1001]:
class BasicMeanFieldGaussianVIExperiment(BasicVIExperiment):
    def _get_guide(self) -> Callable[[jnp.ndarray, Optional[jnp.ndarray]], Any]:
        bnn_weight_dim = self._bnn.get_weight_dim()
        def guide(X, Y=None):
            w_loc = numpyro.param("w_loc", lambda rng_key: dist.Normal(scale=0.25).sample(rng_key, (bnn_weight_dim,)))
            w_scale = numpyro.param("w_scale", jnp.full((bnn_weight_dim,), 1e-5), constraint=constraints.softplus_positive)
            with handlers.scale(scale=BETA):
                numpyro.sample("w", dist.Normal(w_loc, w_scale).to_event(1))
            _, prec_obs_prior = self._bnn.prior
            if prec_obs_prior is not None:
                # See comment above for initialising prec_obs to its point mass as it is masked!
                # Taking the prior mean returns the delta mass location in the Delta case
                prec_obs_loc = numpyro.param("prec_obs_loc", prec_obs_prior.mean, constraint=constraints.positive)
                prec_obs_dist = dist.Delta(prec_obs_loc)
                if isinstance(prec_obs_prior, dist.MaskedDistribution):
                    # Treat prec_obs as constant here, decouple from parameter completely,
                    # otherwise it would give MAP on uniform improper prior
                    del prec_obs_dist  # Lose dependence on "prec_obs_loc" numpyro.param
                    prec_obs_dist = dist.Delta(prec_obs_prior.mean)
                numpyro.sample("prec_obs", prec_obs_dist)
        return guide

    @property
    def posterior(self) -> tuple[dist.Distribution, Optional[dist.Distribution]]:
        assert self._params is not None
        w_posterior = dist.Normal(loc=self._params["w_loc"], scale=self._params["w_scale"]).to_event(1)
        # Note for further VI it is a problem that support(prec_obs) is a single point,
        # therefore we mask this distribution so KL computation is ignored, and make sure to
        # initialise the delta guide to this point!
        prec_obs_posterior = dist.Delta(self._params["prec_obs_loc"]).mask(False) \
            if "prec_obs_loc" in self._params.keys() else None
        return w_posterior, prec_obs_posterior

In [1002]:
class AutoMeanFieldNormalVIExperiment(BasicVIExperiment):
    def _get_guide(self)-> Callable[[jnp.ndarray, Optional[jnp.ndarray]], Any]:
        return autoguide.AutoNormal(self._bnn, init_loc_fn=numpyro.infer.init_to_sample, init_scale=1e-5)

    @property
    def posterior(self) -> tuple[dist.Distribution, dist.Distribution]:
        raise NotImplementedError()

In [1003]:
class AutoDeltaVIExperiment(BasicVIExperiment):
    def _get_guide(self) -> Callable[[jnp.ndarray, Optional[jnp.ndarray]], Any]:
        return autoguide.AutoDelta(self._bnn, init_loc_fn=numpyro.infer.init_to_sample)

    @property
    def posterior(self) -> tuple[dist.Distribution, dist.Distribution]:
        raise NotImplementedError()

#### Full-rank

In [1004]:
class BasicFullRankGaussianVIExperiment(BasicVIExperiment):
    def _get_guide(self) -> Callable[[jnp.ndarray, Optional[jnp.ndarray]], Any]:
        bnn_weight_dim = self._bnn.get_weight_dim()
        def guide(X, Y=None):
            w_loc = numpyro.param("w_loc", lambda rng_key: dist.Normal().sample(rng_key, (bnn_weight_dim,)))
            w_cov = numpyro.param("w_cov", 0.1*jnp.eye(bnn_weight_dim), constraint=constraints.positive_definite)
            with handlers.scale(scale=BETA):
                numpyro.sample("w", dist.MultivariateNormal(w_loc, w_cov))
            _, prec_obs_prior = self._bnn.prior
            if prec_obs_prior is not None:
                # See comment above for initialising prec_obs to its point mass as it is masked!
                # Taking the prior mean returns the delta mass location in the Delta case
                prec_obs_loc = numpyro.param("prec_obs_loc", prec_obs_prior.mean, constraint=constraints.positive)
                prec_obs_dist = dist.Delta(prec_obs_loc)
                if isinstance(prec_obs_prior, dist.MaskedDistribution):
                    # Treat prec_obs as constant here, decouple from parameter completely,
                    # otherwise it would give MAP on uniform improper prior
                    del prec_obs_dist  # Lose dependence on "prec_obs_loc" numpyro.param
                    prec_obs_dist = dist.Delta(prec_obs_prior.mean)
                numpyro.sample("prec_obs", prec_obs_dist)
        return guide

    @property
    def posterior(self) -> tuple[dist.Distribution, Optional[dist.Distribution]]:
        assert self._params is not None
        w_posterior = dist.MultivariateNormal(loc=self._params["w_loc"],
                                              covariance_matrix=self._params["w_cov"])
        # Note for further VI it is a problem that support(prec_obs) is a single point,
        # therefore we mask this distribution so KL computation is ignored, and make sure to
        # initialise the delta guide to this point!
        prec_obs_posterior = dist.Delta(self._params["prec_obs_loc"]).mask(False) \
            if "prec_obs_loc" in self._params.keys() else None
        return w_posterior, prec_obs_posterior

### Laplace

In [1083]:
class AutoLaplaceExperiment(BasicVIExperiment):
    def __init__(self, bnn: BayesianNeuralNetwork, data: Data, diag: bool = True, shrink: float = 25.0,
                 num_samples: int = 2_000, max_iter: int = 150_000):
        super().__init__(bnn, data, num_samples, max_iter)
        self._diag = diag
        self._shrink = shrink

    def _get_guide(self) -> Callable[[jnp.ndarray, Optional[jnp.ndarray]], Any]:
        if not self._diag:
            hessian_fn = lambda f, x: jax.hessian(f)(x) + jnp.eye(x.shape[-1]) * self._shrink
        else:
            hessian_fn = lambda f, x: jnp.diag(jnp.diag(jax.hessian(f)(x))) + jnp.eye(x.shape[-1]) * self._shrink

        self._guide = autoguide.AutoLaplaceApproximation(
            self._bnn, init_loc_fn=functools.partial(numpyro.infer.init_to_uniform, radius=1.2),
            hessian_fn=hessian_fn
        )
        return self._guide

    @property
    def posterior(self) -> tuple[dist.Distribution, Optional[dist.Distribution]]:
        return self._guide.get_posterior(self._params), None

    def make_predictions(self, rng_key_predict: random.PRNGKey):
        assert self._params is not None and self._guide is not None
        X_test, _ = self._data.test
        posterior = self._guide.get_posterior(self._params)
        samples = posterior.sample(rng_key_predict, sample_shape=(self._num_samples,))
        predictive = Predictive(model=self._bnn, posterior_samples={'w': samples})
        self._predictions = predictive(rng_key_predict, X=X_test, Y=None)#['Y'][..., 0]


In [1157]:
experiment = AutoLaplaceExperiment(
    fresh_bnn,
    data,
    num_samples=10_000,
    max_iter=VI_MAX_ITER,
    diag=True,
    shrink=100.0,
)
rk = random.PRNGKey(0)
# experiment.run(random.PRNGKey(1))
# experiment.show_convergence_plot()
# experiment.run(random.PRNGKey(1)).savefig("figs/simple_mfvi_4.png")
# experiment.show_convergence_plot().savefig("figs/simple_mfvi_conv_4.png")

In [1158]:
assert False
for i in tqdm.tqdm(range(15)):
    rk, rkc = random.split(rk)
    experiment.train(rkc)
    with open("svi-state.pkl", "wb") as f:
        pickle.dump(experiment._saved_svi_state, f)
    rk, rkc = random.split(rk)
    experiment.make_predictions(rkc)
    experiment.make_plots().savefig(f"figs/post-pred-{i}.png")
    fig, ax = plt.subplots()
    # ax.plot(experiment._losses[::(VI_MAX_ITER//50)])
    ax.plot(experiment._eval_losses)
    ax.set_ylim(experiment._eval_losses.min(), jnp.percentile(experiment._eval_losses, 90.0))
    fig.savefig(f"figs/loss-{i}.png")

AssertionError: 

In [1159]:
# with open("/tmp/svi-state.pkl", "rb") as f:
#     print(pickle.load(f))

In [1161]:
# go again:
rk, rkc = random.split(rk)
experiment.train(rkc, num_iter=10_000)

Initial eval loss: 3764533.7500 (lik: -612.6980, kl: 3763921.0000)


100%|██████████| 50/50 [00:38<00:00,  1.31it/s, init loss: 3764533.7500, avg. train loss / eval. loss [9800-10000]: -90956.2188 / -90956.4297]


SVI elapsed time: 41.19251585006714





In [1162]:
rk, rkc = random.split(rk)
experiment.make_predictions(rkc)
experiment.make_plots().show()

KeyboardInterrupt: 

In [964]:
fig, ax = plt.subplots()
ax.plot(-experiment._eval_losses[:, 1], label="-loglik")
ax.plot(experiment._eval_losses[:, 2], label="+kl")
ax.legend()
fig.show()

<IPython.core.display.Javascript object>

In [612]:
fig, ax = plt.subplots()
ax.plot(experiment._losses[1000::10*(VI_MAX_ITER // VI_MAX_ITER)])
# ax.plot(experiment._eval_losses)
ax.set_ylim(0, 1000)
fig.show()

<IPython.core.display.Javascript object>

### Loglikelihood analysis

In [193]:
def sample_posterior(rng_key):
    fitted_guide = handlers.substitute(handlers.seed(experiment._get_guide(), rng_key), experiment._params)
    trace = handlers.trace(fitted_guide).get_trace(X=experiment._data.train[0], Y=None)
    # print(numpyro.util.format_shapes(trace))
    return dict(w=trace["w"]["value"])

In [224]:
n = 6
post_fns = vmap(sample_posterior)(random.split(random.PRNGKey(0), n))

In [229]:
def loglik(rng_key, params):
    rng_lik, rng_gen = random.split(rng_key)
    model = handlers.substitute(handlers.seed(experiment._bnn, rng_lik), params)
    trace = handlers.trace(model).get_trace(X=experiment._data.train[0], Y=experiment._data.train[1])
    # print(numpyro.util.format_shapes(trace, compute_log_prob=True))
    w_node = trace["w"]
    y_node = trace["Y"]

    model = handlers.substitute(handlers.seed(experiment._bnn, rng_gen), params)
    trace =  handlers.trace(model).get_trace(X=experiment._data.test[0], Y=None)
    return dict(prior_logprob=w_node['fn'].log_prob(w_node['value']),
                loglik=y_node['fn'].log_prob(y_node['value']).sum(),
                Y_mean=trace["Y_mean"]["value"].squeeze(),
                Y_scale=trace["Y_scale"]["value"].squeeze())

In [230]:
ll = vmap(loglik)(random.split(random.PRNGKey(1), n), post_fns)

In [237]:
fig, axs = plt.subplots(ncols=3, nrows=2, figsize=(9, 6))
for i, ax in enumerate(axs.ravel()):
    ax.plot(data.train[0][:, 1], data.train[1][:, 0], 'kx', alpha=0.2)
    ax.plot(data.test[0][:, 1], ll['Y_mean'][i])
    ax.fill_between(data.test[0][:, 1], ll['Y_mean'][i] - ll['Y_scale'][i] * 2,
                    ll['Y_mean'][i] + ll['Y_scale'][i] * 2, alpha=0.2)
    ax.set_title(f"loglik={ll['loglik'][i]:.3f}\nlogprior={ll['prior_logprob'][i]:.3f}")
fig.tight_layout()
fig.show()

<IPython.core.display.Javascript object>

In [None]:
# fig, ax = plt.subplots()
# for i in range(10):
#     ax.plot(fresh_data.test[0][:, 1], preds[i])
# plt.show()

In [None]:
# fig.savefig("figs/posterior_predictive_1.png")

### Sequential experiment

In [None]:
class SequentialExperiment(SequentialExperimentBlock):
    def __init__(self, bnn: BayesianNeuralNetwork, data: Data, Block: Type[SequentialExperimentBlock],
                 num_inference_steps: int = 2, **block_kwargs):
        """ :param num_inference_steps: split data into this many chunks, and
                                        do Bayesian inference sequentially on them
        """
        super().__init__(bnn, data)
        self._num_inference_steps = num_inference_steps
        self._Block: Type[SequentialExperimentBlock] = Block
        self._block_kwargs: dict = block_kwargs
        # Initialise state
        self._experiment_blocks: list[Block] = list()

    @property
    def posterior(self) -> tuple[dist.Distribution, Optional[dist.Distribution]]:
        assert len(self._experiment_blocks) > 0
        return self._experiment_blocks[-1].posterior

    def train(self, rng_key_train: random.PRNGKey):
        train_len = self._data.train[0].shape[0]
        rng_key_array_train: random.PRNGKeyArray = random.split(rng_key_train, num=self._num_inference_steps)
        for step_idx, rng_key_train_step in enumerate(rng_key_array_train):
            chunk = slice(step_idx*(train_len//self._num_inference_steps),
                          min(train_len, (step_idx+1)*(train_len//self._num_inference_steps)))
            data_view = DataSlice(self._data, chunk)
            experiment_block = self._Block(self._bnn, data_view, **self._block_kwargs)
            experiment_block.train(rng_key_train_step)
            self._bnn = self._bnn.with_prior(*experiment_block.posterior)
            self._experiment_blocks.append(experiment_block)

    def make_predictions(self, rng_key_predict: random.PRNGKey, final_only: bool = True):
        # Delegate to final experiment block
        assert len(self._experiment_blocks) > 0
        if final_only:
            self._experiment_blocks[-1].make_predictions(rng_key_predict)
        else:
            rng_key_array: random.PRNGKeyArray = random.split(rng_key_predict, len(self._experiment_blocks))
            for experiment_block, rng_key in zip(self._experiment_blocks, rng_key_array):
                experiment_block.make_predictions(rng_key)

    def make_plots(self, final_only: bool = True, **kwargs) -> plt.Figure:
        # fig, ax = plt.subplots(nrows=len(self._experiment_blocks))
        assert len(self._experiment_blocks) > 0
        if final_only:
            return self._experiment_blocks[-1].make_plots()
        else:
            for experiment_block in self._experiment_blocks:
                experiment_block.make_plots()
        return None

In [None]:
class ExperimentWithLastBlockReplaced(Experiment):
    def __init__(self, sequential_experiment: SequentialExperiment, LastBlock: Type[Experiment], **kwargs):
        super().__init__(sequential_experiment._bnn, sequential_experiment._data)
        self._LastBlock: Type[Experiment] = LastBlock
        self._sequential_experiment: SequentialExperiment = sequential_experiment
        self._kwargs = kwargs
        self._last_block: Optional[Experiment] = None

    def train(self, rng_key_train: random.PRNGKey):
        rng_seq, rng_hmc = random.split(rng_key_train)
        self._sequential_experiment.train(rng_seq)
        last_seq_block = self._sequential_experiment._experiment_blocks[-1]
        last_block = self._LastBlock(bnn=last_seq_block._bnn, data=last_seq_block._data, **self._kwargs)
        last_block.train(rng_hmc)
        self._last_block = last_block
        # self._sequential_experiment._experiment_blocks[-1] = last_block

    def make_predictions(self, rng_key_predict: random.PRNGKey, **kwargs):
        rng_seq, rng_hmc = random.split(rng_key_predict)
        self._sequential_experiment.make_predictions(rng_seq, **kwargs)
        self._last_block.make_predictions(rng_hmc)

    def make_plots(self, final_only: bool = True, **kwargs) -> plt.Figure:
        self._sequential_experiment.make_plots(final_only)
        self._last_block.make_plots(**kwargs)

## Space for running experiments

In [None]:
tr = handlers.trace(handlers.seed(bnn, random.PRNGKey(0))).get_trace(X=data.train[0], Y=None)
print(numpyro.util.format_shapes(tr, compute_log_prob=True))

In [None]:
# fresh_bnn = copy.deepcopy(bnn)
fresh_data = copy.deepcopy(data)
first_half_data = DataSlice(fresh_data, slice(50))
second_half_data = DataSlice(fresh_data, slice(50,100))

In [None]:
first_half_experiment = BasicMeanFieldGaussianVIExperiment(fresh_bnn, first_half_data, num_samples=1000, max_iter=VI_MAX_ITER)
first_half_experiment.run(random.PRNGKey(0)).savefig("figs/manual_first_half_VI_4.png")
first_half_experiment.show_convergence_plot().savefig("figs/manual_first_half_VI_conv_4.png")

In [None]:
# rk = random.PRNGKey(0)
# # For consistency with exp
# rk, _ = random.split(rk)

In [None]:
# first_half_experiment._data.train[0][..., 1].max()

In [None]:
# rk, rkc = random.split(rk)
# first_half_experiment.train(rkc, num_iter=1_000)

In [None]:
# rk, rkc = random.split(rk)
# first_half_experiment.make_predictions(rkc)
# first_half_experiment.make_plots().show()

In [None]:
# plt.plot(first_half_experiment._losses[::(VI_MAX_ITER // 50)])
# plt.plot(first_half_experiment._eval_losses)

In [None]:
second_half_HMC_experiment = BasicHMCExperiment(fresh_bnn.with_prior(*first_half_experiment.posterior), second_half_data, num_samples=1000, num_warmup=500)
second_half_HMC_experiment.run(random.PRNGKey(0)).savefig("figs/manual_second_half_HMC_4.png")

In [None]:
second_half_VI_experiment = BasicMeanFieldGaussianVIExperiment(fresh_bnn.with_prior(*first_half_experiment.posterior), second_half_data, num_samples=1000, max_iter=5*VI_MAX_ITER)
second_half_VI_experiment.run(random.PRNGKey(0)).savefig("figs/manual_second_half_VI_4.png")

In [None]:
# fresh_data = copy.deepcopy(data)

In [None]:
sequential_experiment = SequentialExperiment(fresh_bnn, fresh_data, BasicMeanFieldGaussianVIExperiment, num_inference_steps=2, max_iter = VI_MAX_ITER)

In [None]:
sequential_experiment = ExperimentWithLastBlockReplaced(sequential_experiment, BasicHMCExperiment, num_samples=200, num_warmup=100)

In [None]:
# sequential_experiment.train(random.PRNGKey(0))

In [None]:
# sequential_experiment.make_predictions(random.PRNGKey(2), final_only=False)

In [None]:
# sequential_experiment.make_plots(final_only=False)

In [None]:
# # Custom plotting for sequential experiment
# fig, axs = plt.subplots(figsize=(8, 4), ncols=2)
# for i, ax in enumerate(axs.ravel()):
#     experiment_block = sequential_experiment._sequential_experiment._experiment_blocks[i]
#     predictions = experiment_block._predictions["Y"][..., 0]
#     mean_predictions = experiment_block._predictions["Y_mean"][..., 0]
#     data = experiment_block._data
#     X, Y = data.train
#     X_test, _ = data.test
#     # compute mean prediction and confidence interval around median
#     mean_means = jnp.mean(mean_predictions, axis=0)
#     mean_percentiles = np.percentile(predictions, [5.0, 95.0], axis=0)
#     # plot training data
#     ax.plot(X[:, 1], Y[:, 0], "kx")
#     # plot predictions & quantiles
#     ax.plot(X_test[:, 1], mean_means, color="blue")
#     ax.fill_between(X_test[:, 1], *mean_percentiles, color="lightblue")
#     ax.set_title(str(data._train_idx_slice))
# fig.tight_layout()
# fig.savefig("figs/sequential-VI-simple3.png")

#### Draw samples from prior predictive

In [1156]:
fig, ax =  plt.subplots()
with handlers.seed(rng_seed=random.PRNGKey(1)):
    t = data.test[0]
    for _ in range(50):
        prior_fn = handlers.trace(fresh_bnn).get_trace(X=t, Y=None)["Y_scale"]["value"]
        ax.plot(t[:,1], prior_fn, alpha=0.5)
        ax.set_yscale('log')
# plt.savefig("figs/prior_pred_4.png")

<IPython.core.display.Javascript object>

#### Reverse dataset for two halves sequential experiment

In [None]:
# fresh_bnn = copy.deepcopy(bnn)

In [None]:
reversed_data = ReverseData(data)

In [None]:
reversed_sequential_experiment = SequentialExperiment(fresh_bnn, reversed_data, BasicFullRankGaussianVIExperiment)

In [None]:
# reversed_sequential_experiment.train(random.PRNGKey(0))

In [None]:
# reversed_sequential_experiment.make_predictions(random.PRNGKey(1), final_only=False)

In [None]:
# # Custom plotting for reversed sequential experiment
# fig, axs = plt.subplots(figsize=(8, 4), ncols=2)
# for i, ax in enumerate(axs.ravel()):
#     experiment_block = reversed_sequential_experiment._experiment_blocks[i]
#     predictions = experiment_block._predictions["Y"][..., 0]
#     data = experiment_block._data
#     X, Y = data.train
#     X_test, _ = data.test
#     # compute mean prediction and confidence interval around median
#     mean_predictions = jnp.mean(predictions, axis=0)
#     percentiles = np.percentile(predictions, [5.0, 95.0], axis=0)
#     # plot training data
#     ax.plot(X[:, 1], Y[:, 0], "kx")
#     # plot predictions & quantiles
#     ax.plot(X_test[:, 1], mean_predictions, color="blue")
#     ax.fill_between(X_test[:, 1], *percentiles, color="lightblue")
#     ax.set_title(str(data._train_idx_slice))
# fig.tight_layout()
# # fig.savefig("figs/sequential-full-rank-VI-reversed1.png")

#### Randomized dataset experiment

In [None]:
train_len = data.train[0].shape[0]
random_perm = np.random.choice(np.arange(train_len), size=train_len, replace=False)
permuted_data = PermutedData(data, random_perm)

In [841]:
plt.close('all')