In [None]:
import sys

sys.path.insert(0, "../..")

%load_ext autoreload
%autoreload 2

(sgpr_heteroskedastic)=

# SGPR regression with heteroskedastic noise

In [None]:
from jax.config import config

config.update("jax_debug_nans", True)
import jax
import copy
import jaxgp as jgp
import jax.numpy as jnp
from jaxgp.sgpr_heteroskedastic import HeteroskedasticSGPR
import jaxopt
import numpy as np
import matplotlib.pyplot as plt
from jaxgp.contrib.train_utils import train_model

In [None]:
rng = np.random.RandomState(123)


def readCsvFile(fileName):
    return np.loadtxt(fileName).reshape(-1, 1)


def getTrainingTestData():
    overallX = readCsvFile("data/snelson_train_inputs.dat")
    overallY = readCsvFile("data/snelson_train_outputs.dat")

    trainIndices = []
    testIndices = []

    nPoints = overallX.shape[0]

    for index in range(nPoints):
        if index % 4 == 0:
            trainIndices.append(index)
        else:
            testIndices.append(index)

    Xtrain = overallX[trainIndices, :]
    Xtest = overallX[testIndices, :]
    Ytrain = overallY[trainIndices, :]
    Ytest = overallY[testIndices, :]

    return Xtrain, Ytrain, Xtest, Ytest


X, y, Xtest, ytest = getTrainingTestData()
inds = jnp.argsort(X[:, 0])
X = X[inds]
y = y[inds]

## Sanity check with SGPR with homoscedastic noise

In [None]:
Z = X.copy()[::5]
train_data = jgp.Dataset(X=X, Y=y)
kernel = jgp.kernels.RBF(active_dims=tuple(range(X.shape[-1])))
mean = jgp.means.Quadratic(input_dim=X.shape[-1])
model = HeteroskedasticSGPR(
    train_data=train_data,
    gprior=jgp.GPrior(kernel=kernel, mean_function=mean),
    likelihood=jgp.likelihoods.Gaussian(),
    inducing_points=Z,
)

_, constrain_trans, _ = jgp.initialise(model)

# soln = train_model(model, fixed_params={"inducing_points": Z})
soln = train_model(model)
posterior = model.posterior()
final_params = constrain_trans(soln.params)
# print(final_params)

In [None]:
# print params' shape info
params_container = copy.deepcopy(soln.params)
params_container = jax.tree_map(lambda v: v.shape, params_container)
print(params_container)

In [None]:
Xtest = jnp.linspace(-3, 10, 100)
pred_mean, pred_var = posterior.predict_f(Xtest, final_params, full_cov=False)
plt.plot(X, y, "o", color="k", markersize=2)
plt.plot(
    final_params["inducing_points"], np.zeros_like(Z), "x", color="tab:red"
)
plt.plot(Xtest, pred_mean, color="tab:orange", linewidth=2)
plt.fill_between(
    Xtest.squeeze(),
    pred_mean.squeeze() - 2 * np.sqrt(pred_var.squeeze()),
    pred_mean.squeeze() + 2 * np.sqrt(pred_var.squeeze()),
    alpha=0.5,
    color="tab:blue",
)

In [None]:
qu_mean, qu_cov = model.compute_qu(final_params)
f_at_Z_mean, f_at_Z_cov = posterior.predict_f(
    final_params["inducing_points"], final_params, full_cov=True
)
assert jnp.allclose(qu_mean, f_at_Z_mean, rtol=1e-5, atol=1e-3)
assert jnp.allclose(
    qu_cov.reshape(1, Z.shape[0], Z.shape[0]), f_at_Z_cov, rtol=1e-5, atol=1e-5
)

## Heteroskedastic noise

In [None]:
Z = X.copy()[::5]
train_data = jgp.Dataset(X=X, Y=y)
kernel = jgp.kernels.RBF(active_dims=tuple(range(X.shape[-1])))
mean = jgp.means.Constant()
noise_variance = jnp.concatenate(
    [
        0.01 * jnp.ones(X.shape[0] // 2),
        1.0 * jnp.ones(X.shape[0] - X.shape[0] // 2),
    ]
)
model = HeteroskedasticSGPR(
    train_data=train_data,
    gprior=jgp.GPrior(kernel=kernel, mean_function=mean),
    likelihood=jgp.likelihoods.FixedHeteroskedasticGaussian(),
    sigma_sq_user=noise_variance,
    inducing_points=Z,
)

_, constrain_trans, _ = jgp.initialise(model)

# soln = train_model(model, fixed_params={"inducing_points": X})
soln = train_model(model)
posterior = model.posterior()
final_params = constrain_trans(soln.params)
print("After optimization negative elbo = ", soln.state.fun_val)

In [None]:
Xtest = jnp.linspace(-3, 10, 100)
Xtest = jnp.concatenate([X[:, 0], Xtest])
Xtest = jnp.sort(Xtest)
pred_mean, pred_var = posterior.predict_f(Xtest, final_params, full_cov=False)
plt.plot(X, y, "o", color="k", markersize=2)
plt.plot(
    final_params["inducing_points"], np.zeros_like(Z), "x", color="tab:red"
)
plt.plot(Xtest, pred_mean, color="tab:orange", linewidth=2)
plt.fill_between(
    Xtest.squeeze(),
    pred_mean.squeeze() - 2 * np.sqrt(pred_var.squeeze()),
    pred_mean.squeeze() + 2 * np.sqrt(pred_var.squeeze()),
    alpha=0.5,
    color="tab:blue",
)

## Compare with exact GP

In [None]:
from jaxgp.gpr import GPR

train_data = jgp.Dataset(X=X, Y=y)
kernel = jgp.kernels.RBF(active_dims=tuple(range(X.shape[-1])))
mean = jgp.means.Constant()
noise_variance = jnp.concatenate(
    [
        0.01 * jnp.ones(X.shape[0] // 2),
        1.0 * jnp.ones(X.shape[0] - X.shape[0] // 2),
    ]
)
model = GPR(
    train_data=train_data,
    gprior=jgp.GPrior(kernel=kernel, mean_function=mean),
    sigma_sq=noise_variance,
)

In [None]:
_, constrain_trans, _ = jgp.initialise(model)
soln = train_model(model)

In [None]:
final_params_exact_gp = constrain_trans(soln.params)
gp_post = model.posterior(final_params_exact_gp)
pred_mean_exact_gp, pred_var_exact_gp = gp_post.predict_f(Xtest)

In [None]:
plt.plot(X, y, "o", color="k", markersize=2)
plt.plot(
    final_params["inducing_points"], np.zeros_like(Z), "x", color="tab:red"
)
plt.plot(
    Xtest, pred_mean_exact_gp, color="tab:red", linewidth=2, label="Exact GP"
)
plt.plot(Xtest, pred_mean, color="tab:orange", linewidth=2, label="SGPR")
plt.fill_between(
    Xtest.squeeze(),
    pred_mean_exact_gp.squeeze() - 2 * np.sqrt(pred_var_exact_gp.squeeze()),
    pred_mean_exact_gp.squeeze() + 2 * np.sqrt(pred_var_exact_gp.squeeze()),
    alpha=0.5,
    color="tab:blue",
    label="Exact GP",
)
plt.fill_between(
    Xtest.squeeze(),
    pred_mean.squeeze() - 2 * np.sqrt(pred_var.squeeze()),
    pred_mean.squeeze() + 2 * np.sqrt(pred_var.squeeze()),
    alpha=0.3,
    color="tab:green",
    label="SGPR",
)
plt.legend()