This is for testing the decoupled samplers with batch GP models.

In [1]:
import torch
from gp_sampling.decoupled_samplers import decoupled_sampler
from gp_sampling.utils.random_gp import generate_random_gp
%matplotlib
import matplotlib.pyplot as plt
from botorch.sampling import IIDNormalSampler
from time import time


Using matplotlib backend: TkAgg


In [2]:
plt.close()

model = generate_random_gp(dim=1, num_train=10, standardized=False)
num_fant_X = 3
fantasy_X = torch.rand(num_fant_X, 1, 1)
num_fantasies = 4
fantasy_model = model.fantasize(
    X=fantasy_X,
    sampler=IIDNormalSampler(num_fantasies)
)
sample_count = 5
sample_shape = torch.Size([sample_count])

ds = decoupled_sampler(
    model=fantasy_model,
    sample_shape=sample_shape,
    num_basis=256
)

In [3]:
with torch.no_grad():
    test_X = torch.linspace(0, 1, 100).reshape(-1, 1)
    # samples are sample_shape x num_fantasies x num_fant_X x 100 x 1
    ds_samples = ds(test_X).squeeze(-1).detach().permute(1, 2, 3, 0)
    # permuted to get sample_shape in the end
    exact_samples = fantasy_model.posterior(test_X).rsample(
        sample_shape=sample_shape
    ).squeeze(-1).detach().permute(1, 2, 3, 0)
    post_mean = fantasy_model.posterior(test_X).mean.squeeze(-1)
    post_var = fantasy_model.posterior(test_X).variance
    post_std = torch.sqrt(post_var).squeeze(-1)

fig, axs = plt.subplots(num_fantasies, num_fant_X)

for i in range(num_fantasies):
    for j in range(num_fant_X):
        axs[i, j].plot(test_X, ds_samples[i, j], label="decoupled", color="red")
        axs[i, j].plot(test_X, exact_samples[i, j], label="exact", color="blue")
        axs[i, j].plot(test_X, post_mean[i, j], label="post_mean", color="green")
        axs[i, j].fill_between(
            test_X.reshape(-1),
            post_mean[i, j] - 2 * post_std[i, j],
            post_mean[i, j] + 2 * post_std[i, j],
            color="green"
        )
plt.legend()
plt.grid(True)
plt.show()

