The decoupled samplers show surprisingly inferior performance in thompson sampling,
which leads to suspicions for a buggy implementation. In this notebook, we will analyze
 the sample quality of the decoupled samplers to better understand the behavior.

The bug has been fixed.

In [1]:
import torch
from gp_sampling.decoupled_samplers import decoupled_sampler
from gp_sampling.utils.random_gp import generate_random_gp
%matplotlib
import matplotlib.pyplot as plt
from time import time


Using matplotlib backend: TkAgg


In [2]:
plt.close()

model = generate_random_gp(dim=1, num_train=10, standardized=False)
sample_count = 5
sample_shape = torch.Size([sample_count])
ds = decoupled_sampler(
    model=model,
    sample_shape=sample_shape,
    num_basis=256
)
with torch.no_grad():
    test_X = torch.linspace(0, 1, 100).reshape(-1, 1)
    ds_samples = ds(test_X).reshape(sample_count, 100).detach().t()
    exact_samples = model.posterior(test_X).rsample(
        sample_shape=sample_shape).reshape(sample_count, 100).detach().t()
    post_mean = model.posterior(test_X).mean.reshape(-1)
    post_var = model.posterior(test_X).variance
    post_std = torch.sqrt(post_var).reshape(-1)

plt.plot(test_X, ds_samples, label="decoupled", color="red")
plt.plot(test_X, exact_samples, label="exact", color="blue")
plt.plot(test_X, post_mean, label="post_mean", color="green")
plt.fill_between(
    test_X.reshape(-1),
    post_mean - 2 * post_std,
    post_mean + 2 * post_std,
    color="green"
)
plt.legend()
plt.grid(True)
plt.show()

