# Global inducing points for BNNs

One central challenge with Bayesian Neural Networks (BNNs) is handling the posteriors over their parameters. In Variational Inference (VI), we approximate the exact posterior using another tractable distibution, which can be used for making predictions. Unfortunately, common choices such as factored approximate posteriors {cite}`blundell2015weight` are typically a poor approximation to the true posterior. For example, a usual choice for the approximate posterior is a factored Gaussian, which typically leads to underfitting, and poor estimates of the predictive uncertainty. {cite}`foong2019between` Recently, Ober and Aitchison {cite}`ober2021global` introduced an approximate posterior which goes beyond the factored approximation and yields improved results. Their approximation is based on inducing points, an idea which is common in the Gaussian Process (GP) literature. In particular, this approximate posterior can be used for Deep GPs (DGPs) as well as BNNs, highlighting the similarities between these models.

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_probability as tfp

from tqdm import tqdm

from check_shape import check_shape

tfk = tf.keras
tfd = tfp.distributions

matplotlib.rcParams['mathtext.fontset'] = 'stix'
matplotlib.rcParams['font.family'] = 'STIXGeneral'

## Prior model

Suppose we wish to perform a regression task, to learn a map from inputs $\mathbf{X} \in \mathbb{R}^{K \times D_x}$ to the corresponding outputs $\mathbf{Y} \in \mathbb{R}^{K \times D_y}$. Consider a fully connected neural network for regression, made up of $L$ hidden layers, each containing $N_1, \dots, N_L$ hidden units, followed by by a last linear layer which maps from the $N_L$ units to the dimension of the data $D_y$. Let the weights in each layer be

$$\begin{align}
\mathbf{W} = \{\mathbf{W}_l \in \mathbb{R}^{N_l \times N_{l+1}} \}_{l = 1}^{L+1},
\end{align}$$

where $N_0 = D_x$ and $N_{L+1} = D_y$. Thus the network takes the form

$$\begin{align}
\mathbf{F}_1 &= \mathbf{X} \mathbf{W}_1, \\
\mathbf{F}_l &= \phi(\mathbf{F}_{l - 1}) \mathbf{W}_l,
\end{align}$$

where $\phi$ is a nonlinearity. Note that under this notation, the weights post-multiply the activations rather than pre-multiplying them. Now suppose we place a prior $p(\mathbf{W})$ over the weights, together with a Gaussian likelihood function

$$\begin{align}
p(\mathbf{Y} | \mathbf{W}, \mathbf{X}) = \prod_{k = 1}^K \mathcal{N}(\mathbf{y}_k; \mathbf{f}_k \mathbf{W}_{L+1}, \sigma_n^2 I),
\end{align}$$

where $\mathbf{y}_k$ and $\mathbf{f}_{L, k}$ correspond to the $k^{th}$ row of the $\mathbf{Y}$ and $\mathbf{F}_{L+1}$ matrices. The posterior over the weights of this network is not analytic and must therefore be approximated. 

However, conditioned on the weights of all previous layers, the posterior over the weights of the last layer is analytic. In particular, we have

$$\begin{align}
p(\mathbf{w}_{L+1, d} | \mathbf{W}_{1:L}, \mathbf{X}) \propto \mathcal{N}(y_{\cdot, d}; \mathbf{f}_{\cdot, d} \mathbf{w}_{L+1, \cdot, d}, \sigma_n^2 I) p(\mathbf{w}_{L+1, d}),
\end{align}$$

where $y_{\cdot, d}$ is the is the $d^{th}$ column of $\mathbf{Y}$ and . Note that the features $\mathbf{f}_n$ are dependent on $\mathbf{W}_{1:L}$. Therefore, $\mathbf{w}_{L+1,d}$ has conditional posterior

$$\begin{align}
p(\mathbf{w}_{L+1, d} | \mathbf{W}_{1:L}, \mathbf{X}) = \mathcal{N}(\mathbf{w}_{L+1, d}; , (\sigma_n^{-2} I)),
\end{align}$$

Ober and Aitchison {cite}`ober2021global` draw inspiration from this to propose an approximate posterior, in which the weights of a layer given all previous layers are conditionally Gaussian, but the full posterior is not.

## Approximate posterior

The approximate posterior takes the form

