# Online SGPR.

OSGPR - Bui et al 2017.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
from jax import jit
import optax as ox

import gpjax as gpx
import tensorflow as tf

tf.random.set_seed(42)
key = jr.PRNGKey(123)

## Dataset

With the necessary modules imported, we simulate a dataset $\mathcal{D} = (\boldsymbol{x}, \boldsymbol{y}) = \{(x_i, y_i)\}_{i=1}^{5000}$ with inputs $\boldsymbol{x}$ sampled uniformly on $(-5, 5)$ and corresponding binary outputs

$$\boldsymbol{y} \sim \mathcal{N} \left(\sin(4 * \boldsymbol{x}) + \sin(2 * \boldsymbol{x}), \textbf{I} * (0.2)^{2} \right).$$

We store our data $\mathcal{D}$ as a GPJax `Dataset` and create test inputs for later.

In [None]:
n = 5000
noise = 0.2

x = jr.uniform(key=key, minval=-5.0, maxval=5.0, shape=(n,)).sort().reshape(-1, 1)
f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)
signal = f(x)
y = signal + jr.normal(key, shape=signal.shape) * noise

D1 = gpx.Dataset(X=x[:2500], y=y[:2500])
D2 = gpx.Dataset(X=x[2500:2600], y=y[2500:2600])

xtest = jnp.linspace(-5.5, 5.5, 500).reshape(-1, 1)

In [None]:
z_1 = jnp.linspace(-5.0, 5.0, 100).reshape(-1, 1)
z_2 = jnp.linspace(-5.0, 5.0, 100).reshape(-1, 1)

fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(x, y, "o", alpha=0.3)
ax.plot(xtest, f(xtest))
[ax.axvline(x=z_i, color="black", alpha=0.3, linewidth=1) for z_i in z_1]
plt.show()

In [None]:
likelihood = gpx.Gaussian(num_datapoints=n)
prior = gpx.Prior(kernel=gpx.RBF())
p =  prior * likelihood
q1 = gpx.VariationalGaussian(prior=prior, inducing_inputs=z_1)

In [None]:
svgp = gpx.StochasticVI(posterior=p, variational_family=q1)

In [None]:
params, trainables, constrainers, unconstrainers = gpx.initialise(svgp)
params = gpx.transform(params, unconstrainers)

loss_fn = jit(svgp.elbo(D1, constrainers, negative=True))

In [None]:
Dbatched = D1.cache().repeat().shuffle(D1.n).batch(batch_size=100).prefetch(buffer_size=1)

optimiser = ox.adam(learning_rate=0.001)

learned_params = gpx.fit_batches(
    objective = loss_fn,
    params = params,
    trainables = trainables,
    train_data = Dbatched, 
    optax_optim = optimiser,
    n_iters=10000,
)

learned_params = gpx.transform(learned_params, constrainers)

## Predictions q1

In [None]:
latent_dist = q1(learned_params)(xtest)
predictive_dist = likelihood(latent_dist, learned_params)

meanf = predictive_dist.mean()
sigma = predictive_dist.stddev()

fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(x, y, "o", alpha=0.15, label="Training Data", color="tab:gray")
ax.plot(xtest, meanf, label="Posterior mean", color="tab:blue")
ax.fill_between(xtest.flatten(), meanf - sigma, meanf + sigma, alpha=0.3)
[
    ax.axvline(x=z_i, color="black", alpha=0.3, linewidth=1)
    for z_i in learned_params["variational_family"]["inducing_inputs"]
]
plt.show()

# OSGPR

In [None]:
q2 = gpx.VariationalGaussian(prior=prior, inducing_inputs=z_2)

In [None]:
osgpr = gpx.variational_inference.OSGPR(posterior=p, variational_family_old = q1, variational_family=q2, params_old = learned_params)

In [None]:
params, trainables, constrainers, unconstrainers = gpx.initialise(osgpr)
params = gpx.transform(params, unconstrainers)

loss_fn = jit(osgpr.elbo(constrainers, negative=False))

In [None]:
D2 = gpx.Dataset(X=x[2500:2600], y=y[2500:2600])
loss_fn(params, D2)

In [None]:
D2 = gpx.Dataset(X=x[2500:2600], y=y[2500:2600])
Dbatched = D2.cache().repeat().shuffle(D2.n).batch(batch_size=100).prefetch(buffer_size=1)

optimiser = ox.adam(learning_rate=0.001)

learned_params_new = gpx.fit_batches(
    objective = loss_fn,
    params = params,
    trainables = trainables,
    train_data = Dbatched, 
    optax_optim = optimiser,
    n_iters=1000,
)

learned_params_new = gpx.transform(learned_params_new, constrainers)

In [None]:
latent_dist = q2(learned_params_new)(xtest)
predictive_dist = likelihood(latent_dist, learned_params)

meanf = predictive_dist.mean()
sigma = predictive_dist.stddev()

fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(x, y, "o", alpha=0.15, label="Training Data", color="tab:gray")
ax.plot(xtest, meanf, label="Posterior mean", color="tab:blue")
ax.fill_between(xtest.flatten(), meanf - sigma, meanf + sigma, alpha=0.3)
[
    ax.axvline(x=z_i, color="black", alpha=0.3, linewidth=1)
    for z_i in learned_params["variational_family"]["inducing_inputs"]
]
plt.show()

## System configuration

In [None]:
%reload_ext watermark
%watermark -n -u -v -iv -w -a 'Daniel Dodd'