In [None]:
%load_ext autoreload
%autoreload 2

import sys

sys.path.insert(0, "../..")

import jax
import jax.numpy as jnp
import jaxopt
import matplotlib.pyplot as plt
import numpy as np
import optax
from tqdm.notebook import tqdm

import jaxgp as jgp
from jaxgp.datasets import Dataset, CustomDataset, NumpyLoader

np.random.seed(42)

In [None]:
import logging
import tensorflow_probability.substrates.numpy as tfp

logger = logging.getLogger("root")


class CheckTypesFilter(logging.Filter):
    def filter(self, record):
        return "check_types" not in record.getMessage()


logger.addFilter(CheckTypesFilter())

In [None]:
input_dim = 1
output_dim = 1
num_data = 1000
num_test = 1000
num_inducing = 20


def func(X):
    return np.sin(2 * X) + 0.3 * X + np.random.normal(0, 0.1, X.shape)


X = np.random.uniform(-3.0, 3.0, (num_data, input_dim))
Y = func(X)

key = jax.random.PRNGKey(10)

Xtest = jnp.sort(
    jax.random.uniform(key, shape=(num_test, input_dim), minval=-5, maxval=5),
    0,
)

In [None]:
mean = jgp.means.Quadratic()
kernel = jgp.kernels.RBF()
gprior = jgp.GPrior(kernel=kernel, mean_function=mean)
likelihood = jgp.likelihoods.Gaussian()
inducing_points = (
    jax.random.uniform(key=key, shape=(num_inducing, input_dim))
    * (X.max() - X.min())
    + X.min()
)
model = jgp.SVGP(gprior, likelihood, inducing_points, output_dim)

params, constrain_trans, unconstrain_trans = jgp.initialise(model)
raw_params = unconstrain_trans(params)
neg_elbo = model.build_elbo(num_data=num_data, sign=-1.0)

In [None]:
from jaxgp.utils import pytree_shape_info

In [None]:
def loss(raw_params):
    batch_size = num_data
    if num_data % batch_size != 0:
        raise ValueError("num_data need to be divisible by batch_size.")
    num_iters = num_data // batch_size
    neg_elbo_value = 0.0
    for i in range(num_iters):
        batch = (
            X[i * batch_size : (i + 1) * batch_size],
            Y[i * batch_size : (i + 1) * batch_size],
        )
        data = Dataset(X=batch[0], Y=batch[1])
        neg_elbo_value += neg_elbo(raw_params, data)
        i += 1
    return neg_elbo_value / i

In [None]:
print("Initial negative elbo = ", loss(raw_params))

In [None]:
pytree_shape_info(raw_params)

In [None]:
solver = jaxopt.LBFGS(fun=loss, verbose=True)
soln = solver.run(raw_params)

In [None]:
solver = jaxopt.ScipyMinimize(fun=loss, jit=True, options={"disp": True})
soln = solver.run(raw_params)

In [None]:
@jax.jit
@jax.value_and_grad
def loss_sgd(raw_params, batch):
    return neg_elbo(raw_params, batch)


batch_size = 50
training_data = CustomDataset(X, Y)
train_dataloader = NumpyLoader(
    training_data, batch_size=batch_size, shuffle=True
)
opt = optax.adam(learning_rate=1e-3)
opt_state = opt.init(raw_params)

num_epochs = 400
loss_history = []
for epoch in tqdm(range(num_epochs)):
    for batch in train_dataloader:
        data = Dataset(X=batch[0], Y=batch[1])
        loss_val, grads = loss_sgd(raw_params, data)
        updates, opt_state = opt.update(grads, opt_state)
        raw_params = optax.apply_updates(raw_params, updates)
    loss_history.append(loss_val.item())

In [None]:
plt.figure()
plt.plot(loss_history, label="loss")
plt.legend()
plt.show()

In [None]:
loss(raw_params)

In [None]:
raw_params["q_sqrt"]

In [None]:
params = constrain_trans(raw_params)
plot_model(params)

In [None]:
def plot_model(params):
    plt.figure(figsize=(12, 6))
    plt.plot(X, Y, "kx", mew=2, alpha=0.5, label="data points")
    plt.plot(
        params["inducing_points"],
        jnp.zeros([num_inducing, input_dim]),
        "|",
        color="tab:red",
        mew=2,
        alpha=0.5,
        label="inducing_points",
    )
    mean, var = model.predict_y(params, Xtest)
    plt.plot(Xtest, mean, "tab:orange", lw=2, label="predicted mean")
    plt.fill_between(
        Xtest[:, 0],
        mean[:, 0] - 1.96 * jnp.sqrt(var[:, 0]),
        mean[:, 0] + 1.96 * jnp.sqrt(var[:, 0]),
        color="tab:blue",
        alpha=0.5,
        label="95% confidence region",
    )
    plt.xlabel("x")
    plt.ylabel("y")
    plt.legend()

    plt.xlim([-5, 5])
    plt.show()

In [None]:
params = constrain_trans(soln.params)
plot_model(params)

In [None]:
loss(raw_params)

In [None]:
opt = optax.adam(learning_rate=1e-3)
opt_state = opt.init(raw_params)

num_epochs = 400
loss_history = []
for epoch in tqdm(range(num_epochs)):
    for batch in train_dataloader:
        data = Dataset(X=batch[0], Y=batch[1])
        loss_val, grads = loss(raw_params, data)
        updates, opt_state = opt.update(grads, opt_state)
        raw_params = optax.apply_updates(raw_params, updates)
    loss_history.append(loss_val.item())