In [None]:
import matplotlib.pyplot as plt
import torch
import seaborn as sns
# sns.set_style("whitegrid")
sns.set_palette("bright")
from botorch.models import HigherOrderGP
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.optim.fit import fit_gpytorch_torch

from mpl_toolkits.axes_grid1 import make_axes_locatable

## Example Test Problem

In [None]:
def generate_sample(pars, noise_sd=0.01, size=32):
    x, y = torch.meshgrid(torch.arange(1, size+1), torch.arange(0, size+1))
    noiseless_out = torch.sin(2. * pars[0] * x) * torch.cos(0.4 * pars[1] * y)
    return noiseless_out + noise_sd * torch.randn_like(noiseless_out)

def generate_data(x, noise_sd=0.01):
    return torch.stack([generate_sample(pars, noise_sd=noise_sd) for pars in x])

In [None]:
torch.random.manual_seed(210)

_, ax = plt.subplots(1,1,figsize=(5, 5), facecolor="w")
f = plt.imshow(generate_sample(torch.randn(2)))

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)

plt.colorbar(f, cax=cax).set_label(size=20, label=r"$f(x,y)$", )
plt.savefig("./hogp_example_function.pdf", bbox_inches="tight")

In [None]:
train_x = torch.randn(50, 2)
train_y = generate_data(train_x)

In [None]:
_, ax = plt.subplots(1,1,figsize=(5, 5), facecolor="w")

f = plt.imshow(train_y.var(dim=0))
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)

plt.colorbar(f, cax=cax).set_label(size=20, label=r"$\mathbb{V}(f(x,y))$", )
plt.savefig("./hogp_example_function.pdf", bbox_inches="tight")

In [None]:
model = HigherOrderGP(train_x, train_y, latent_init="default")
mll = ExactMarginalLogLikelihood(model.likelihood, model)

In [None]:
fit_gpytorch_torch(mll);

In [None]:
test_x = torch.randn(1, 2)
test_y = generate_data(test_x, noise_sd=0.)

In [None]:
nonsmooth_post = model.posterior(test_x)

In [None]:
_, ax = plt.subplots(1,1,figsize=(5, 5), facecolor="w")
f = plt.imshow(nonsmooth_post.variance[0].detach())

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)

plt.colorbar(f, cax=cax)
# plt.savefig("./hogp_zhe_variance.pdf", bbox_inches="tight")

In [None]:
plt.plot(model.latent_parameters[0].detach())

In [None]:
plt.plot(model.latent_parameters[1].detach())

In [None]:
model.eval()
true_post = model(test_x)

In [None]:
_, ax = plt.subplots(1,1,figsize=(5, 5), facecolor="w")
f = plt.imshow(true_post.covariance_matrix.detach())

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)

plt.colorbar(f, cax=cax)
# plt.savefig("./hogp_zhe_covariance.pdf", bbox_inches="tight")

In [None]:
smooth_model = HigherOrderGP(train_x, train_y, latent_init="gp")
mll = ExactMarginalLogLikelihood(smooth_model.likelihood, smooth_model)

In [None]:
# fit_gpytorch_torch(mll);

In [None]:
smooth_post = smooth_model.posterior(test_x)

In [None]:
_, ax = plt.subplots(1,1,figsize=(5, 5), facecolor="w")
f = plt.imshow(smooth_post.variance[0].detach())

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)

plt.colorbar(f, cax=cax)
# plt.savefig("./hogp_smooth_variance.pdf", bbox_inches="tight")

In [None]:
plt.plot(smooth_model.latent_parameters[0].detach())

In [None]:
plt.plot(smooth_model.latent_parameters[1].detach())

In [None]:
smooth_model.eval()
true_post = smooth_model(test_x)

In [None]:
_, ax = plt.subplots(1,1,figsize=(5, 5), facecolor="w")
f = plt.imshow(true_post.covariance_matrix.detach())

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)

plt.colorbar(f, cax=cax)
plt.savefig("./hogp_smooth_covariance.pdf", bbox_inches="tight")

In [None]:
true_post.covariance_matrix.shape

In [None]:
f = plt.imshow(true_post.covariance_matrix.detach()[:256, :256])
plt.colorbar(f)

In [None]:
f = plt.imshow(cov(train_y.reshape(50, -1).numpy().T))
plt.colorbar(f)

In [None]:
from numpy import cov

In [None]:
sns.set_style("whitegrid")

In [None]:
fig, ax = plt.subplots(1, 1, figsize = (5, 5))

plt.plot(smooth_model.latent_parameters[0].detach(), color = "blue", label = "GP Latent 0")
plt.plot(smooth_model.latent_parameters[1].detach(), color = "blue", linestyle="--", label = "GP Latent 1")

plt.plot(model.latent_parameters[0].detach(), color = "orange", label = "Latent 0")
plt.plot(model.latent_parameters[1].detach(), color = "orange", linestyle="--", label = "Latent 1")
plt.legend(fontsize=16, ncol=2, loc="lower center")
plt.xlabel("x", fontsize = 20)
plt.ylabel("Latent", fontsize=20)
plt.ylim((-1.5, 1))
plt.savefig("./hogp_latents.pdf", bbox_inches="tight")