$$\begin{align}
q\left(\mathbf{W}_l | \{\mathbf{W}_{l'}\}_{l' = 1}^{l-1}\right) &\propto \prod_{d = 1}^{D_l} \mathcal{N}\left(\mathbf{v}^l_d; \phi(\mathbf{F}_{l-1}) \mathbf{w}^l_d, \boldsymbol{\Lambda}^{-1}_l\right) p(\mathbf{w}^l_d),
\end{align}$$



which, if we rearrange to explicitly be in the form of a distrubution over the $\mathbf{w}_d^l$ weights, becomes

$$\begin{align}
q\left(\mathbf{W}_l | \{\mathbf{W}_{l'}\}_{l' = 1}^{l-1}\right) &= \prod_{d = 1}^{D_l} \mathcal{N}\left(\mathbf{w}^l_d; \boldsymbol{\mu}_l^w, \boldsymbol{\Sigma}_l^w\right), \\
                                          \boldsymbol{\mu}_l^w  &= \boldsymbol{\Sigma}_l^w \phi\left(\mathbf{F}_{l-1}\right)^\top \boldsymbol{\Lambda}_l \mathbf{v}_d^l, \\
                                       \boldsymbol{\Sigma}_l^w  &= \left( D_l \mathbf{I} + \phi\left(\mathbf{F}_{l-1}\right)^\top \boldsymbol{\Lambda}_l \phi\left(\mathbf{F}_{l-1}\right) \right)^{-1}.
\end{align}$$

We will define a `GlobalInducingDenseLayer`, which handdles propagating the data activations $\mathbf{F}_l$, the inducing activations $\mathbf{U}_l$ and computes the contribution of the layer to the total KL divergence.

In [None]:
class GlobalInducingDenseLayer(tfk.layers.Layer):
    
    def __init__(self,
                 num_input,
                 num_output,
                 num_inducing,
                 nonlinearity,
                 dtype,
                 name="global_inducing_fully_connected_layer",
                 **kwargs):
        
        super().__init__(name=name, dtype=dtype, **kwargs)
        
        self.num_input = num_input + 1
        self.num_output = num_output
        self.num_inducing = num_inducing
        
        # Set nonlinearity for the layer
        self.nonlinearity = (lambda x: x) if nonlinearity is None else \
                            getattr(tf.nn, nonlinearity)
    
    def build(self, input_shape):
        
        # Set up prior mean, scale and distribution
        self.prior_mean = tf.zeros(
            shape=(self.num_output, self.num_input),
            dtype=self.dtype
        )
        
        self.prior_scale = tf.ones(
            shape=(self.num_output, self.num_input),
            dtype=self.dtype
        )
        self.prior_scale = self.prior_scale / self.num_input**0.5
        
        self.prior = tfd.MultivariateNormalDiag(
            loc=self.prior_mean,
            scale_diag=self.prior_scale
        )
        
        # Set up pseudo observation means and variances
        self.pseudo_means = tf.zeros(
            shape=(self.num_inducing, self.num_output),
            dtype=self.dtype
        )
        self.pseudo_mean = tf.Variable(self.pseudo_means)
        
        self.pseudo_log_prec = tf.zeros(
            shape=(self.num_inducing,),
            dtype=self.dtype
        )
        self.pseudo_log_prec = tf.Variable(self.pseudo_log_prec)
        
        
    @property
    def pseudo_precision(self):
        return tf.math.exp(self.pseudo_log_precision)
    
        
    def q_prec_cov_chols(self, Uin):
        
        phiU = self.nonlinearity(Uin)
        pseudo_prec = tf.math.exp(self.pseudo_log_prec)
        
        # Compute precision matrix of multivariate normal
        phiT_lambda_phi = tf.einsum("mi, m, mj -> ij", phiU, pseudo_prec, phiU)
        
        q_prec = tf.linalg.diag(self.prior_scale[0, :]**-2.) + phiT_lambda_phi
        
        # Compute cholesky of approximate posterior precision
        q_prec_chol = tf.linalg.cholesky(q_prec)
        
        # Compute cholesky of approximate posterior covariance
        iq_prec_chol = tf.linalg.triangular_solve(
            q_prec_chol,
            tf.eye(q_prec_chol.shape[0]),
            lower=True
        )
        
        q_cov = tf.matmul(iq_prec_chol, iq_prec_chol, transpose_a=True)
        q_cov_chol = tf.linalg.cholesky(q_cov)
        
        return q_prec_chol, q_cov_chol
    
    
    def q_mean(self, Uin, prec_chol):
        
        phiU = self.nonlinearity(Uin)
        pseudo_prec = tf.math.exp(self.pseudo_log_prec)
        
        mean = tf.matmul(
            phiU,
            pseudo_prec[:, None] * self.pseudo_mean,
            transpose_a=True
        )
        
        mean = tf.linalg.cholesky_solve(prec_chol, mean)
        mean = tf.transpose(mean, [1, 0])
        
        return mean
        
        
    def call(self, Fin, Uin):
        
        # Augment input features with ones to absorb bias
        Fones = tf.ones(shape=(Fin.shape[0], 1), dtype=self.dtype)
        Fin = tf.concat([Fin, Fones], axis=-1)
        
        Uones = tf.ones(shape=(Uin.shape[0], 1), dtype=self.dtype)
        Uin = tf.concat([Uin, Uones], axis=-1)
        
        Din = self.num_input
        Dout = self.num_output
        M = self.num_inducing
        
        # Check shape of input features Fin and pseudo-means
        check_shape(
            [Fin, Uin, self.pseudo_means],
            [(-1, Din), (M, Din), (M, Dout)]
        )
        
        # Compute cholesky factors of q precision and covariance.
        # These are common between all weight columns, i.e. the covariance
        # between weights leading to a neuron in the next layer is shared
        # between all next neurons.
        q_prec_chol, q_cov_chol = self.q_prec_cov_chols(Uin)
        
        check_shape(
            [q_prec_chol, q_cov_chol],
            [(Din, Din), (Din, Din)]
        )
        
        # Compute means of q. There is a different mean vector for
        # each column of weights.
        q_mean = self.q_mean(Uin, q_prec_chol)
        
        check_shape(q_mean, (Dout, Din))
        
        # Sample approximate posterior for the weights
        q_cov_chol = tf.stack([q_cov_chol]*Dout, axis=0)
        q = tfd.MultivariateNormalTriL(loc=q_mean, scale_tril=q_cov_chol)
        wT = q.sample()
        w = tf.transpose(wT, [1, 0])
        
        check_shape(w, (Din, Dout))
        
        # Compute contibution to ELBO
        kl_term = q.kl_divergence(self.prior)
        kl_term = tf.reduce_sum(kl_term)
        
        # Compute log-probability of weights under prior
        log_p = self.prior.log_prob(wT)
        log_p = tf.reduce_sum(log_p)
        
        # Compute log-probability of weights under approximate posterior
        log_q = q.log_prob(wT)
        log_q = tf.reduce_sum(log_q)
        
        # Compute Fout and Uout and return
        Fout = tf.matmul(self.nonlinearity(Fin), w)
        Uout = tf.matmul(self.nonlinearity(Uin), w)
        
        return Fout, Uout, kl_term, log_p, log_q

We can then stack a few `GlobalInducingDenseLayers` to form a `GlobalInducingFullyConnectedNetwork`. We use an architecture using two hidden layers, each using $50$ units, as done by Ober and Aitchinson.

In [None]:
class GlobalInducingFullyConnectedNetwork(tfk.Model):

    def __init__(self,
                 num_input,
                 num_output,
                 inducing_points,
                 nonlinearity,
                 dtype,
                 name="global_inducing_fully_connected",
                 **kwargs):
        
        super().__init__(name=name, dtype=dtype, **kwargs)
        
        self.num_input = num_input
        self.num_output = num_output
        self.inducing_points = inducing_points
        self.num_inducing = inducing_points.shape[0]
        self.nonlinearity = nonlinearity
        self.num_hidden = [50, 50]
        
        
    def build(self, input_shape):
        
        self.inducing_points = tf.Variable(self.inducing_points)
        
        self.l1 = GlobalInducingDenseLayer(
            num_input=self.num_input,
            num_output=self.num_hidden[0],
            num_inducing=self.num_inducing,
            nonlinearity=None,
            dtype=self.dtype
        )
        
        self.l2 = GlobalInducingDenseLayer(
            num_input=self.num_hidden[0],
            num_output=self.num_hidden[1],
            num_inducing=self.num_inducing,
            nonlinearity=self.nonlinearity,
            dtype=self.dtype
        )
        
        self.l3 = GlobalInducingDenseLayer(
            num_input=self.num_hidden[1],
            num_output=self.num_output,
            num_inducing=self.num_inducing,
            nonlinearity=self.nonlinearity,
            dtype=self.dtype
        )
        
        self.log_noise = tf.Variable(
            tf.convert_to_tensor(-2., dtype=self.dtype)
        )
        
    @property
    def noise(self):
        return tf.math.exp(self.log_noise)
        
    @tf.function
    def call(self, x):
        
        F1, U1, kl1, log_p1, log_q1 = self.l1(x, self.inducing_points)
        F2, U2, kl2, log_p2, log_q2 = self.l2(F1, U1)
        F3, U3, kl3, log_p3, log_q3 = self.l3(F2, U2)
        
        means = F3
        scales = self.noise * tf.ones_like(F3)
        
        kl = tf.reduce_sum([kl1, kl2, kl3])
        
        log_p = tf.reduce_sum([log_p1, log_p2, log_p3])
        log_q = tf.reduce_sum([log_q1, log_q2, log_q3])
        
        return means, scales, kl, log_p, log_q
    
    
    def elbo(self, x, y):

        means, scales, kl, _, _ = self(x)

        cond_lik = tfd.Normal(loc=means, scale=scales)
        cond_lik = tf.reduce_sum(cond_lik.log_prob(y))

        elbo = cond_lik - kl

        return elbo, cond_lik, kl

    
    def iwbo(self, x, y, num_samples):
        
        @tf.function
        def call(x):
            return self.call(x)

        iwbo = []

        for i in range(num_samples):

            means, scales, kl, log_p, log_q = call(x)

            cond_lik = tfd.Normal(loc=means, scale=scales)
            cond_lik = tf.reduce_sum(cond_lik.log_prob(y))

            iwbo.append(cond_lik + log_p - log_q)

        iwbo = tf.stack(iwbo, axis=0)
        iwbo = tf.math.reduce_logsumexp(iwbo) - np.log(num_samples)

        return iwbo

In [None]:
num_data = 100
num_input = 1
std_noise = 3.

x1 = tf.random.uniform(minval=-4., maxval=-2., shape=(num_data // 2, 1))
x2 = tf.random.uniform(minval=2., maxval=4., shape=(num_data // 2, 1))

x = tf.concat([x1, x2], axis=0)
y = tf.concat([x1, x2], axis=0) ** 3. + std_noise * tf.random.normal(shape=(num_data, 1))

x = (x - tf.reduce_mean(x)) / tf.math.reduce_std(x)
y = (y - tf.reduce_mean(y)) / tf.math.reduce_std(y)

# Figure to plot on 
plt.figure(figsize=(14, 4))

# Plot data
plt.scatter(
    x[:, 0],
    y[:, 0],
    marker="+",
    c="black"
)

# Format plot
plt.xlim([-2.5, 2.5])
plt.ylim([-3.5, 3.5])

plt.xticks(np.linspace(-2., 2., 3), fontsize=24)
plt.yticks(np.linspace(-3., 3., 3), fontsize=24)

plt.xlabel("$x$", fontsize=32)
plt.ylabel("$y$", fontsize=32)

plt.show()

In [None]:
# We decorate a single gradient descent step with tf.function. On the first
# call of single_step, tensorflow will compile the computational graph first.
# After that, all calls to single_step will use the compiled graph which is
# much faster than the default eager mode execution. In this case, the gain
# is roughly a x20 speedup (with a CPU), which can be checked by commenting
# out the decorator and rerunning the training script.

@tf.function
def single_step(model, optimiser, x, y):

    with tf.GradientTape() as tape:

        elbo, cond_lik, kl = model.elbo(x, y)
        loss = - elbo / x.shape[0]

    gradients = tape.gradient(loss, model.trainable_variables)
    optimiser.apply_gradients(zip(gradients, model.trainable_variables))

    return elbo, cond_lik, kl

In [None]:
# Set model constants
num_input = 1
num_output = 1
num_inducing = 100
dtype = tf.float32
nonlinearity = "relu"
num_steps = int(1e5)

# Initialise inducing points at subset of training points
inducing_idx = tf.random.shuffle(tf.range(x.shape[0]))[:num_inducing]
inducing_points = tf.gather(x, inducing_idx)

# Initialise model
model = GlobalInducingFullyConnectedNetwork(
    num_input=num_input,
    num_output=num_output,
    inducing_points=inducing_points,
    nonlinearity=nonlinearity,
    dtype=dtype
)

# Initialise optimiser
optimiser = tfk.optimizers.Adam(learning_rate=1e-3)
    
# Set progress bar and suppress warnings
progress_bar = tqdm(range(1, num_steps+1))
tf.get_logger().setLevel('ERROR')

# Set tensors for keeping track of quantities of interest
train_elbo = []
train_cond_lik = []
train_kl = []

# Train model
for i in progress_bar:
        
    elbo, cond_lik, kl = single_step(
        model=model,
        optimiser=optimiser,
        x=x,
        y=y
    )

    if i % 1000 == 0:
        
        progress_bar.set_description(
            f"ELBO {elbo:.1f}, "
            f"Cond-lik. {cond_lik:.1f}, "
            f"KL {kl:.1f}"
        )
        
    train_elbo.append(elbo)
    train_cond_lik.append(cond_lik)
    train_kl.append(kl)

In [None]:
# Number of samples to draw, three will be plotted
num_samples = 100

# Input locations to plot
x_plot = tf.linspace(-4., 4., 100)[:, None]

# Draw samples from model
samples = [model(x_plot)[0] for i in range(num_samples)]

# Compute mean and standard deviation of samples
mean = tf.reduce_mean(samples, axis=0)
std = tf.math.reduce_std(samples, axis=0)

# Figure to plot on 
plt.figure(figsize=(14, 4))

# Plot epistemic uncertainty
plt.fill_between(
    x_plot[:, 0],
    mean[:, 0] - 2.*std[:, 0],
    mean[:, 0] + 2.*std[:, 0],
    color="tab:gray",
    alpha=0.4,
    zorder=1
)

# Plot three samples
for i, color in enumerate(["tab:red", "tab:green", "tab:blue"]):
    
    plt.plot(
        x_plot[:, 0],
        samples[i][:, 0],
        color=color,
        zorder=2
    )

# Plot data
plt.scatter(
    model.inducing_points[:, 0],
    -3.*tf.ones_like(model.inducing_points[:, 0]),
    marker="x",
    c="tab:purple",
    zorder=3
)

# Plot data
plt.scatter(
    x[:, 0],
    y[:, 0],
    marker="+",
    c="black",
    zorder=3
)

# Format plot
plt.xlim([-3.5, 3.5])
plt.ylim([-4.5, 4.5])

plt.xticks(np.linspace(-3., 3., 3), fontsize=24)
plt.yticks(np.linspace(-4., 4., 5), fontsize=24)

plt.xlabel("$x$", fontsize=32)
plt.ylabel("$y$", fontsize=32)

plt.show()

In [None]:
def moving_average(array, n):
    
    cumsum = np.cumsum(array)
    cumsum[n:] = cumsum[n:] - cumsum[:-n]
    
    moving_average = cumsum[n - 1:] / n
    
    return moving_average

plt.plot(moving_average(tf.stack(train_elbo).numpy(), n=1000))
# plt.xlim([2000, 100000])
plt.ylim([-25, -15])
plt.show()

## How tight is the GIP ELBO?

We can also check how tight the ELBO of the GIP approximate posterior is, using importance sampling.

In [None]:
num_repetitions = 10
num_iwbo_samples = 1000

elbos = [model.elbo(x=x, y=y)[0] for i in range(num_repetitions)]
iwbos = [model.iwbo(x=x, y=y, num_samples=num_iwbo_samples) for i in range(num_repetitions)]

print(
    f"ELBO: {tf.reduce_mean(elbos): 7.3f} +/- {2.*tf.math.reduce_std(elbos)/num_repetitions**0.5:.3f} "
    f"(estimated with {num_repetitions} ELBO samples)"
)

print(
    f"IWBO: {tf.reduce_mean(iwbos): 7.3f} +/- {2.*tf.math.reduce_std(iwbos)/num_repetitions**0.5:.3f} "
    f"(estimated with {num_repetitions} IWBO samples, each using {num_iwbo_samples} weight samples)"
)

## References

```{bibliography} ./references.bib
```