The decoupled samplers have unexpected behavior with this implementation. The goal here
 is to compare it in detail with the original implementation, and debug the problem.

The bug has been identified and fixed.

In [60]:
import torch
from botorch.models import SingleTaskGP
%matplotlib
import matplotlib.pyplot as plt
from gp_sampling import decoupled_sampler


Using matplotlib backend: TkAgg


In [61]:
# generate the train data
train_X = torch.arange(0, 1, 0.1).reshape(-1, 1)
train_Y = train_X * torch.sin(20*train_X)

# Initialize the GP model
model = SingleTaskGP(train_X, train_Y)

# Set GP hyper-parameters
# model.covar_module.outputscale = 0.1  # This is a different scale
model.covar_module.base_kernel.lengthscale = 0.1
model.likelihood.noise = 5e-3

# test points
test_X = torch.arange(0, 1, 0.01).reshape(-1, 1)

# clear previous plots
plt.close()

# posterior summary
with torch.no_grad():
    post_mean = model.posterior(test_X).mean.reshape(-1)
    post_std = torch.sqrt(model.posterior(test_X).variance).reshape(-1)
plt.plot(test_X.reshape(-1), post_mean, label="post_mean", color="green")
plt.fill_between(
    test_X.reshape(-1),
    post_mean - 2 * post_std,
    post_mean + 2 * post_std,
    alpha=0.5
)

# plot several posterior samples
num_samples = 10
sample_shape = torch.Size([num_samples])
with torch.no_grad():
    exact_samples = model.posterior(test_X).rsample(sample_shape=sample_shape).squeeze(-1)
plt.plot(
    test_X.reshape(-1),
    exact_samples.t(),
    label="exact",
    color="blue"
)


[<matplotlib.lines.Line2D at 0x7fbb3b5b5b20>,
 <matplotlib.lines.Line2D at 0x7fbb3b5b5940>,
 <matplotlib.lines.Line2D at 0x7fbb3b5b5d90>,
 <matplotlib.lines.Line2D at 0x7fbb3b5b5760>,
 <matplotlib.lines.Line2D at 0x7fbb3b5b59a0>,
 <matplotlib.lines.Line2D at 0x7fbb3b5b5970>,
 <matplotlib.lines.Line2D at 0x7fbb3b5b58e0>,
 <matplotlib.lines.Line2D at 0x7fbb3b5b50d0>,
 <matplotlib.lines.Line2D at 0x7fbb3b5b5dc0>,
 <matplotlib.lines.Line2D at 0x7fbb3b5b5f70>]

In [62]:
# initialize the decoupled sampler
ds = decoupled_sampler(
    model=model,
    sample_shape=sample_shape,
    num_basis=256
)

# sample from decoupled
with torch.no_grad():
    ds_samples = ds(test_X)

# plot the samples
plt.plot(
    test_X.reshape(-1),
    ds_samples.squeeze(-1).t(),
    label="decoupled",
    color="red"
)


[<matplotlib.lines.Line2D at 0x7fbb3bf79730>,
 <matplotlib.lines.Line2D at 0x7fbb3bf79df0>,
 <matplotlib.lines.Line2D at 0x7fbb3bf79d90>,
 <matplotlib.lines.Line2D at 0x7fbb3bf79640>,
 <matplotlib.lines.Line2D at 0x7fbb3bf79a30>,
 <matplotlib.lines.Line2D at 0x7fbb3bf794c0>,
 <matplotlib.lines.Line2D at 0x7fbb3bf79d00>,
 <matplotlib.lines.Line2D at 0x7fbb3bf797c0>,
 <matplotlib.lines.Line2D at 0x7fbb3bf79c10>,
 <matplotlib.lines.Line2D at 0x7fbb3bf58eb0>]

In [63]:
#plot
plt.grid(True)
plt.legend()
plt.title("torch")
plt.show()